##// END OF EJS Templates
add ownership to AsyncResult objects...
MinRK -
Show More
@@ -1,695 +1,703 b''
1 """AsyncResult objects for the client"""
1 """AsyncResult objects for the client"""
2
2
3 # Copyright (c) IPython Development Team.
3 # Copyright (c) IPython Development Team.
4 # Distributed under the terms of the Modified BSD License.
4 # Distributed under the terms of the Modified BSD License.
5
5
6 from __future__ import print_function
6 from __future__ import print_function
7
7
8 import sys
8 import sys
9 import time
9 import time
10 from datetime import datetime
10 from datetime import datetime
11
11
12 from zmq import MessageTracker
12 from zmq import MessageTracker
13
13
14 from IPython.core.display import clear_output, display, display_pretty
14 from IPython.core.display import clear_output, display, display_pretty
15 from IPython.external.decorator import decorator
15 from IPython.external.decorator import decorator
16 from IPython.parallel import error
16 from IPython.parallel import error
17 from IPython.utils.py3compat import string_types
17 from IPython.utils.py3compat import string_types
18
18
19 #-----------------------------------------------------------------------------
20 # Functions
21 #-----------------------------------------------------------------------------
22
19
23 def _raw_text(s):
20 def _raw_text(s):
24 display_pretty(s, raw=True)
21 display_pretty(s, raw=True)
25
22
26 #-----------------------------------------------------------------------------
27 # Classes
28 #-----------------------------------------------------------------------------
29
23
30 # global empty tracker that's always done:
24 # global empty tracker that's always done:
31 finished_tracker = MessageTracker()
25 finished_tracker = MessageTracker()
32
26
33 @decorator
27 @decorator
34 def check_ready(f, self, *args, **kwargs):
28 def check_ready(f, self, *args, **kwargs):
35 """Call spin() to sync state prior to calling the method."""
29 """Call spin() to sync state prior to calling the method."""
36 self.wait(0)
30 self.wait(0)
37 if not self._ready:
31 if not self._ready:
38 raise error.TimeoutError("result not ready")
32 raise error.TimeoutError("result not ready")
39 return f(self, *args, **kwargs)
33 return f(self, *args, **kwargs)
40
34
41 class AsyncResult(object):
35 class AsyncResult(object):
42 """Class for representing results of non-blocking calls.
36 """Class for representing results of non-blocking calls.
43
37
44 Provides the same interface as :py:class:`multiprocessing.pool.AsyncResult`.
38 Provides the same interface as :py:class:`multiprocessing.pool.AsyncResult`.
45 """
39 """
46
40
47 msg_ids = None
41 msg_ids = None
48 _targets = None
42 _targets = None
49 _tracker = None
43 _tracker = None
50 _single_result = False
44 _single_result = False
45 owner = False,
51
46
52 def __init__(self, client, msg_ids, fname='unknown', targets=None, tracker=None):
47 def __init__(self, client, msg_ids, fname='unknown', targets=None, tracker=None,
48 owner=False,
49 ):
53 if isinstance(msg_ids, string_types):
50 if isinstance(msg_ids, string_types):
54 # always a list
51 # always a list
55 msg_ids = [msg_ids]
52 msg_ids = [msg_ids]
56 self._single_result = True
53 self._single_result = True
57 else:
54 else:
58 self._single_result = False
55 self._single_result = False
59 if tracker is None:
56 if tracker is None:
60 # default to always done
57 # default to always done
61 tracker = finished_tracker
58 tracker = finished_tracker
62 self._client = client
59 self._client = client
63 self.msg_ids = msg_ids
60 self.msg_ids = msg_ids
64 self._fname=fname
61 self._fname=fname
65 self._targets = targets
62 self._targets = targets
66 self._tracker = tracker
63 self._tracker = tracker
64 self.owner = owner
67
65
68 self._ready = False
66 self._ready = False
69 self._outputs_ready = False
67 self._outputs_ready = False
70 self._success = None
68 self._success = None
71 self._metadata = [self._client.metadata[id] for id in self.msg_ids]
69 self._metadata = [self._client.metadata[id] for id in self.msg_ids]
72
70
73 def __repr__(self):
71 def __repr__(self):
74 if self._ready:
72 if self._ready:
75 return "<%s: finished>"%(self.__class__.__name__)
73 return "<%s: finished>"%(self.__class__.__name__)
76 else:
74 else:
77 return "<%s: %s>"%(self.__class__.__name__,self._fname)
75 return "<%s: %s>"%(self.__class__.__name__,self._fname)
78
76
79
77
80 def _reconstruct_result(self, res):
78 def _reconstruct_result(self, res):
81 """Reconstruct our result from actual result list (always a list)
79 """Reconstruct our result from actual result list (always a list)
82
80
83 Override me in subclasses for turning a list of results
81 Override me in subclasses for turning a list of results
84 into the expected form.
82 into the expected form.
85 """
83 """
86 if self._single_result:
84 if self._single_result:
87 return res[0]
85 return res[0]
88 else:
86 else:
89 return res
87 return res
90
88
91 def get(self, timeout=-1):
89 def get(self, timeout=-1):
92 """Return the result when it arrives.
90 """Return the result when it arrives.
93
91
94 If `timeout` is not ``None`` and the result does not arrive within
92 If `timeout` is not ``None`` and the result does not arrive within
95 `timeout` seconds then ``TimeoutError`` is raised. If the
93 `timeout` seconds then ``TimeoutError`` is raised. If the
96 remote call raised an exception then that exception will be reraised
94 remote call raised an exception then that exception will be reraised
97 by get() inside a `RemoteError`.
95 by get() inside a `RemoteError`.
98 """
96 """
99 if not self.ready():
97 if not self.ready():
100 self.wait(timeout)
98 self.wait(timeout)
101
99
102 if self._ready:
100 if self._ready:
103 if self._success:
101 if self._success:
104 return self._result
102 return self._result
105 else:
103 else:
106 raise self._exception
104 raise self._exception
107 else:
105 else:
108 raise error.TimeoutError("Result not ready.")
106 raise error.TimeoutError("Result not ready.")
109
107
110 def _check_ready(self):
108 def _check_ready(self):
111 if not self.ready():
109 if not self.ready():
112 raise error.TimeoutError("Result not ready.")
110 raise error.TimeoutError("Result not ready.")
113
111
114 def ready(self):
112 def ready(self):
115 """Return whether the call has completed."""
113 """Return whether the call has completed."""
116 if not self._ready:
114 if not self._ready:
117 self.wait(0)
115 self.wait(0)
118 elif not self._outputs_ready:
116 elif not self._outputs_ready:
119 self._wait_for_outputs(0)
117 self._wait_for_outputs(0)
120
118
121 return self._ready
119 return self._ready
122
120
123 def wait(self, timeout=-1):
121 def wait(self, timeout=-1):
124 """Wait until the result is available or until `timeout` seconds pass.
122 """Wait until the result is available or until `timeout` seconds pass.
125
123
126 This method always returns None.
124 This method always returns None.
127 """
125 """
128 if self._ready:
126 if self._ready:
129 self._wait_for_outputs(timeout)
127 self._wait_for_outputs(timeout)
130 return
128 return
131 self._ready = self._client.wait(self.msg_ids, timeout)
129 self._ready = self._client.wait(self.msg_ids, timeout)
132 if self._ready:
130 if self._ready:
133 try:
131 try:
134 results = list(map(self._client.results.get, self.msg_ids))
132 results = list(map(self._client.results.get, self.msg_ids))
135 self._result = results
133 self._result = results
136 if self._single_result:
134 if self._single_result:
137 r = results[0]
135 r = results[0]
138 if isinstance(r, Exception):
136 if isinstance(r, Exception):
139 raise r
137 raise r
140 else:
138 else:
141 results = error.collect_exceptions(results, self._fname)
139 results = error.collect_exceptions(results, self._fname)
142 self._result = self._reconstruct_result(results)
140 self._result = self._reconstruct_result(results)
143 except Exception as e:
141 except Exception as e:
144 self._exception = e
142 self._exception = e
145 self._success = False
143 self._success = False
146 else:
144 else:
147 self._success = True
145 self._success = True
148 finally:
146 finally:
149 if timeout is None or timeout < 0:
147 if timeout is None or timeout < 0:
150 # cutoff infinite wait at 10s
148 # cutoff infinite wait at 10s
151 timeout = 10
149 timeout = 10
152 self._wait_for_outputs(timeout)
150 self._wait_for_outputs(timeout)
153
151
152 if self.owner:
153
154 self._metadata = [self._client.metadata.pop(mid) for mid in self.msg_ids]
155 [self._client.results.pop(mid) for mid in self.msg_ids]
156
157
154
158
155 def successful(self):
159 def successful(self):
156 """Return whether the call completed without raising an exception.
160 """Return whether the call completed without raising an exception.
157
161
158 Will raise ``AssertionError`` if the result is not ready.
162 Will raise ``AssertionError`` if the result is not ready.
159 """
163 """
160 assert self.ready()
164 assert self.ready()
161 return self._success
165 return self._success
162
166
163 #----------------------------------------------------------------
167 #----------------------------------------------------------------
164 # Extra methods not in mp.pool.AsyncResult
168 # Extra methods not in mp.pool.AsyncResult
165 #----------------------------------------------------------------
169 #----------------------------------------------------------------
166
170
167 def get_dict(self, timeout=-1):
171 def get_dict(self, timeout=-1):
168 """Get the results as a dict, keyed by engine_id.
172 """Get the results as a dict, keyed by engine_id.
169
173
170 timeout behavior is described in `get()`.
174 timeout behavior is described in `get()`.
171 """
175 """
172
176
173 results = self.get(timeout)
177 results = self.get(timeout)
174 if self._single_result:
178 if self._single_result:
175 results = [results]
179 results = [results]
176 engine_ids = [ md['engine_id'] for md in self._metadata ]
180 engine_ids = [ md['engine_id'] for md in self._metadata ]
177
181
178
182
179 rdict = {}
183 rdict = {}
180 for engine_id, result in zip(engine_ids, results):
184 for engine_id, result in zip(engine_ids, results):
181 if engine_id in rdict:
185 if engine_id in rdict:
182 raise ValueError("Cannot build dict, %i jobs ran on engine #%i" % (
186 raise ValueError("Cannot build dict, %i jobs ran on engine #%i" % (
183 engine_ids.count(engine_id), engine_id)
187 engine_ids.count(engine_id), engine_id)
184 )
188 )
185 else:
189 else:
186 rdict[engine_id] = result
190 rdict[engine_id] = result
187
191
188 return rdict
192 return rdict
189
193
190 @property
194 @property
191 def result(self):
195 def result(self):
192 """result property wrapper for `get(timeout=-1)`."""
196 """result property wrapper for `get(timeout=-1)`."""
193 return self.get()
197 return self.get()
194
198
195 # abbreviated alias:
199 # abbreviated alias:
196 r = result
200 r = result
197
201
198 @property
202 @property
199 def metadata(self):
203 def metadata(self):
200 """property for accessing execution metadata."""
204 """property for accessing execution metadata."""
201 if self._single_result:
205 if self._single_result:
202 return self._metadata[0]
206 return self._metadata[0]
203 else:
207 else:
204 return self._metadata
208 return self._metadata
205
209
206 @property
210 @property
207 def result_dict(self):
211 def result_dict(self):
208 """result property as a dict."""
212 """result property as a dict."""
209 return self.get_dict()
213 return self.get_dict()
210
214
211 def __dict__(self):
215 def __dict__(self):
212 return self.get_dict(0)
216 return self.get_dict(0)
213
217
214 def abort(self):
218 def abort(self):
215 """abort my tasks."""
219 """abort my tasks."""
216 assert not self.ready(), "Can't abort, I am already done!"
220 assert not self.ready(), "Can't abort, I am already done!"
217 return self._client.abort(self.msg_ids, targets=self._targets, block=True)
221 return self._client.abort(self.msg_ids, targets=self._targets, block=True)
218
222
219 @property
223 @property
220 def sent(self):
224 def sent(self):
221 """check whether my messages have been sent."""
225 """check whether my messages have been sent."""
222 return self._tracker.done
226 return self._tracker.done
223
227
224 def wait_for_send(self, timeout=-1):
228 def wait_for_send(self, timeout=-1):
225 """wait for pyzmq send to complete.
229 """wait for pyzmq send to complete.
226
230
227 This is necessary when sending arrays that you intend to edit in-place.
231 This is necessary when sending arrays that you intend to edit in-place.
228 `timeout` is in seconds, and will raise TimeoutError if it is reached
232 `timeout` is in seconds, and will raise TimeoutError if it is reached
229 before the send completes.
233 before the send completes.
230 """
234 """
231 return self._tracker.wait(timeout)
235 return self._tracker.wait(timeout)
232
236
233 #-------------------------------------
237 #-------------------------------------
234 # dict-access
238 # dict-access
235 #-------------------------------------
239 #-------------------------------------
236
240
237 def __getitem__(self, key):
241 def __getitem__(self, key):
238 """getitem returns result value(s) if keyed by int/slice, or metadata if key is str.
242 """getitem returns result value(s) if keyed by int/slice, or metadata if key is str.
239 """
243 """
240 if isinstance(key, int):
244 if isinstance(key, int):
241 self._check_ready()
245 self._check_ready()
242 return error.collect_exceptions([self._result[key]], self._fname)[0]
246 return error.collect_exceptions([self._result[key]], self._fname)[0]
243 elif isinstance(key, slice):
247 elif isinstance(key, slice):
244 self._check_ready()
248 self._check_ready()
245 return error.collect_exceptions(self._result[key], self._fname)
249 return error.collect_exceptions(self._result[key], self._fname)
246 elif isinstance(key, string_types):
250 elif isinstance(key, string_types):
247 # metadata proxy *does not* require that results are done
251 # metadata proxy *does not* require that results are done
248 self.wait(0)
252 self.wait(0)
249 values = [ md[key] for md in self._metadata ]
253 values = [ md[key] for md in self._metadata ]
250 if self._single_result:
254 if self._single_result:
251 return values[0]
255 return values[0]
252 else:
256 else:
253 return values
257 return values
254 else:
258 else:
255 raise TypeError("Invalid key type %r, must be 'int','slice', or 'str'"%type(key))
259 raise TypeError("Invalid key type %r, must be 'int','slice', or 'str'"%type(key))
256
260
257 def __getattr__(self, key):
261 def __getattr__(self, key):
258 """getattr maps to getitem for convenient attr access to metadata."""
262 """getattr maps to getitem for convenient attr access to metadata."""
259 try:
263 try:
260 return self.__getitem__(key)
264 return self.__getitem__(key)
261 except (error.TimeoutError, KeyError):
265 except (error.TimeoutError, KeyError):
262 raise AttributeError("%r object has no attribute %r"%(
266 raise AttributeError("%r object has no attribute %r"%(
263 self.__class__.__name__, key))
267 self.__class__.__name__, key))
264
268
265 # asynchronous iterator:
269 # asynchronous iterator:
266 def __iter__(self):
270 def __iter__(self):
267 if self._single_result:
271 if self._single_result:
268 raise TypeError("AsyncResults with a single result are not iterable.")
272 raise TypeError("AsyncResults with a single result are not iterable.")
269 try:
273 try:
270 rlist = self.get(0)
274 rlist = self.get(0)
271 except error.TimeoutError:
275 except error.TimeoutError:
272 # wait for each result individually
276 # wait for each result individually
273 for msg_id in self.msg_ids:
277 for msg_id in self.msg_ids:
274 ar = AsyncResult(self._client, msg_id, self._fname)
278 ar = AsyncResult(self._client, msg_id, self._fname)
275 yield ar.get()
279 yield ar.get()
276 else:
280 else:
277 # already done
281 # already done
278 for r in rlist:
282 for r in rlist:
279 yield r
283 yield r
280
284
281 def __len__(self):
285 def __len__(self):
282 return len(self.msg_ids)
286 return len(self.msg_ids)
283
287
284 #-------------------------------------
288 #-------------------------------------
285 # Sugar methods and attributes
289 # Sugar methods and attributes
286 #-------------------------------------
290 #-------------------------------------
287
291
288 def timedelta(self, start, end, start_key=min, end_key=max):
292 def timedelta(self, start, end, start_key=min, end_key=max):
289 """compute the difference between two sets of timestamps
293 """compute the difference between two sets of timestamps
290
294
291 The default behavior is to use the earliest of the first
295 The default behavior is to use the earliest of the first
292 and the latest of the second list, but this can be changed
296 and the latest of the second list, but this can be changed
293 by passing a different
297 by passing a different
294
298
295 Parameters
299 Parameters
296 ----------
300 ----------
297
301
298 start : one or more datetime objects (e.g. ar.submitted)
302 start : one or more datetime objects (e.g. ar.submitted)
299 end : one or more datetime objects (e.g. ar.received)
303 end : one or more datetime objects (e.g. ar.received)
300 start_key : callable
304 start_key : callable
301 Function to call on `start` to extract the relevant
305 Function to call on `start` to extract the relevant
302 entry [defalt: min]
306 entry [defalt: min]
303 end_key : callable
307 end_key : callable
304 Function to call on `end` to extract the relevant
308 Function to call on `end` to extract the relevant
305 entry [default: max]
309 entry [default: max]
306
310
307 Returns
311 Returns
308 -------
312 -------
309
313
310 dt : float
314 dt : float
311 The time elapsed (in seconds) between the two selected timestamps.
315 The time elapsed (in seconds) between the two selected timestamps.
312 """
316 """
313 if not isinstance(start, datetime):
317 if not isinstance(start, datetime):
314 # handle single_result AsyncResults, where ar.stamp is single object,
318 # handle single_result AsyncResults, where ar.stamp is single object,
315 # not a list
319 # not a list
316 start = start_key(start)
320 start = start_key(start)
317 if not isinstance(end, datetime):
321 if not isinstance(end, datetime):
318 # handle single_result AsyncResults, where ar.stamp is single object,
322 # handle single_result AsyncResults, where ar.stamp is single object,
319 # not a list
323 # not a list
320 end = end_key(end)
324 end = end_key(end)
321 return (end - start).total_seconds()
325 return (end - start).total_seconds()
322
326
323 @property
327 @property
324 def progress(self):
328 def progress(self):
325 """the number of tasks which have been completed at this point.
329 """the number of tasks which have been completed at this point.
326
330
327 Fractional progress would be given by 1.0 * ar.progress / len(ar)
331 Fractional progress would be given by 1.0 * ar.progress / len(ar)
328 """
332 """
329 self.wait(0)
333 self.wait(0)
330 return len(self) - len(set(self.msg_ids).intersection(self._client.outstanding))
334 return len(self) - len(set(self.msg_ids).intersection(self._client.outstanding))
331
335
332 @property
336 @property
333 def elapsed(self):
337 def elapsed(self):
334 """elapsed time since initial submission"""
338 """elapsed time since initial submission"""
335 if self.ready():
339 if self.ready():
336 return self.wall_time
340 return self.wall_time
337
341
338 now = submitted = datetime.now()
342 now = submitted = datetime.now()
339 for msg_id in self.msg_ids:
343 for msg_id in self.msg_ids:
340 if msg_id in self._client.metadata:
344 if msg_id in self._client.metadata:
341 stamp = self._client.metadata[msg_id]['submitted']
345 stamp = self._client.metadata[msg_id]['submitted']
342 if stamp and stamp < submitted:
346 if stamp and stamp < submitted:
343 submitted = stamp
347 submitted = stamp
344 return (now-submitted).total_seconds()
348 return (now-submitted).total_seconds()
345
349
346 @property
350 @property
347 @check_ready
351 @check_ready
348 def serial_time(self):
352 def serial_time(self):
349 """serial computation time of a parallel calculation
353 """serial computation time of a parallel calculation
350
354
351 Computed as the sum of (completed-started) of each task
355 Computed as the sum of (completed-started) of each task
352 """
356 """
353 t = 0
357 t = 0
354 for md in self._metadata:
358 for md in self._metadata:
355 t += (md['completed'] - md['started']).total_seconds()
359 t += (md['completed'] - md['started']).total_seconds()
356 return t
360 return t
357
361
358 @property
362 @property
359 @check_ready
363 @check_ready
360 def wall_time(self):
364 def wall_time(self):
361 """actual computation time of a parallel calculation
365 """actual computation time of a parallel calculation
362
366
363 Computed as the time between the latest `received` stamp
367 Computed as the time between the latest `received` stamp
364 and the earliest `submitted`.
368 and the earliest `submitted`.
365
369
366 Only reliable if Client was spinning/waiting when the task finished, because
370 Only reliable if Client was spinning/waiting when the task finished, because
367 the `received` timestamp is created when a result is pulled off of the zmq queue,
371 the `received` timestamp is created when a result is pulled off of the zmq queue,
368 which happens as a result of `client.spin()`.
372 which happens as a result of `client.spin()`.
369
373
370 For similar comparison of other timestamp pairs, check out AsyncResult.timedelta.
374 For similar comparison of other timestamp pairs, check out AsyncResult.timedelta.
371
375
372 """
376 """
373 return self.timedelta(self.submitted, self.received)
377 return self.timedelta(self.submitted, self.received)
374
378
375 def wait_interactive(self, interval=1., timeout=-1):
379 def wait_interactive(self, interval=1., timeout=-1):
376 """interactive wait, printing progress at regular intervals"""
380 """interactive wait, printing progress at regular intervals"""
377 if timeout is None:
381 if timeout is None:
378 timeout = -1
382 timeout = -1
379 N = len(self)
383 N = len(self)
380 tic = time.time()
384 tic = time.time()
381 while not self.ready() and (timeout < 0 or time.time() - tic <= timeout):
385 while not self.ready() and (timeout < 0 or time.time() - tic <= timeout):
382 self.wait(interval)
386 self.wait(interval)
383 clear_output(wait=True)
387 clear_output(wait=True)
384 print("%4i/%i tasks finished after %4i s" % (self.progress, N, self.elapsed), end="")
388 print("%4i/%i tasks finished after %4i s" % (self.progress, N, self.elapsed), end="")
385 sys.stdout.flush()
389 sys.stdout.flush()
386 print()
390 print()
387 print("done")
391 print("done")
388
392
389 def _republish_displaypub(self, content, eid):
393 def _republish_displaypub(self, content, eid):
390 """republish individual displaypub content dicts"""
394 """republish individual displaypub content dicts"""
391 try:
395 try:
392 ip = get_ipython()
396 ip = get_ipython()
393 except NameError:
397 except NameError:
394 # displaypub is meaningless outside IPython
398 # displaypub is meaningless outside IPython
395 return
399 return
396 md = content['metadata'] or {}
400 md = content['metadata'] or {}
397 md['engine'] = eid
401 md['engine'] = eid
398 ip.display_pub.publish(content['source'], content['data'], md)
402 ip.display_pub.publish(content['source'], content['data'], md)
399
403
400 def _display_stream(self, text, prefix='', file=None):
404 def _display_stream(self, text, prefix='', file=None):
401 if not text:
405 if not text:
402 # nothing to display
406 # nothing to display
403 return
407 return
404 if file is None:
408 if file is None:
405 file = sys.stdout
409 file = sys.stdout
406 end = '' if text.endswith('\n') else '\n'
410 end = '' if text.endswith('\n') else '\n'
407
411
408 multiline = text.count('\n') > int(text.endswith('\n'))
412 multiline = text.count('\n') > int(text.endswith('\n'))
409 if prefix and multiline and not text.startswith('\n'):
413 if prefix and multiline and not text.startswith('\n'):
410 prefix = prefix + '\n'
414 prefix = prefix + '\n'
411 print("%s%s" % (prefix, text), file=file, end=end)
415 print("%s%s" % (prefix, text), file=file, end=end)
412
416
413
417
414 def _display_single_result(self):
418 def _display_single_result(self):
415 self._display_stream(self.stdout)
419 self._display_stream(self.stdout)
416 self._display_stream(self.stderr, file=sys.stderr)
420 self._display_stream(self.stderr, file=sys.stderr)
417
421
418 try:
422 try:
419 get_ipython()
423 get_ipython()
420 except NameError:
424 except NameError:
421 # displaypub is meaningless outside IPython
425 # displaypub is meaningless outside IPython
422 return
426 return
423
427
424 for output in self.outputs:
428 for output in self.outputs:
425 self._republish_displaypub(output, self.engine_id)
429 self._republish_displaypub(output, self.engine_id)
426
430
427 if self.execute_result is not None:
431 if self.execute_result is not None:
428 display(self.get())
432 display(self.get())
429
433
430 def _wait_for_outputs(self, timeout=-1):
434 def _wait_for_outputs(self, timeout=-1):
431 """wait for the 'status=idle' message that indicates we have all outputs
435 """wait for the 'status=idle' message that indicates we have all outputs
432 """
436 """
433 if self._outputs_ready or not self._success:
437 if self._outputs_ready or not self._success:
434 # don't wait on errors
438 # don't wait on errors
435 return
439 return
436
440
437 # cast None to -1 for infinite timeout
441 # cast None to -1 for infinite timeout
438 if timeout is None:
442 if timeout is None:
439 timeout = -1
443 timeout = -1
440
444
441 tic = time.time()
445 tic = time.time()
442 while True:
446 while True:
443 self._client._flush_iopub(self._client._iopub_socket)
447 self._client._flush_iopub(self._client._iopub_socket)
444 self._outputs_ready = all(md['outputs_ready']
448 self._outputs_ready = all(md['outputs_ready']
445 for md in self._metadata)
449 for md in self._metadata)
446 if self._outputs_ready or \
450 if self._outputs_ready or \
447 (timeout >= 0 and time.time() > tic + timeout):
451 (timeout >= 0 and time.time() > tic + timeout):
448 break
452 break
449 time.sleep(0.01)
453 time.sleep(0.01)
450
454
451 @check_ready
455 @check_ready
452 def display_outputs(self, groupby="type"):
456 def display_outputs(self, groupby="type"):
453 """republish the outputs of the computation
457 """republish the outputs of the computation
454
458
455 Parameters
459 Parameters
456 ----------
460 ----------
457
461
458 groupby : str [default: type]
462 groupby : str [default: type]
459 if 'type':
463 if 'type':
460 Group outputs by type (show all stdout, then all stderr, etc.):
464 Group outputs by type (show all stdout, then all stderr, etc.):
461
465
462 [stdout:1] foo
466 [stdout:1] foo
463 [stdout:2] foo
467 [stdout:2] foo
464 [stderr:1] bar
468 [stderr:1] bar
465 [stderr:2] bar
469 [stderr:2] bar
466 if 'engine':
470 if 'engine':
467 Display outputs for each engine before moving on to the next:
471 Display outputs for each engine before moving on to the next:
468
472
469 [stdout:1] foo
473 [stdout:1] foo
470 [stderr:1] bar
474 [stderr:1] bar
471 [stdout:2] foo
475 [stdout:2] foo
472 [stderr:2] bar
476 [stderr:2] bar
473
477
474 if 'order':
478 if 'order':
475 Like 'type', but further collate individual displaypub
479 Like 'type', but further collate individual displaypub
476 outputs. This is meant for cases of each command producing
480 outputs. This is meant for cases of each command producing
477 several plots, and you would like to see all of the first
481 several plots, and you would like to see all of the first
478 plots together, then all of the second plots, and so on.
482 plots together, then all of the second plots, and so on.
479 """
483 """
480 if self._single_result:
484 if self._single_result:
481 self._display_single_result()
485 self._display_single_result()
482 return
486 return
483
487
484 stdouts = self.stdout
488 stdouts = self.stdout
485 stderrs = self.stderr
489 stderrs = self.stderr
486 execute_results = self.execute_result
490 execute_results = self.execute_result
487 output_lists = self.outputs
491 output_lists = self.outputs
488 results = self.get()
492 results = self.get()
489
493
490 targets = self.engine_id
494 targets = self.engine_id
491
495
492 if groupby == "engine":
496 if groupby == "engine":
493 for eid,stdout,stderr,outputs,r,execute_result in zip(
497 for eid,stdout,stderr,outputs,r,execute_result in zip(
494 targets, stdouts, stderrs, output_lists, results, execute_results
498 targets, stdouts, stderrs, output_lists, results, execute_results
495 ):
499 ):
496 self._display_stream(stdout, '[stdout:%i] ' % eid)
500 self._display_stream(stdout, '[stdout:%i] ' % eid)
497 self._display_stream(stderr, '[stderr:%i] ' % eid, file=sys.stderr)
501 self._display_stream(stderr, '[stderr:%i] ' % eid, file=sys.stderr)
498
502
499 try:
503 try:
500 get_ipython()
504 get_ipython()
501 except NameError:
505 except NameError:
502 # displaypub is meaningless outside IPython
506 # displaypub is meaningless outside IPython
503 return
507 return
504
508
505 if outputs or execute_result is not None:
509 if outputs or execute_result is not None:
506 _raw_text('[output:%i]' % eid)
510 _raw_text('[output:%i]' % eid)
507
511
508 for output in outputs:
512 for output in outputs:
509 self._republish_displaypub(output, eid)
513 self._republish_displaypub(output, eid)
510
514
511 if execute_result is not None:
515 if execute_result is not None:
512 display(r)
516 display(r)
513
517
514 elif groupby in ('type', 'order'):
518 elif groupby in ('type', 'order'):
515 # republish stdout:
519 # republish stdout:
516 for eid,stdout in zip(targets, stdouts):
520 for eid,stdout in zip(targets, stdouts):
517 self._display_stream(stdout, '[stdout:%i] ' % eid)
521 self._display_stream(stdout, '[stdout:%i] ' % eid)
518
522
519 # republish stderr:
523 # republish stderr:
520 for eid,stderr in zip(targets, stderrs):
524 for eid,stderr in zip(targets, stderrs):
521 self._display_stream(stderr, '[stderr:%i] ' % eid, file=sys.stderr)
525 self._display_stream(stderr, '[stderr:%i] ' % eid, file=sys.stderr)
522
526
523 try:
527 try:
524 get_ipython()
528 get_ipython()
525 except NameError:
529 except NameError:
526 # displaypub is meaningless outside IPython
530 # displaypub is meaningless outside IPython
527 return
531 return
528
532
529 if groupby == 'order':
533 if groupby == 'order':
530 output_dict = dict((eid, outputs) for eid,outputs in zip(targets, output_lists))
534 output_dict = dict((eid, outputs) for eid,outputs in zip(targets, output_lists))
531 N = max(len(outputs) for outputs in output_lists)
535 N = max(len(outputs) for outputs in output_lists)
532 for i in range(N):
536 for i in range(N):
533 for eid in targets:
537 for eid in targets:
534 outputs = output_dict[eid]
538 outputs = output_dict[eid]
535 if len(outputs) >= N:
539 if len(outputs) >= N:
536 _raw_text('[output:%i]' % eid)
540 _raw_text('[output:%i]' % eid)
537 self._republish_displaypub(outputs[i], eid)
541 self._republish_displaypub(outputs[i], eid)
538 else:
542 else:
539 # republish displaypub output
543 # republish displaypub output
540 for eid,outputs in zip(targets, output_lists):
544 for eid,outputs in zip(targets, output_lists):
541 if outputs:
545 if outputs:
542 _raw_text('[output:%i]' % eid)
546 _raw_text('[output:%i]' % eid)
543 for output in outputs:
547 for output in outputs:
544 self._republish_displaypub(output, eid)
548 self._republish_displaypub(output, eid)
545
549
546 # finally, add execute_result:
550 # finally, add execute_result:
547 for eid,r,execute_result in zip(targets, results, execute_results):
551 for eid,r,execute_result in zip(targets, results, execute_results):
548 if execute_result is not None:
552 if execute_result is not None:
549 display(r)
553 display(r)
550
554
551 else:
555 else:
552 raise ValueError("groupby must be one of 'type', 'engine', 'collate', not %r" % groupby)
556 raise ValueError("groupby must be one of 'type', 'engine', 'collate', not %r" % groupby)
553
557
554
558
555
559
556
560
557 class AsyncMapResult(AsyncResult):
561 class AsyncMapResult(AsyncResult):
558 """Class for representing results of non-blocking gathers.
562 """Class for representing results of non-blocking gathers.
559
563
560 This will properly reconstruct the gather.
564 This will properly reconstruct the gather.
561
565
562 This class is iterable at any time, and will wait on results as they come.
566 This class is iterable at any time, and will wait on results as they come.
563
567
564 If ordered=False, then the first results to arrive will come first, otherwise
568 If ordered=False, then the first results to arrive will come first, otherwise
565 results will be yielded in the order they were submitted.
569 results will be yielded in the order they were submitted.
566
570
567 """
571 """
568
572
569 def __init__(self, client, msg_ids, mapObject, fname='', ordered=True):
573 def __init__(self, client, msg_ids, mapObject, fname='', ordered=True):
570 AsyncResult.__init__(self, client, msg_ids, fname=fname)
574 AsyncResult.__init__(self, client, msg_ids, fname=fname)
571 self._mapObject = mapObject
575 self._mapObject = mapObject
572 self._single_result = False
576 self._single_result = False
573 self.ordered = ordered
577 self.ordered = ordered
574
578
575 def _reconstruct_result(self, res):
579 def _reconstruct_result(self, res):
576 """Perform the gather on the actual results."""
580 """Perform the gather on the actual results."""
577 return self._mapObject.joinPartitions(res)
581 return self._mapObject.joinPartitions(res)
578
582
579 # asynchronous iterator:
583 # asynchronous iterator:
580 def __iter__(self):
584 def __iter__(self):
581 it = self._ordered_iter if self.ordered else self._unordered_iter
585 it = self._ordered_iter if self.ordered else self._unordered_iter
582 for r in it():
586 for r in it():
583 yield r
587 yield r
584
588
585 # asynchronous ordered iterator:
589 # asynchronous ordered iterator:
586 def _ordered_iter(self):
590 def _ordered_iter(self):
587 """iterator for results *as they arrive*, preserving submission order."""
591 """iterator for results *as they arrive*, preserving submission order."""
588 try:
592 try:
589 rlist = self.get(0)
593 rlist = self.get(0)
590 except error.TimeoutError:
594 except error.TimeoutError:
591 # wait for each result individually
595 # wait for each result individually
592 for msg_id in self.msg_ids:
596 for msg_id in self.msg_ids:
593 ar = AsyncResult(self._client, msg_id, self._fname)
597 ar = AsyncResult(self._client, msg_id, self._fname)
594 rlist = ar.get()
598 rlist = ar.get()
595 try:
599 try:
596 for r in rlist:
600 for r in rlist:
597 yield r
601 yield r
598 except TypeError:
602 except TypeError:
599 # flattened, not a list
603 # flattened, not a list
600 # this could get broken by flattened data that returns iterables
604 # this could get broken by flattened data that returns iterables
601 # but most calls to map do not expose the `flatten` argument
605 # but most calls to map do not expose the `flatten` argument
602 yield rlist
606 yield rlist
603 else:
607 else:
604 # already done
608 # already done
605 for r in rlist:
609 for r in rlist:
606 yield r
610 yield r
607
611
608 # asynchronous unordered iterator:
612 # asynchronous unordered iterator:
609 def _unordered_iter(self):
613 def _unordered_iter(self):
610 """iterator for results *as they arrive*, on FCFS basis, ignoring submission order."""
614 """iterator for results *as they arrive*, on FCFS basis, ignoring submission order."""
611 try:
615 try:
612 rlist = self.get(0)
616 rlist = self.get(0)
613 except error.TimeoutError:
617 except error.TimeoutError:
614 pending = set(self.msg_ids)
618 pending = set(self.msg_ids)
615 while pending:
619 while pending:
616 try:
620 try:
617 self._client.wait(pending, 1e-3)
621 self._client.wait(pending, 1e-3)
618 except error.TimeoutError:
622 except error.TimeoutError:
619 # ignore timeout error, because that only means
623 # ignore timeout error, because that only means
620 # *some* jobs are outstanding
624 # *some* jobs are outstanding
621 pass
625 pass
622 # update ready set with those no longer outstanding:
626 # update ready set with those no longer outstanding:
623 ready = pending.difference(self._client.outstanding)
627 ready = pending.difference(self._client.outstanding)
624 # update pending to exclude those that are finished
628 # update pending to exclude those that are finished
625 pending = pending.difference(ready)
629 pending = pending.difference(ready)
626 while ready:
630 while ready:
627 msg_id = ready.pop()
631 msg_id = ready.pop()
628 ar = AsyncResult(self._client, msg_id, self._fname)
632 ar = AsyncResult(self._client, msg_id, self._fname)
629 rlist = ar.get()
633 rlist = ar.get()
630 try:
634 try:
631 for r in rlist:
635 for r in rlist:
632 yield r
636 yield r
633 except TypeError:
637 except TypeError:
634 # flattened, not a list
638 # flattened, not a list
635 # this could get broken by flattened data that returns iterables
639 # this could get broken by flattened data that returns iterables
636 # but most calls to map do not expose the `flatten` argument
640 # but most calls to map do not expose the `flatten` argument
637 yield rlist
641 yield rlist
638 else:
642 else:
639 # already done
643 # already done
640 for r in rlist:
644 for r in rlist:
641 yield r
645 yield r
642
646
643
647
644 class AsyncHubResult(AsyncResult):
648 class AsyncHubResult(AsyncResult):
645 """Class to wrap pending results that must be requested from the Hub.
649 """Class to wrap pending results that must be requested from the Hub.
646
650
647 Note that waiting/polling on these objects requires polling the Hubover the network,
651 Note that waiting/polling on these objects requires polling the Hubover the network,
648 so use `AsyncHubResult.wait()` sparingly.
652 so use `AsyncHubResult.wait()` sparingly.
649 """
653 """
650
654
651 def _wait_for_outputs(self, timeout=-1):
655 def _wait_for_outputs(self, timeout=-1):
652 """no-op, because HubResults are never incomplete"""
656 """no-op, because HubResults are never incomplete"""
653 self._outputs_ready = True
657 self._outputs_ready = True
654
658
655 def wait(self, timeout=-1):
659 def wait(self, timeout=-1):
656 """wait for result to complete."""
660 """wait for result to complete."""
657 start = time.time()
661 start = time.time()
658 if self._ready:
662 if self._ready:
659 return
663 return
660 local_ids = [m for m in self.msg_ids if m in self._client.outstanding]
664 local_ids = [m for m in self.msg_ids if m in self._client.outstanding]
661 local_ready = self._client.wait(local_ids, timeout)
665 local_ready = self._client.wait(local_ids, timeout)
662 if local_ready:
666 if local_ready:
663 remote_ids = [m for m in self.msg_ids if m not in self._client.results]
667 remote_ids = [m for m in self.msg_ids if m not in self._client.results]
664 if not remote_ids:
668 if not remote_ids:
665 self._ready = True
669 self._ready = True
666 else:
670 else:
667 rdict = self._client.result_status(remote_ids, status_only=False)
671 rdict = self._client.result_status(remote_ids, status_only=False)
668 pending = rdict['pending']
672 pending = rdict['pending']
669 while pending and (timeout < 0 or time.time() < start+timeout):
673 while pending and (timeout < 0 or time.time() < start+timeout):
670 rdict = self._client.result_status(remote_ids, status_only=False)
674 rdict = self._client.result_status(remote_ids, status_only=False)
671 pending = rdict['pending']
675 pending = rdict['pending']
672 if pending:
676 if pending:
673 time.sleep(0.1)
677 time.sleep(0.1)
674 if not pending:
678 if not pending:
675 self._ready = True
679 self._ready = True
676 if self._ready:
680 if self._ready:
677 try:
681 try:
678 results = list(map(self._client.results.get, self.msg_ids))
682 results = list(map(self._client.results.get, self.msg_ids))
679 self._result = results
683 self._result = results
680 if self._single_result:
684 if self._single_result:
681 r = results[0]
685 r = results[0]
682 if isinstance(r, Exception):
686 if isinstance(r, Exception):
683 raise r
687 raise r
684 else:
688 else:
685 results = error.collect_exceptions(results, self._fname)
689 results = error.collect_exceptions(results, self._fname)
686 self._result = self._reconstruct_result(results)
690 self._result = self._reconstruct_result(results)
687 except Exception as e:
691 except Exception as e:
688 self._exception = e
692 self._exception = e
689 self._success = False
693 self._success = False
690 else:
694 else:
691 self._success = True
695 self._success = True
692 finally:
696 finally:
693 self._metadata = [self._client.metadata[mid] for mid in self.msg_ids]
697 self._metadata = [self._client.metadata[mid] for mid in self.msg_ids]
698 if self.owner:
699 [self._client.metadata.pop(mid) for mid in self.msg_ids]
700 [self._client.results.pop(mid) for mid in self.msg_ids]
701
694
702
695 __all__ = ['AsyncResult', 'AsyncMapResult', 'AsyncHubResult']
703 __all__ = ['AsyncResult', 'AsyncMapResult', 'AsyncHubResult']
@@ -1,1863 +1,1868 b''
1 """A semi-synchronous Client for IPython parallel"""
1 """A semi-synchronous Client for IPython parallel"""
2
2
3 # Copyright (c) IPython Development Team.
3 # Copyright (c) IPython Development Team.
4 # Distributed under the terms of the Modified BSD License.
4 # Distributed under the terms of the Modified BSD License.
5
5
6 from __future__ import print_function
6 from __future__ import print_function
7
7
8 import os
8 import os
9 import json
9 import json
10 import sys
10 import sys
11 from threading import Thread, Event
11 from threading import Thread, Event
12 import time
12 import time
13 import warnings
13 import warnings
14 from datetime import datetime
14 from datetime import datetime
15 from getpass import getpass
15 from getpass import getpass
16 from pprint import pprint
16 from pprint import pprint
17
17
18 pjoin = os.path.join
18 pjoin = os.path.join
19
19
20 import zmq
20 import zmq
21
21
22 from IPython.config.configurable import MultipleInstanceError
22 from IPython.config.configurable import MultipleInstanceError
23 from IPython.core.application import BaseIPythonApplication
23 from IPython.core.application import BaseIPythonApplication
24 from IPython.core.profiledir import ProfileDir, ProfileDirError
24 from IPython.core.profiledir import ProfileDir, ProfileDirError
25
25
26 from IPython.utils.capture import RichOutput
26 from IPython.utils.capture import RichOutput
27 from IPython.utils.coloransi import TermColors
27 from IPython.utils.coloransi import TermColors
28 from IPython.utils.jsonutil import rekey, extract_dates, parse_date
28 from IPython.utils.jsonutil import rekey, extract_dates, parse_date
29 from IPython.utils.localinterfaces import localhost, is_local_ip
29 from IPython.utils.localinterfaces import localhost, is_local_ip
30 from IPython.utils.path import get_ipython_dir
30 from IPython.utils.path import get_ipython_dir
31 from IPython.utils.py3compat import cast_bytes, string_types, xrange, iteritems
31 from IPython.utils.py3compat import cast_bytes, string_types, xrange, iteritems
32 from IPython.utils.traitlets import (HasTraits, Integer, Instance, Unicode,
32 from IPython.utils.traitlets import (HasTraits, Integer, Instance, Unicode,
33 Dict, List, Bool, Set, Any)
33 Dict, List, Bool, Set, Any)
34 from IPython.external.decorator import decorator
34 from IPython.external.decorator import decorator
35 from IPython.external.ssh import tunnel
35 from IPython.external.ssh import tunnel
36
36
37 from IPython.parallel import Reference
37 from IPython.parallel import Reference
38 from IPython.parallel import error
38 from IPython.parallel import error
39 from IPython.parallel import util
39 from IPython.parallel import util
40
40
41 from IPython.kernel.zmq.session import Session, Message
41 from IPython.kernel.zmq.session import Session, Message
42 from IPython.kernel.zmq import serialize
42 from IPython.kernel.zmq import serialize
43
43
44 from .asyncresult import AsyncResult, AsyncHubResult
44 from .asyncresult import AsyncResult, AsyncHubResult
45 from .view import DirectView, LoadBalancedView
45 from .view import DirectView, LoadBalancedView
46
46
47 #--------------------------------------------------------------------------
47 #--------------------------------------------------------------------------
48 # Decorators for Client methods
48 # Decorators for Client methods
49 #--------------------------------------------------------------------------
49 #--------------------------------------------------------------------------
50
50
51 @decorator
51 @decorator
52 def spin_first(f, self, *args, **kwargs):
52 def spin_first(f, self, *args, **kwargs):
53 """Call spin() to sync state prior to calling the method."""
53 """Call spin() to sync state prior to calling the method."""
54 self.spin()
54 self.spin()
55 return f(self, *args, **kwargs)
55 return f(self, *args, **kwargs)
56
56
57
57
58 #--------------------------------------------------------------------------
58 #--------------------------------------------------------------------------
59 # Classes
59 # Classes
60 #--------------------------------------------------------------------------
60 #--------------------------------------------------------------------------
61
61
62
62
63 class ExecuteReply(RichOutput):
63 class ExecuteReply(RichOutput):
64 """wrapper for finished Execute results"""
64 """wrapper for finished Execute results"""
65 def __init__(self, msg_id, content, metadata):
65 def __init__(self, msg_id, content, metadata):
66 self.msg_id = msg_id
66 self.msg_id = msg_id
67 self._content = content
67 self._content = content
68 self.execution_count = content['execution_count']
68 self.execution_count = content['execution_count']
69 self.metadata = metadata
69 self.metadata = metadata
70
70
71 # RichOutput overrides
71 # RichOutput overrides
72
72
73 @property
73 @property
74 def source(self):
74 def source(self):
75 execute_result = self.metadata['execute_result']
75 execute_result = self.metadata['execute_result']
76 if execute_result:
76 if execute_result:
77 return execute_result.get('source', '')
77 return execute_result.get('source', '')
78
78
79 @property
79 @property
80 def data(self):
80 def data(self):
81 execute_result = self.metadata['execute_result']
81 execute_result = self.metadata['execute_result']
82 if execute_result:
82 if execute_result:
83 return execute_result.get('data', {})
83 return execute_result.get('data', {})
84
84
85 @property
85 @property
86 def _metadata(self):
86 def _metadata(self):
87 execute_result = self.metadata['execute_result']
87 execute_result = self.metadata['execute_result']
88 if execute_result:
88 if execute_result:
89 return execute_result.get('metadata', {})
89 return execute_result.get('metadata', {})
90
90
91 def display(self):
91 def display(self):
92 from IPython.display import publish_display_data
92 from IPython.display import publish_display_data
93 publish_display_data(self.data, self.metadata)
93 publish_display_data(self.data, self.metadata)
94
94
95 def _repr_mime_(self, mime):
95 def _repr_mime_(self, mime):
96 if mime not in self.data:
96 if mime not in self.data:
97 return
97 return
98 data = self.data[mime]
98 data = self.data[mime]
99 if mime in self._metadata:
99 if mime in self._metadata:
100 return data, self._metadata[mime]
100 return data, self._metadata[mime]
101 else:
101 else:
102 return data
102 return data
103
103
104 def __getitem__(self, key):
104 def __getitem__(self, key):
105 return self.metadata[key]
105 return self.metadata[key]
106
106
107 def __getattr__(self, key):
107 def __getattr__(self, key):
108 if key not in self.metadata:
108 if key not in self.metadata:
109 raise AttributeError(key)
109 raise AttributeError(key)
110 return self.metadata[key]
110 return self.metadata[key]
111
111
112 def __repr__(self):
112 def __repr__(self):
113 execute_result = self.metadata['execute_result'] or {'data':{}}
113 execute_result = self.metadata['execute_result'] or {'data':{}}
114 text_out = execute_result['data'].get('text/plain', '')
114 text_out = execute_result['data'].get('text/plain', '')
115 if len(text_out) > 32:
115 if len(text_out) > 32:
116 text_out = text_out[:29] + '...'
116 text_out = text_out[:29] + '...'
117
117
118 return "<ExecuteReply[%i]: %s>" % (self.execution_count, text_out)
118 return "<ExecuteReply[%i]: %s>" % (self.execution_count, text_out)
119
119
120 def _repr_pretty_(self, p, cycle):
120 def _repr_pretty_(self, p, cycle):
121 execute_result = self.metadata['execute_result'] or {'data':{}}
121 execute_result = self.metadata['execute_result'] or {'data':{}}
122 text_out = execute_result['data'].get('text/plain', '')
122 text_out = execute_result['data'].get('text/plain', '')
123
123
124 if not text_out:
124 if not text_out:
125 return
125 return
126
126
127 try:
127 try:
128 ip = get_ipython()
128 ip = get_ipython()
129 except NameError:
129 except NameError:
130 colors = "NoColor"
130 colors = "NoColor"
131 else:
131 else:
132 colors = ip.colors
132 colors = ip.colors
133
133
134 if colors == "NoColor":
134 if colors == "NoColor":
135 out = normal = ""
135 out = normal = ""
136 else:
136 else:
137 out = TermColors.Red
137 out = TermColors.Red
138 normal = TermColors.Normal
138 normal = TermColors.Normal
139
139
140 if '\n' in text_out and not text_out.startswith('\n'):
140 if '\n' in text_out and not text_out.startswith('\n'):
141 # add newline for multiline reprs
141 # add newline for multiline reprs
142 text_out = '\n' + text_out
142 text_out = '\n' + text_out
143
143
144 p.text(
144 p.text(
145 out + u'Out[%i:%i]: ' % (
145 out + u'Out[%i:%i]: ' % (
146 self.metadata['engine_id'], self.execution_count
146 self.metadata['engine_id'], self.execution_count
147 ) + normal + text_out
147 ) + normal + text_out
148 )
148 )
149
149
150
150
151 class Metadata(dict):
151 class Metadata(dict):
152 """Subclass of dict for initializing metadata values.
152 """Subclass of dict for initializing metadata values.
153
153
154 Attribute access works on keys.
154 Attribute access works on keys.
155
155
156 These objects have a strict set of keys - errors will raise if you try
156 These objects have a strict set of keys - errors will raise if you try
157 to add new keys.
157 to add new keys.
158 """
158 """
159 def __init__(self, *args, **kwargs):
159 def __init__(self, *args, **kwargs):
160 dict.__init__(self)
160 dict.__init__(self)
161 md = {'msg_id' : None,
161 md = {'msg_id' : None,
162 'submitted' : None,
162 'submitted' : None,
163 'started' : None,
163 'started' : None,
164 'completed' : None,
164 'completed' : None,
165 'received' : None,
165 'received' : None,
166 'engine_uuid' : None,
166 'engine_uuid' : None,
167 'engine_id' : None,
167 'engine_id' : None,
168 'follow' : None,
168 'follow' : None,
169 'after' : None,
169 'after' : None,
170 'status' : None,
170 'status' : None,
171
171
172 'execute_input' : None,
172 'execute_input' : None,
173 'execute_result' : None,
173 'execute_result' : None,
174 'error' : None,
174 'error' : None,
175 'stdout' : '',
175 'stdout' : '',
176 'stderr' : '',
176 'stderr' : '',
177 'outputs' : [],
177 'outputs' : [],
178 'data': {},
178 'data': {},
179 'outputs_ready' : False,
179 'outputs_ready' : False,
180 }
180 }
181 self.update(md)
181 self.update(md)
182 self.update(dict(*args, **kwargs))
182 self.update(dict(*args, **kwargs))
183
183
184 def __getattr__(self, key):
184 def __getattr__(self, key):
185 """getattr aliased to getitem"""
185 """getattr aliased to getitem"""
186 if key in self:
186 if key in self:
187 return self[key]
187 return self[key]
188 else:
188 else:
189 raise AttributeError(key)
189 raise AttributeError(key)
190
190
191 def __setattr__(self, key, value):
191 def __setattr__(self, key, value):
192 """setattr aliased to setitem, with strict"""
192 """setattr aliased to setitem, with strict"""
193 if key in self:
193 if key in self:
194 self[key] = value
194 self[key] = value
195 else:
195 else:
196 raise AttributeError(key)
196 raise AttributeError(key)
197
197
198 def __setitem__(self, key, value):
198 def __setitem__(self, key, value):
199 """strict static key enforcement"""
199 """strict static key enforcement"""
200 if key in self:
200 if key in self:
201 dict.__setitem__(self, key, value)
201 dict.__setitem__(self, key, value)
202 else:
202 else:
203 raise KeyError(key)
203 raise KeyError(key)
204
204
205
205
206 class Client(HasTraits):
206 class Client(HasTraits):
207 """A semi-synchronous client to the IPython ZMQ cluster
207 """A semi-synchronous client to the IPython ZMQ cluster
208
208
209 Parameters
209 Parameters
210 ----------
210 ----------
211
211
212 url_file : str/unicode; path to ipcontroller-client.json
212 url_file : str/unicode; path to ipcontroller-client.json
213 This JSON file should contain all the information needed to connect to a cluster,
213 This JSON file should contain all the information needed to connect to a cluster,
214 and is likely the only argument needed.
214 and is likely the only argument needed.
215 Connection information for the Hub's registration. If a json connector
215 Connection information for the Hub's registration. If a json connector
216 file is given, then likely no further configuration is necessary.
216 file is given, then likely no further configuration is necessary.
217 [Default: use profile]
217 [Default: use profile]
218 profile : bytes
218 profile : bytes
219 The name of the Cluster profile to be used to find connector information.
219 The name of the Cluster profile to be used to find connector information.
220 If run from an IPython application, the default profile will be the same
220 If run from an IPython application, the default profile will be the same
221 as the running application, otherwise it will be 'default'.
221 as the running application, otherwise it will be 'default'.
222 cluster_id : str
222 cluster_id : str
223 String id to added to runtime files, to prevent name collisions when using
223 String id to added to runtime files, to prevent name collisions when using
224 multiple clusters with a single profile simultaneously.
224 multiple clusters with a single profile simultaneously.
225 When set, will look for files named like: 'ipcontroller-<cluster_id>-client.json'
225 When set, will look for files named like: 'ipcontroller-<cluster_id>-client.json'
226 Since this is text inserted into filenames, typical recommendations apply:
226 Since this is text inserted into filenames, typical recommendations apply:
227 Simple character strings are ideal, and spaces are not recommended (but
227 Simple character strings are ideal, and spaces are not recommended (but
228 should generally work)
228 should generally work)
229 context : zmq.Context
229 context : zmq.Context
230 Pass an existing zmq.Context instance, otherwise the client will create its own.
230 Pass an existing zmq.Context instance, otherwise the client will create its own.
231 debug : bool
231 debug : bool
232 flag for lots of message printing for debug purposes
232 flag for lots of message printing for debug purposes
233 timeout : int/float
233 timeout : int/float
234 time (in seconds) to wait for connection replies from the Hub
234 time (in seconds) to wait for connection replies from the Hub
235 [Default: 10]
235 [Default: 10]
236
236
237 #-------------- session related args ----------------
237 #-------------- session related args ----------------
238
238
239 config : Config object
239 config : Config object
240 If specified, this will be relayed to the Session for configuration
240 If specified, this will be relayed to the Session for configuration
241 username : str
241 username : str
242 set username for the session object
242 set username for the session object
243
243
244 #-------------- ssh related args ----------------
244 #-------------- ssh related args ----------------
245 # These are args for configuring the ssh tunnel to be used
245 # These are args for configuring the ssh tunnel to be used
246 # credentials are used to forward connections over ssh to the Controller
246 # credentials are used to forward connections over ssh to the Controller
247 # Note that the ip given in `addr` needs to be relative to sshserver
247 # Note that the ip given in `addr` needs to be relative to sshserver
248 # The most basic case is to leave addr as pointing to localhost (127.0.0.1),
248 # The most basic case is to leave addr as pointing to localhost (127.0.0.1),
249 # and set sshserver as the same machine the Controller is on. However,
249 # and set sshserver as the same machine the Controller is on. However,
250 # the only requirement is that sshserver is able to see the Controller
250 # the only requirement is that sshserver is able to see the Controller
251 # (i.e. is within the same trusted network).
251 # (i.e. is within the same trusted network).
252
252
253 sshserver : str
253 sshserver : str
254 A string of the form passed to ssh, i.e. 'server.tld' or 'user@server.tld:port'
254 A string of the form passed to ssh, i.e. 'server.tld' or 'user@server.tld:port'
255 If keyfile or password is specified, and this is not, it will default to
255 If keyfile or password is specified, and this is not, it will default to
256 the ip given in addr.
256 the ip given in addr.
257 sshkey : str; path to ssh private key file
257 sshkey : str; path to ssh private key file
258 This specifies a key to be used in ssh login, default None.
258 This specifies a key to be used in ssh login, default None.
259 Regular default ssh keys will be used without specifying this argument.
259 Regular default ssh keys will be used without specifying this argument.
260 password : str
260 password : str
261 Your ssh password to sshserver. Note that if this is left None,
261 Your ssh password to sshserver. Note that if this is left None,
262 you will be prompted for it if passwordless key based login is unavailable.
262 you will be prompted for it if passwordless key based login is unavailable.
263 paramiko : bool
263 paramiko : bool
264 flag for whether to use paramiko instead of shell ssh for tunneling.
264 flag for whether to use paramiko instead of shell ssh for tunneling.
265 [default: True on win32, False else]
265 [default: True on win32, False else]
266
266
267
267
268 Attributes
268 Attributes
269 ----------
269 ----------
270
270
271 ids : list of int engine IDs
271 ids : list of int engine IDs
272 requesting the ids attribute always synchronizes
272 requesting the ids attribute always synchronizes
273 the registration state. To request ids without synchronization,
273 the registration state. To request ids without synchronization,
274 use semi-private _ids attributes.
274 use semi-private _ids attributes.
275
275
276 history : list of msg_ids
276 history : list of msg_ids
277 a list of msg_ids, keeping track of all the execution
277 a list of msg_ids, keeping track of all the execution
278 messages you have submitted in order.
278 messages you have submitted in order.
279
279
280 outstanding : set of msg_ids
280 outstanding : set of msg_ids
281 a set of msg_ids that have been submitted, but whose
281 a set of msg_ids that have been submitted, but whose
282 results have not yet been received.
282 results have not yet been received.
283
283
284 results : dict
284 results : dict
285 a dict of all our results, keyed by msg_id
285 a dict of all our results, keyed by msg_id
286
286
287 block : bool
287 block : bool
288 determines default behavior when block not specified
288 determines default behavior when block not specified
289 in execution methods
289 in execution methods
290
290
291 Methods
291 Methods
292 -------
292 -------
293
293
294 spin
294 spin
295 flushes incoming results and registration state changes
295 flushes incoming results and registration state changes
296 control methods spin, and requesting `ids` also ensures up to date
296 control methods spin, and requesting `ids` also ensures up to date
297
297
298 wait
298 wait
299 wait on one or more msg_ids
299 wait on one or more msg_ids
300
300
301 execution methods
301 execution methods
302 apply
302 apply
303 legacy: execute, run
303 legacy: execute, run
304
304
305 data movement
305 data movement
306 push, pull, scatter, gather
306 push, pull, scatter, gather
307
307
308 query methods
308 query methods
309 queue_status, get_result, purge, result_status
309 queue_status, get_result, purge, result_status
310
310
311 control methods
311 control methods
312 abort, shutdown
312 abort, shutdown
313
313
314 """
314 """
315
315
316
316
317 block = Bool(False)
317 block = Bool(False)
318 outstanding = Set()
318 outstanding = Set()
319 results = Instance('collections.defaultdict', (dict,))
319 results = Instance('collections.defaultdict', (dict,))
320 metadata = Instance('collections.defaultdict', (Metadata,))
320 metadata = Instance('collections.defaultdict', (Metadata,))
321 history = List()
321 history = List()
322 debug = Bool(False)
322 debug = Bool(False)
323 _spin_thread = Any()
323 _spin_thread = Any()
324 _stop_spinning = Any()
324 _stop_spinning = Any()
325
325
326 profile=Unicode()
326 profile=Unicode()
327 def _profile_default(self):
327 def _profile_default(self):
328 if BaseIPythonApplication.initialized():
328 if BaseIPythonApplication.initialized():
329 # an IPython app *might* be running, try to get its profile
329 # an IPython app *might* be running, try to get its profile
330 try:
330 try:
331 return BaseIPythonApplication.instance().profile
331 return BaseIPythonApplication.instance().profile
332 except (AttributeError, MultipleInstanceError):
332 except (AttributeError, MultipleInstanceError):
333 # could be a *different* subclass of config.Application,
333 # could be a *different* subclass of config.Application,
334 # which would raise one of these two errors.
334 # which would raise one of these two errors.
335 return u'default'
335 return u'default'
336 else:
336 else:
337 return u'default'
337 return u'default'
338
338
339
339
340 _outstanding_dict = Instance('collections.defaultdict', (set,))
340 _outstanding_dict = Instance('collections.defaultdict', (set,))
341 _ids = List()
341 _ids = List()
342 _connected=Bool(False)
342 _connected=Bool(False)
343 _ssh=Bool(False)
343 _ssh=Bool(False)
344 _context = Instance('zmq.Context')
344 _context = Instance('zmq.Context')
345 _config = Dict()
345 _config = Dict()
346 _engines=Instance(util.ReverseDict, (), {})
346 _engines=Instance(util.ReverseDict, (), {})
347 # _hub_socket=Instance('zmq.Socket')
347 # _hub_socket=Instance('zmq.Socket')
348 _query_socket=Instance('zmq.Socket')
348 _query_socket=Instance('zmq.Socket')
349 _control_socket=Instance('zmq.Socket')
349 _control_socket=Instance('zmq.Socket')
350 _iopub_socket=Instance('zmq.Socket')
350 _iopub_socket=Instance('zmq.Socket')
351 _notification_socket=Instance('zmq.Socket')
351 _notification_socket=Instance('zmq.Socket')
352 _mux_socket=Instance('zmq.Socket')
352 _mux_socket=Instance('zmq.Socket')
353 _task_socket=Instance('zmq.Socket')
353 _task_socket=Instance('zmq.Socket')
354 _task_scheme=Unicode()
354 _task_scheme=Unicode()
355 _closed = False
355 _closed = False
356 _ignored_control_replies=Integer(0)
356 _ignored_control_replies=Integer(0)
357 _ignored_hub_replies=Integer(0)
357 _ignored_hub_replies=Integer(0)
358
358
359 def __new__(self, *args, **kw):
359 def __new__(self, *args, **kw):
360 # don't raise on positional args
360 # don't raise on positional args
361 return HasTraits.__new__(self, **kw)
361 return HasTraits.__new__(self, **kw)
362
362
363 def __init__(self, url_file=None, profile=None, profile_dir=None, ipython_dir=None,
363 def __init__(self, url_file=None, profile=None, profile_dir=None, ipython_dir=None,
364 context=None, debug=False,
364 context=None, debug=False,
365 sshserver=None, sshkey=None, password=None, paramiko=None,
365 sshserver=None, sshkey=None, password=None, paramiko=None,
366 timeout=10, cluster_id=None, **extra_args
366 timeout=10, cluster_id=None, **extra_args
367 ):
367 ):
368 if profile:
368 if profile:
369 super(Client, self).__init__(debug=debug, profile=profile)
369 super(Client, self).__init__(debug=debug, profile=profile)
370 else:
370 else:
371 super(Client, self).__init__(debug=debug)
371 super(Client, self).__init__(debug=debug)
372 if context is None:
372 if context is None:
373 context = zmq.Context.instance()
373 context = zmq.Context.instance()
374 self._context = context
374 self._context = context
375 self._stop_spinning = Event()
375 self._stop_spinning = Event()
376
376
377 if 'url_or_file' in extra_args:
377 if 'url_or_file' in extra_args:
378 url_file = extra_args['url_or_file']
378 url_file = extra_args['url_or_file']
379 warnings.warn("url_or_file arg no longer supported, use url_file", DeprecationWarning)
379 warnings.warn("url_or_file arg no longer supported, use url_file", DeprecationWarning)
380
380
381 if url_file and util.is_url(url_file):
381 if url_file and util.is_url(url_file):
382 raise ValueError("single urls cannot be specified, url-files must be used.")
382 raise ValueError("single urls cannot be specified, url-files must be used.")
383
383
384 self._setup_profile_dir(self.profile, profile_dir, ipython_dir)
384 self._setup_profile_dir(self.profile, profile_dir, ipython_dir)
385
385
386 if self._cd is not None:
386 if self._cd is not None:
387 if url_file is None:
387 if url_file is None:
388 if not cluster_id:
388 if not cluster_id:
389 client_json = 'ipcontroller-client.json'
389 client_json = 'ipcontroller-client.json'
390 else:
390 else:
391 client_json = 'ipcontroller-%s-client.json' % cluster_id
391 client_json = 'ipcontroller-%s-client.json' % cluster_id
392 url_file = pjoin(self._cd.security_dir, client_json)
392 url_file = pjoin(self._cd.security_dir, client_json)
393 if url_file is None:
393 if url_file is None:
394 raise ValueError(
394 raise ValueError(
395 "I can't find enough information to connect to a hub!"
395 "I can't find enough information to connect to a hub!"
396 " Please specify at least one of url_file or profile."
396 " Please specify at least one of url_file or profile."
397 )
397 )
398
398
399 with open(url_file) as f:
399 with open(url_file) as f:
400 cfg = json.load(f)
400 cfg = json.load(f)
401
401
402 self._task_scheme = cfg['task_scheme']
402 self._task_scheme = cfg['task_scheme']
403
403
404 # sync defaults from args, json:
404 # sync defaults from args, json:
405 if sshserver:
405 if sshserver:
406 cfg['ssh'] = sshserver
406 cfg['ssh'] = sshserver
407
407
408 location = cfg.setdefault('location', None)
408 location = cfg.setdefault('location', None)
409
409
410 proto,addr = cfg['interface'].split('://')
410 proto,addr = cfg['interface'].split('://')
411 addr = util.disambiguate_ip_address(addr, location)
411 addr = util.disambiguate_ip_address(addr, location)
412 cfg['interface'] = "%s://%s" % (proto, addr)
412 cfg['interface'] = "%s://%s" % (proto, addr)
413
413
414 # turn interface,port into full urls:
414 # turn interface,port into full urls:
415 for key in ('control', 'task', 'mux', 'iopub', 'notification', 'registration'):
415 for key in ('control', 'task', 'mux', 'iopub', 'notification', 'registration'):
416 cfg[key] = cfg['interface'] + ':%i' % cfg[key]
416 cfg[key] = cfg['interface'] + ':%i' % cfg[key]
417
417
418 url = cfg['registration']
418 url = cfg['registration']
419
419
420 if location is not None and addr == localhost():
420 if location is not None and addr == localhost():
421 # location specified, and connection is expected to be local
421 # location specified, and connection is expected to be local
422 if not is_local_ip(location) and not sshserver:
422 if not is_local_ip(location) and not sshserver:
423 # load ssh from JSON *only* if the controller is not on
423 # load ssh from JSON *only* if the controller is not on
424 # this machine
424 # this machine
425 sshserver=cfg['ssh']
425 sshserver=cfg['ssh']
426 if not is_local_ip(location) and not sshserver:
426 if not is_local_ip(location) and not sshserver:
427 # warn if no ssh specified, but SSH is probably needed
427 # warn if no ssh specified, but SSH is probably needed
428 # This is only a warning, because the most likely cause
428 # This is only a warning, because the most likely cause
429 # is a local Controller on a laptop whose IP is dynamic
429 # is a local Controller on a laptop whose IP is dynamic
430 warnings.warn("""
430 warnings.warn("""
431 Controller appears to be listening on localhost, but not on this machine.
431 Controller appears to be listening on localhost, but not on this machine.
432 If this is true, you should specify Client(...,sshserver='you@%s')
432 If this is true, you should specify Client(...,sshserver='you@%s')
433 or instruct your controller to listen on an external IP."""%location,
433 or instruct your controller to listen on an external IP."""%location,
434 RuntimeWarning)
434 RuntimeWarning)
435 elif not sshserver:
435 elif not sshserver:
436 # otherwise sync with cfg
436 # otherwise sync with cfg
437 sshserver = cfg['ssh']
437 sshserver = cfg['ssh']
438
438
439 self._config = cfg
439 self._config = cfg
440
440
441 self._ssh = bool(sshserver or sshkey or password)
441 self._ssh = bool(sshserver or sshkey or password)
442 if self._ssh and sshserver is None:
442 if self._ssh and sshserver is None:
443 # default to ssh via localhost
443 # default to ssh via localhost
444 sshserver = addr
444 sshserver = addr
445 if self._ssh and password is None:
445 if self._ssh and password is None:
446 if tunnel.try_passwordless_ssh(sshserver, sshkey, paramiko):
446 if tunnel.try_passwordless_ssh(sshserver, sshkey, paramiko):
447 password=False
447 password=False
448 else:
448 else:
449 password = getpass("SSH Password for %s: "%sshserver)
449 password = getpass("SSH Password for %s: "%sshserver)
450 ssh_kwargs = dict(keyfile=sshkey, password=password, paramiko=paramiko)
450 ssh_kwargs = dict(keyfile=sshkey, password=password, paramiko=paramiko)
451
451
452 # configure and construct the session
452 # configure and construct the session
453 try:
453 try:
454 extra_args['packer'] = cfg['pack']
454 extra_args['packer'] = cfg['pack']
455 extra_args['unpacker'] = cfg['unpack']
455 extra_args['unpacker'] = cfg['unpack']
456 extra_args['key'] = cast_bytes(cfg['key'])
456 extra_args['key'] = cast_bytes(cfg['key'])
457 extra_args['signature_scheme'] = cfg['signature_scheme']
457 extra_args['signature_scheme'] = cfg['signature_scheme']
458 except KeyError as exc:
458 except KeyError as exc:
459 msg = '\n'.join([
459 msg = '\n'.join([
460 "Connection file is invalid (missing '{}'), possibly from an old version of IPython.",
460 "Connection file is invalid (missing '{}'), possibly from an old version of IPython.",
461 "If you are reusing connection files, remove them and start ipcontroller again."
461 "If you are reusing connection files, remove them and start ipcontroller again."
462 ])
462 ])
463 raise ValueError(msg.format(exc.message))
463 raise ValueError(msg.format(exc.message))
464
464
465 self.session = Session(**extra_args)
465 self.session = Session(**extra_args)
466
466
467 self._query_socket = self._context.socket(zmq.DEALER)
467 self._query_socket = self._context.socket(zmq.DEALER)
468
468
469 if self._ssh:
469 if self._ssh:
470 tunnel.tunnel_connection(self._query_socket, cfg['registration'], sshserver, **ssh_kwargs)
470 tunnel.tunnel_connection(self._query_socket, cfg['registration'], sshserver, **ssh_kwargs)
471 else:
471 else:
472 self._query_socket.connect(cfg['registration'])
472 self._query_socket.connect(cfg['registration'])
473
473
474 self.session.debug = self.debug
474 self.session.debug = self.debug
475
475
476 self._notification_handlers = {'registration_notification' : self._register_engine,
476 self._notification_handlers = {'registration_notification' : self._register_engine,
477 'unregistration_notification' : self._unregister_engine,
477 'unregistration_notification' : self._unregister_engine,
478 'shutdown_notification' : lambda msg: self.close(),
478 'shutdown_notification' : lambda msg: self.close(),
479 }
479 }
480 self._queue_handlers = {'execute_reply' : self._handle_execute_reply,
480 self._queue_handlers = {'execute_reply' : self._handle_execute_reply,
481 'apply_reply' : self._handle_apply_reply}
481 'apply_reply' : self._handle_apply_reply}
482
482
483 try:
483 try:
484 self._connect(sshserver, ssh_kwargs, timeout)
484 self._connect(sshserver, ssh_kwargs, timeout)
485 except:
485 except:
486 self.close(linger=0)
486 self.close(linger=0)
487 raise
487 raise
488
488
489 # last step: setup magics, if we are in IPython:
489 # last step: setup magics, if we are in IPython:
490
490
491 try:
491 try:
492 ip = get_ipython()
492 ip = get_ipython()
493 except NameError:
493 except NameError:
494 return
494 return
495 else:
495 else:
496 if 'px' not in ip.magics_manager.magics:
496 if 'px' not in ip.magics_manager.magics:
497 # in IPython but we are the first Client.
497 # in IPython but we are the first Client.
498 # activate a default view for parallel magics.
498 # activate a default view for parallel magics.
499 self.activate()
499 self.activate()
500
500
501 def __del__(self):
501 def __del__(self):
502 """cleanup sockets, but _not_ context."""
502 """cleanup sockets, but _not_ context."""
503 self.close()
503 self.close()
504
504
505 def _setup_profile_dir(self, profile, profile_dir, ipython_dir):
505 def _setup_profile_dir(self, profile, profile_dir, ipython_dir):
506 if ipython_dir is None:
506 if ipython_dir is None:
507 ipython_dir = get_ipython_dir()
507 ipython_dir = get_ipython_dir()
508 if profile_dir is not None:
508 if profile_dir is not None:
509 try:
509 try:
510 self._cd = ProfileDir.find_profile_dir(profile_dir)
510 self._cd = ProfileDir.find_profile_dir(profile_dir)
511 return
511 return
512 except ProfileDirError:
512 except ProfileDirError:
513 pass
513 pass
514 elif profile is not None:
514 elif profile is not None:
515 try:
515 try:
516 self._cd = ProfileDir.find_profile_dir_by_name(
516 self._cd = ProfileDir.find_profile_dir_by_name(
517 ipython_dir, profile)
517 ipython_dir, profile)
518 return
518 return
519 except ProfileDirError:
519 except ProfileDirError:
520 pass
520 pass
521 self._cd = None
521 self._cd = None
522
522
523 def _update_engines(self, engines):
523 def _update_engines(self, engines):
524 """Update our engines dict and _ids from a dict of the form: {id:uuid}."""
524 """Update our engines dict and _ids from a dict of the form: {id:uuid}."""
525 for k,v in iteritems(engines):
525 for k,v in iteritems(engines):
526 eid = int(k)
526 eid = int(k)
527 if eid not in self._engines:
527 if eid not in self._engines:
528 self._ids.append(eid)
528 self._ids.append(eid)
529 self._engines[eid] = v
529 self._engines[eid] = v
530 self._ids = sorted(self._ids)
530 self._ids = sorted(self._ids)
531 if sorted(self._engines.keys()) != list(range(len(self._engines))) and \
531 if sorted(self._engines.keys()) != list(range(len(self._engines))) and \
532 self._task_scheme == 'pure' and self._task_socket:
532 self._task_scheme == 'pure' and self._task_socket:
533 self._stop_scheduling_tasks()
533 self._stop_scheduling_tasks()
534
534
535 def _stop_scheduling_tasks(self):
535 def _stop_scheduling_tasks(self):
536 """Stop scheduling tasks because an engine has been unregistered
536 """Stop scheduling tasks because an engine has been unregistered
537 from a pure ZMQ scheduler.
537 from a pure ZMQ scheduler.
538 """
538 """
539 self._task_socket.close()
539 self._task_socket.close()
540 self._task_socket = None
540 self._task_socket = None
541 msg = "An engine has been unregistered, and we are using pure " +\
541 msg = "An engine has been unregistered, and we are using pure " +\
542 "ZMQ task scheduling. Task farming will be disabled."
542 "ZMQ task scheduling. Task farming will be disabled."
543 if self.outstanding:
543 if self.outstanding:
544 msg += " If you were running tasks when this happened, " +\
544 msg += " If you were running tasks when this happened, " +\
545 "some `outstanding` msg_ids may never resolve."
545 "some `outstanding` msg_ids may never resolve."
546 warnings.warn(msg, RuntimeWarning)
546 warnings.warn(msg, RuntimeWarning)
547
547
548 def _build_targets(self, targets):
548 def _build_targets(self, targets):
549 """Turn valid target IDs or 'all' into two lists:
549 """Turn valid target IDs or 'all' into two lists:
550 (int_ids, uuids).
550 (int_ids, uuids).
551 """
551 """
552 if not self._ids:
552 if not self._ids:
553 # flush notification socket if no engines yet, just in case
553 # flush notification socket if no engines yet, just in case
554 if not self.ids:
554 if not self.ids:
555 raise error.NoEnginesRegistered("Can't build targets without any engines")
555 raise error.NoEnginesRegistered("Can't build targets without any engines")
556
556
557 if targets is None:
557 if targets is None:
558 targets = self._ids
558 targets = self._ids
559 elif isinstance(targets, string_types):
559 elif isinstance(targets, string_types):
560 if targets.lower() == 'all':
560 if targets.lower() == 'all':
561 targets = self._ids
561 targets = self._ids
562 else:
562 else:
563 raise TypeError("%r not valid str target, must be 'all'"%(targets))
563 raise TypeError("%r not valid str target, must be 'all'"%(targets))
564 elif isinstance(targets, int):
564 elif isinstance(targets, int):
565 if targets < 0:
565 if targets < 0:
566 targets = self.ids[targets]
566 targets = self.ids[targets]
567 if targets not in self._ids:
567 if targets not in self._ids:
568 raise IndexError("No such engine: %i"%targets)
568 raise IndexError("No such engine: %i"%targets)
569 targets = [targets]
569 targets = [targets]
570
570
571 if isinstance(targets, slice):
571 if isinstance(targets, slice):
572 indices = list(range(len(self._ids))[targets])
572 indices = list(range(len(self._ids))[targets])
573 ids = self.ids
573 ids = self.ids
574 targets = [ ids[i] for i in indices ]
574 targets = [ ids[i] for i in indices ]
575
575
576 if not isinstance(targets, (tuple, list, xrange)):
576 if not isinstance(targets, (tuple, list, xrange)):
577 raise TypeError("targets by int/slice/collection of ints only, not %s"%(type(targets)))
577 raise TypeError("targets by int/slice/collection of ints only, not %s"%(type(targets)))
578
578
579 return [cast_bytes(self._engines[t]) for t in targets], list(targets)
579 return [cast_bytes(self._engines[t]) for t in targets], list(targets)
580
580
581 def _connect(self, sshserver, ssh_kwargs, timeout):
581 def _connect(self, sshserver, ssh_kwargs, timeout):
582 """setup all our socket connections to the cluster. This is called from
582 """setup all our socket connections to the cluster. This is called from
583 __init__."""
583 __init__."""
584
584
585 # Maybe allow reconnecting?
585 # Maybe allow reconnecting?
586 if self._connected:
586 if self._connected:
587 return
587 return
588 self._connected=True
588 self._connected=True
589
589
590 def connect_socket(s, url):
590 def connect_socket(s, url):
591 if self._ssh:
591 if self._ssh:
592 return tunnel.tunnel_connection(s, url, sshserver, **ssh_kwargs)
592 return tunnel.tunnel_connection(s, url, sshserver, **ssh_kwargs)
593 else:
593 else:
594 return s.connect(url)
594 return s.connect(url)
595
595
596 self.session.send(self._query_socket, 'connection_request')
596 self.session.send(self._query_socket, 'connection_request')
597 # use Poller because zmq.select has wrong units in pyzmq 2.1.7
597 # use Poller because zmq.select has wrong units in pyzmq 2.1.7
598 poller = zmq.Poller()
598 poller = zmq.Poller()
599 poller.register(self._query_socket, zmq.POLLIN)
599 poller.register(self._query_socket, zmq.POLLIN)
600 # poll expects milliseconds, timeout is seconds
600 # poll expects milliseconds, timeout is seconds
601 evts = poller.poll(timeout*1000)
601 evts = poller.poll(timeout*1000)
602 if not evts:
602 if not evts:
603 raise error.TimeoutError("Hub connection request timed out")
603 raise error.TimeoutError("Hub connection request timed out")
604 idents,msg = self.session.recv(self._query_socket,mode=0)
604 idents,msg = self.session.recv(self._query_socket,mode=0)
605 if self.debug:
605 if self.debug:
606 pprint(msg)
606 pprint(msg)
607 content = msg['content']
607 content = msg['content']
608 # self._config['registration'] = dict(content)
608 # self._config['registration'] = dict(content)
609 cfg = self._config
609 cfg = self._config
610 if content['status'] == 'ok':
610 if content['status'] == 'ok':
611 self._mux_socket = self._context.socket(zmq.DEALER)
611 self._mux_socket = self._context.socket(zmq.DEALER)
612 connect_socket(self._mux_socket, cfg['mux'])
612 connect_socket(self._mux_socket, cfg['mux'])
613
613
614 self._task_socket = self._context.socket(zmq.DEALER)
614 self._task_socket = self._context.socket(zmq.DEALER)
615 connect_socket(self._task_socket, cfg['task'])
615 connect_socket(self._task_socket, cfg['task'])
616
616
617 self._notification_socket = self._context.socket(zmq.SUB)
617 self._notification_socket = self._context.socket(zmq.SUB)
618 self._notification_socket.setsockopt(zmq.SUBSCRIBE, b'')
618 self._notification_socket.setsockopt(zmq.SUBSCRIBE, b'')
619 connect_socket(self._notification_socket, cfg['notification'])
619 connect_socket(self._notification_socket, cfg['notification'])
620
620
621 self._control_socket = self._context.socket(zmq.DEALER)
621 self._control_socket = self._context.socket(zmq.DEALER)
622 connect_socket(self._control_socket, cfg['control'])
622 connect_socket(self._control_socket, cfg['control'])
623
623
624 self._iopub_socket = self._context.socket(zmq.SUB)
624 self._iopub_socket = self._context.socket(zmq.SUB)
625 self._iopub_socket.setsockopt(zmq.SUBSCRIBE, b'')
625 self._iopub_socket.setsockopt(zmq.SUBSCRIBE, b'')
626 connect_socket(self._iopub_socket, cfg['iopub'])
626 connect_socket(self._iopub_socket, cfg['iopub'])
627
627
628 self._update_engines(dict(content['engines']))
628 self._update_engines(dict(content['engines']))
629 else:
629 else:
630 self._connected = False
630 self._connected = False
631 raise Exception("Failed to connect!")
631 raise Exception("Failed to connect!")
632
632
633 #--------------------------------------------------------------------------
633 #--------------------------------------------------------------------------
634 # handlers and callbacks for incoming messages
634 # handlers and callbacks for incoming messages
635 #--------------------------------------------------------------------------
635 #--------------------------------------------------------------------------
636
636
637 def _unwrap_exception(self, content):
637 def _unwrap_exception(self, content):
638 """unwrap exception, and remap engine_id to int."""
638 """unwrap exception, and remap engine_id to int."""
639 e = error.unwrap_exception(content)
639 e = error.unwrap_exception(content)
640 # print e.traceback
640 # print e.traceback
641 if e.engine_info:
641 if e.engine_info:
642 e_uuid = e.engine_info['engine_uuid']
642 e_uuid = e.engine_info['engine_uuid']
643 eid = self._engines[e_uuid]
643 eid = self._engines[e_uuid]
644 e.engine_info['engine_id'] = eid
644 e.engine_info['engine_id'] = eid
645 return e
645 return e
646
646
647 def _extract_metadata(self, msg):
647 def _extract_metadata(self, msg):
648 header = msg['header']
648 header = msg['header']
649 parent = msg['parent_header']
649 parent = msg['parent_header']
650 msg_meta = msg['metadata']
650 msg_meta = msg['metadata']
651 content = msg['content']
651 content = msg['content']
652 md = {'msg_id' : parent['msg_id'],
652 md = {'msg_id' : parent['msg_id'],
653 'received' : datetime.now(),
653 'received' : datetime.now(),
654 'engine_uuid' : msg_meta.get('engine', None),
654 'engine_uuid' : msg_meta.get('engine', None),
655 'follow' : msg_meta.get('follow', []),
655 'follow' : msg_meta.get('follow', []),
656 'after' : msg_meta.get('after', []),
656 'after' : msg_meta.get('after', []),
657 'status' : content['status'],
657 'status' : content['status'],
658 }
658 }
659
659
660 if md['engine_uuid'] is not None:
660 if md['engine_uuid'] is not None:
661 md['engine_id'] = self._engines.get(md['engine_uuid'], None)
661 md['engine_id'] = self._engines.get(md['engine_uuid'], None)
662
662
663 if 'date' in parent:
663 if 'date' in parent:
664 md['submitted'] = parent['date']
664 md['submitted'] = parent['date']
665 if 'started' in msg_meta:
665 if 'started' in msg_meta:
666 md['started'] = parse_date(msg_meta['started'])
666 md['started'] = parse_date(msg_meta['started'])
667 if 'date' in header:
667 if 'date' in header:
668 md['completed'] = header['date']
668 md['completed'] = header['date']
669 return md
669 return md
670
670
671 def _register_engine(self, msg):
671 def _register_engine(self, msg):
672 """Register a new engine, and update our connection info."""
672 """Register a new engine, and update our connection info."""
673 content = msg['content']
673 content = msg['content']
674 eid = content['id']
674 eid = content['id']
675 d = {eid : content['uuid']}
675 d = {eid : content['uuid']}
676 self._update_engines(d)
676 self._update_engines(d)
677
677
678 def _unregister_engine(self, msg):
678 def _unregister_engine(self, msg):
679 """Unregister an engine that has died."""
679 """Unregister an engine that has died."""
680 content = msg['content']
680 content = msg['content']
681 eid = int(content['id'])
681 eid = int(content['id'])
682 if eid in self._ids:
682 if eid in self._ids:
683 self._ids.remove(eid)
683 self._ids.remove(eid)
684 uuid = self._engines.pop(eid)
684 uuid = self._engines.pop(eid)
685
685
686 self._handle_stranded_msgs(eid, uuid)
686 self._handle_stranded_msgs(eid, uuid)
687
687
688 if self._task_socket and self._task_scheme == 'pure':
688 if self._task_socket and self._task_scheme == 'pure':
689 self._stop_scheduling_tasks()
689 self._stop_scheduling_tasks()
690
690
691 def _handle_stranded_msgs(self, eid, uuid):
691 def _handle_stranded_msgs(self, eid, uuid):
692 """Handle messages known to be on an engine when the engine unregisters.
692 """Handle messages known to be on an engine when the engine unregisters.
693
693
694 It is possible that this will fire prematurely - that is, an engine will
694 It is possible that this will fire prematurely - that is, an engine will
695 go down after completing a result, and the client will be notified
695 go down after completing a result, and the client will be notified
696 of the unregistration and later receive the successful result.
696 of the unregistration and later receive the successful result.
697 """
697 """
698
698
699 outstanding = self._outstanding_dict[uuid]
699 outstanding = self._outstanding_dict[uuid]
700
700
701 for msg_id in list(outstanding):
701 for msg_id in list(outstanding):
702 if msg_id in self.results:
702 if msg_id in self.results:
703 # we already
703 # we already
704 continue
704 continue
705 try:
705 try:
706 raise error.EngineError("Engine %r died while running task %r"%(eid, msg_id))
706 raise error.EngineError("Engine %r died while running task %r"%(eid, msg_id))
707 except:
707 except:
708 content = error.wrap_exception()
708 content = error.wrap_exception()
709 # build a fake message:
709 # build a fake message:
710 msg = self.session.msg('apply_reply', content=content)
710 msg = self.session.msg('apply_reply', content=content)
711 msg['parent_header']['msg_id'] = msg_id
711 msg['parent_header']['msg_id'] = msg_id
712 msg['metadata']['engine'] = uuid
712 msg['metadata']['engine'] = uuid
713 self._handle_apply_reply(msg)
713 self._handle_apply_reply(msg)
714
714
715 def _handle_execute_reply(self, msg):
715 def _handle_execute_reply(self, msg):
716 """Save the reply to an execute_request into our results.
716 """Save the reply to an execute_request into our results.
717
717
718 execute messages are never actually used. apply is used instead.
718 execute messages are never actually used. apply is used instead.
719 """
719 """
720
720
721 parent = msg['parent_header']
721 parent = msg['parent_header']
722 msg_id = parent['msg_id']
722 msg_id = parent['msg_id']
723 if msg_id not in self.outstanding:
723 if msg_id not in self.outstanding:
724 if msg_id in self.history:
724 if msg_id in self.history:
725 print("got stale result: %s"%msg_id)
725 print("got stale result: %s"%msg_id)
726 else:
726 else:
727 print("got unknown result: %s"%msg_id)
727 print("got unknown result: %s"%msg_id)
728 else:
728 else:
729 self.outstanding.remove(msg_id)
729 self.outstanding.remove(msg_id)
730
730
731 content = msg['content']
731 content = msg['content']
732 header = msg['header']
732 header = msg['header']
733
733
734 # construct metadata:
734 # construct metadata:
735 md = self.metadata[msg_id]
735 md = self.metadata[msg_id]
736 md.update(self._extract_metadata(msg))
736 md.update(self._extract_metadata(msg))
737 # is this redundant?
737 # is this redundant?
738 self.metadata[msg_id] = md
738 self.metadata[msg_id] = md
739
739
740 e_outstanding = self._outstanding_dict[md['engine_uuid']]
740 e_outstanding = self._outstanding_dict[md['engine_uuid']]
741 if msg_id in e_outstanding:
741 if msg_id in e_outstanding:
742 e_outstanding.remove(msg_id)
742 e_outstanding.remove(msg_id)
743
743
744 # construct result:
744 # construct result:
745 if content['status'] == 'ok':
745 if content['status'] == 'ok':
746 self.results[msg_id] = ExecuteReply(msg_id, content, md)
746 self.results[msg_id] = ExecuteReply(msg_id, content, md)
747 elif content['status'] == 'aborted':
747 elif content['status'] == 'aborted':
748 self.results[msg_id] = error.TaskAborted(msg_id)
748 self.results[msg_id] = error.TaskAborted(msg_id)
749 elif content['status'] == 'resubmitted':
749 elif content['status'] == 'resubmitted':
750 # TODO: handle resubmission
750 # TODO: handle resubmission
751 pass
751 pass
752 else:
752 else:
753 self.results[msg_id] = self._unwrap_exception(content)
753 self.results[msg_id] = self._unwrap_exception(content)
754
754
755 def _handle_apply_reply(self, msg):
755 def _handle_apply_reply(self, msg):
756 """Save the reply to an apply_request into our results."""
756 """Save the reply to an apply_request into our results."""
757 parent = msg['parent_header']
757 parent = msg['parent_header']
758 msg_id = parent['msg_id']
758 msg_id = parent['msg_id']
759 if msg_id not in self.outstanding:
759 if msg_id not in self.outstanding:
760 if msg_id in self.history:
760 if msg_id in self.history:
761 print("got stale result: %s"%msg_id)
761 print("got stale result: %s"%msg_id)
762 print(self.results[msg_id])
762 print(self.results[msg_id])
763 print(msg)
763 print(msg)
764 else:
764 else:
765 print("got unknown result: %s"%msg_id)
765 print("got unknown result: %s"%msg_id)
766 else:
766 else:
767 self.outstanding.remove(msg_id)
767 self.outstanding.remove(msg_id)
768 content = msg['content']
768 content = msg['content']
769 header = msg['header']
769 header = msg['header']
770
770
771 # construct metadata:
771 # construct metadata:
772 md = self.metadata[msg_id]
772 md = self.metadata[msg_id]
773 md.update(self._extract_metadata(msg))
773 md.update(self._extract_metadata(msg))
774 # is this redundant?
774 # is this redundant?
775 self.metadata[msg_id] = md
775 self.metadata[msg_id] = md
776
776
777 e_outstanding = self._outstanding_dict[md['engine_uuid']]
777 e_outstanding = self._outstanding_dict[md['engine_uuid']]
778 if msg_id in e_outstanding:
778 if msg_id in e_outstanding:
779 e_outstanding.remove(msg_id)
779 e_outstanding.remove(msg_id)
780
780
781 # construct result:
781 # construct result:
782 if content['status'] == 'ok':
782 if content['status'] == 'ok':
783 self.results[msg_id] = serialize.unserialize_object(msg['buffers'])[0]
783 self.results[msg_id] = serialize.unserialize_object(msg['buffers'])[0]
784 elif content['status'] == 'aborted':
784 elif content['status'] == 'aborted':
785 self.results[msg_id] = error.TaskAborted(msg_id)
785 self.results[msg_id] = error.TaskAborted(msg_id)
786 elif content['status'] == 'resubmitted':
786 elif content['status'] == 'resubmitted':
787 # TODO: handle resubmission
787 # TODO: handle resubmission
788 pass
788 pass
789 else:
789 else:
790 self.results[msg_id] = self._unwrap_exception(content)
790 self.results[msg_id] = self._unwrap_exception(content)
791
791
792 def _flush_notifications(self):
792 def _flush_notifications(self):
793 """Flush notifications of engine registrations waiting
793 """Flush notifications of engine registrations waiting
794 in ZMQ queue."""
794 in ZMQ queue."""
795 idents,msg = self.session.recv(self._notification_socket, mode=zmq.NOBLOCK)
795 idents,msg = self.session.recv(self._notification_socket, mode=zmq.NOBLOCK)
796 while msg is not None:
796 while msg is not None:
797 if self.debug:
797 if self.debug:
798 pprint(msg)
798 pprint(msg)
799 msg_type = msg['header']['msg_type']
799 msg_type = msg['header']['msg_type']
800 handler = self._notification_handlers.get(msg_type, None)
800 handler = self._notification_handlers.get(msg_type, None)
801 if handler is None:
801 if handler is None:
802 raise Exception("Unhandled message type: %s" % msg_type)
802 raise Exception("Unhandled message type: %s" % msg_type)
803 else:
803 else:
804 handler(msg)
804 handler(msg)
805 idents,msg = self.session.recv(self._notification_socket, mode=zmq.NOBLOCK)
805 idents,msg = self.session.recv(self._notification_socket, mode=zmq.NOBLOCK)
806
806
807 def _flush_results(self, sock):
807 def _flush_results(self, sock):
808 """Flush task or queue results waiting in ZMQ queue."""
808 """Flush task or queue results waiting in ZMQ queue."""
809 idents,msg = self.session.recv(sock, mode=zmq.NOBLOCK)
809 idents,msg = self.session.recv(sock, mode=zmq.NOBLOCK)
810 while msg is not None:
810 while msg is not None:
811 if self.debug:
811 if self.debug:
812 pprint(msg)
812 pprint(msg)
813 msg_type = msg['header']['msg_type']
813 msg_type = msg['header']['msg_type']
814 handler = self._queue_handlers.get(msg_type, None)
814 handler = self._queue_handlers.get(msg_type, None)
815 if handler is None:
815 if handler is None:
816 raise Exception("Unhandled message type: %s" % msg_type)
816 raise Exception("Unhandled message type: %s" % msg_type)
817 else:
817 else:
818 handler(msg)
818 handler(msg)
819 idents,msg = self.session.recv(sock, mode=zmq.NOBLOCK)
819 idents,msg = self.session.recv(sock, mode=zmq.NOBLOCK)
820
820
821 def _flush_control(self, sock):
821 def _flush_control(self, sock):
822 """Flush replies from the control channel waiting
822 """Flush replies from the control channel waiting
823 in the ZMQ queue.
823 in the ZMQ queue.
824
824
825 Currently: ignore them."""
825 Currently: ignore them."""
826 if self._ignored_control_replies <= 0:
826 if self._ignored_control_replies <= 0:
827 return
827 return
828 idents,msg = self.session.recv(sock, mode=zmq.NOBLOCK)
828 idents,msg = self.session.recv(sock, mode=zmq.NOBLOCK)
829 while msg is not None:
829 while msg is not None:
830 self._ignored_control_replies -= 1
830 self._ignored_control_replies -= 1
831 if self.debug:
831 if self.debug:
832 pprint(msg)
832 pprint(msg)
833 idents,msg = self.session.recv(sock, mode=zmq.NOBLOCK)
833 idents,msg = self.session.recv(sock, mode=zmq.NOBLOCK)
834
834
835 def _flush_ignored_control(self):
835 def _flush_ignored_control(self):
836 """flush ignored control replies"""
836 """flush ignored control replies"""
837 while self._ignored_control_replies > 0:
837 while self._ignored_control_replies > 0:
838 self.session.recv(self._control_socket)
838 self.session.recv(self._control_socket)
839 self._ignored_control_replies -= 1
839 self._ignored_control_replies -= 1
840
840
841 def _flush_ignored_hub_replies(self):
841 def _flush_ignored_hub_replies(self):
842 ident,msg = self.session.recv(self._query_socket, mode=zmq.NOBLOCK)
842 ident,msg = self.session.recv(self._query_socket, mode=zmq.NOBLOCK)
843 while msg is not None:
843 while msg is not None:
844 ident,msg = self.session.recv(self._query_socket, mode=zmq.NOBLOCK)
844 ident,msg = self.session.recv(self._query_socket, mode=zmq.NOBLOCK)
845
845
846 def _flush_iopub(self, sock):
846 def _flush_iopub(self, sock):
847 """Flush replies from the iopub channel waiting
847 """Flush replies from the iopub channel waiting
848 in the ZMQ queue.
848 in the ZMQ queue.
849 """
849 """
850 idents,msg = self.session.recv(sock, mode=zmq.NOBLOCK)
850 idents,msg = self.session.recv(sock, mode=zmq.NOBLOCK)
851 while msg is not None:
851 while msg is not None:
852 if self.debug:
852 if self.debug:
853 pprint(msg)
853 pprint(msg)
854 parent = msg['parent_header']
854 parent = msg['parent_header']
855 # ignore IOPub messages with no parent.
855 # ignore IOPub messages with no parent.
856 # Caused by print statements or warnings from before the first execution.
856 # Caused by print statements or warnings from before the first execution.
857 if not parent:
857 if not parent:
858 idents,msg = self.session.recv(sock, mode=zmq.NOBLOCK)
858 idents,msg = self.session.recv(sock, mode=zmq.NOBLOCK)
859 continue
859 continue
860 msg_id = parent['msg_id']
860 msg_id = parent['msg_id']
861 content = msg['content']
861 content = msg['content']
862 header = msg['header']
862 header = msg['header']
863 msg_type = msg['header']['msg_type']
863 msg_type = msg['header']['msg_type']
864
864
865 # init metadata:
865 # init metadata:
866 md = self.metadata[msg_id]
866 md = self.metadata[msg_id]
867
867
868 if msg_type == 'stream':
868 if msg_type == 'stream':
869 name = content['name']
869 name = content['name']
870 s = md[name] or ''
870 s = md[name] or ''
871 md[name] = s + content['data']
871 md[name] = s + content['data']
872 elif msg_type == 'error':
872 elif msg_type == 'error':
873 md.update({'error' : self._unwrap_exception(content)})
873 md.update({'error' : self._unwrap_exception(content)})
874 elif msg_type == 'execute_input':
874 elif msg_type == 'execute_input':
875 md.update({'execute_input' : content['code']})
875 md.update({'execute_input' : content['code']})
876 elif msg_type == 'display_data':
876 elif msg_type == 'display_data':
877 md['outputs'].append(content)
877 md['outputs'].append(content)
878 elif msg_type == 'execute_result':
878 elif msg_type == 'execute_result':
879 md['execute_result'] = content
879 md['execute_result'] = content
880 elif msg_type == 'data_message':
880 elif msg_type == 'data_message':
881 data, remainder = serialize.unserialize_object(msg['buffers'])
881 data, remainder = serialize.unserialize_object(msg['buffers'])
882 md['data'].update(data)
882 md['data'].update(data)
883 elif msg_type == 'status':
883 elif msg_type == 'status':
884 # idle message comes after all outputs
884 # idle message comes after all outputs
885 if content['execution_state'] == 'idle':
885 if content['execution_state'] == 'idle':
886 md['outputs_ready'] = True
886 md['outputs_ready'] = True
887 else:
887 else:
888 # unhandled msg_type (status, etc.)
888 # unhandled msg_type (status, etc.)
889 pass
889 pass
890
890
891 # reduntant?
891 # reduntant?
892 self.metadata[msg_id] = md
892 self.metadata[msg_id] = md
893
893
894 idents,msg = self.session.recv(sock, mode=zmq.NOBLOCK)
894 idents,msg = self.session.recv(sock, mode=zmq.NOBLOCK)
895
895
896 #--------------------------------------------------------------------------
896 #--------------------------------------------------------------------------
897 # len, getitem
897 # len, getitem
898 #--------------------------------------------------------------------------
898 #--------------------------------------------------------------------------
899
899
900 def __len__(self):
900 def __len__(self):
901 """len(client) returns # of engines."""
901 """len(client) returns # of engines."""
902 return len(self.ids)
902 return len(self.ids)
903
903
904 def __getitem__(self, key):
904 def __getitem__(self, key):
905 """index access returns DirectView multiplexer objects
905 """index access returns DirectView multiplexer objects
906
906
907 Must be int, slice, or list/tuple/xrange of ints"""
907 Must be int, slice, or list/tuple/xrange of ints"""
908 if not isinstance(key, (int, slice, tuple, list, xrange)):
908 if not isinstance(key, (int, slice, tuple, list, xrange)):
909 raise TypeError("key by int/slice/iterable of ints only, not %s"%(type(key)))
909 raise TypeError("key by int/slice/iterable of ints only, not %s"%(type(key)))
910 else:
910 else:
911 return self.direct_view(key)
911 return self.direct_view(key)
912
912
913 def __iter__(self):
913 def __iter__(self):
914 """Since we define getitem, Client is iterable
914 """Since we define getitem, Client is iterable
915
915
916 but unless we also define __iter__, it won't work correctly unless engine IDs
916 but unless we also define __iter__, it won't work correctly unless engine IDs
917 start at zero and are continuous.
917 start at zero and are continuous.
918 """
918 """
919 for eid in self.ids:
919 for eid in self.ids:
920 yield self.direct_view(eid)
920 yield self.direct_view(eid)
921
921
922 #--------------------------------------------------------------------------
922 #--------------------------------------------------------------------------
923 # Begin public methods
923 # Begin public methods
924 #--------------------------------------------------------------------------
924 #--------------------------------------------------------------------------
925
925
926 @property
926 @property
927 def ids(self):
927 def ids(self):
928 """Always up-to-date ids property."""
928 """Always up-to-date ids property."""
929 self._flush_notifications()
929 self._flush_notifications()
930 # always copy:
930 # always copy:
931 return list(self._ids)
931 return list(self._ids)
932
932
933 def activate(self, targets='all', suffix=''):
933 def activate(self, targets='all', suffix=''):
934 """Create a DirectView and register it with IPython magics
934 """Create a DirectView and register it with IPython magics
935
935
936 Defines the magics `%px, %autopx, %pxresult, %%px`
936 Defines the magics `%px, %autopx, %pxresult, %%px`
937
937
938 Parameters
938 Parameters
939 ----------
939 ----------
940
940
941 targets: int, list of ints, or 'all'
941 targets: int, list of ints, or 'all'
942 The engines on which the view's magics will run
942 The engines on which the view's magics will run
943 suffix: str [default: '']
943 suffix: str [default: '']
944 The suffix, if any, for the magics. This allows you to have
944 The suffix, if any, for the magics. This allows you to have
945 multiple views associated with parallel magics at the same time.
945 multiple views associated with parallel magics at the same time.
946
946
947 e.g. ``rc.activate(targets=0, suffix='0')`` will give you
947 e.g. ``rc.activate(targets=0, suffix='0')`` will give you
948 the magics ``%px0``, ``%pxresult0``, etc. for running magics just
948 the magics ``%px0``, ``%pxresult0``, etc. for running magics just
949 on engine 0.
949 on engine 0.
950 """
950 """
951 view = self.direct_view(targets)
951 view = self.direct_view(targets)
952 view.block = True
952 view.block = True
953 view.activate(suffix)
953 view.activate(suffix)
954 return view
954 return view
955
955
956 def close(self, linger=None):
956 def close(self, linger=None):
957 """Close my zmq Sockets
957 """Close my zmq Sockets
958
958
959 If `linger`, set the zmq LINGER socket option,
959 If `linger`, set the zmq LINGER socket option,
960 which allows discarding of messages.
960 which allows discarding of messages.
961 """
961 """
962 if self._closed:
962 if self._closed:
963 return
963 return
964 self.stop_spin_thread()
964 self.stop_spin_thread()
965 snames = [ trait for trait in self.trait_names() if trait.endswith("socket") ]
965 snames = [ trait for trait in self.trait_names() if trait.endswith("socket") ]
966 for name in snames:
966 for name in snames:
967 socket = getattr(self, name)
967 socket = getattr(self, name)
968 if socket is not None and not socket.closed:
968 if socket is not None and not socket.closed:
969 if linger is not None:
969 if linger is not None:
970 socket.close(linger=linger)
970 socket.close(linger=linger)
971 else:
971 else:
972 socket.close()
972 socket.close()
973 self._closed = True
973 self._closed = True
974
974
975 def _spin_every(self, interval=1):
975 def _spin_every(self, interval=1):
976 """target func for use in spin_thread"""
976 """target func for use in spin_thread"""
977 while True:
977 while True:
978 if self._stop_spinning.is_set():
978 if self._stop_spinning.is_set():
979 return
979 return
980 time.sleep(interval)
980 time.sleep(interval)
981 self.spin()
981 self.spin()
982
982
983 def spin_thread(self, interval=1):
983 def spin_thread(self, interval=1):
984 """call Client.spin() in a background thread on some regular interval
984 """call Client.spin() in a background thread on some regular interval
985
985
986 This helps ensure that messages don't pile up too much in the zmq queue
986 This helps ensure that messages don't pile up too much in the zmq queue
987 while you are working on other things, or just leaving an idle terminal.
987 while you are working on other things, or just leaving an idle terminal.
988
988
989 It also helps limit potential padding of the `received` timestamp
989 It also helps limit potential padding of the `received` timestamp
990 on AsyncResult objects, used for timings.
990 on AsyncResult objects, used for timings.
991
991
992 Parameters
992 Parameters
993 ----------
993 ----------
994
994
995 interval : float, optional
995 interval : float, optional
996 The interval on which to spin the client in the background thread
996 The interval on which to spin the client in the background thread
997 (simply passed to time.sleep).
997 (simply passed to time.sleep).
998
998
999 Notes
999 Notes
1000 -----
1000 -----
1001
1001
1002 For precision timing, you may want to use this method to put a bound
1002 For precision timing, you may want to use this method to put a bound
1003 on the jitter (in seconds) in `received` timestamps used
1003 on the jitter (in seconds) in `received` timestamps used
1004 in AsyncResult.wall_time.
1004 in AsyncResult.wall_time.
1005
1005
1006 """
1006 """
1007 if self._spin_thread is not None:
1007 if self._spin_thread is not None:
1008 self.stop_spin_thread()
1008 self.stop_spin_thread()
1009 self._stop_spinning.clear()
1009 self._stop_spinning.clear()
1010 self._spin_thread = Thread(target=self._spin_every, args=(interval,))
1010 self._spin_thread = Thread(target=self._spin_every, args=(interval,))
1011 self._spin_thread.daemon = True
1011 self._spin_thread.daemon = True
1012 self._spin_thread.start()
1012 self._spin_thread.start()
1013
1013
1014 def stop_spin_thread(self):
1014 def stop_spin_thread(self):
1015 """stop background spin_thread, if any"""
1015 """stop background spin_thread, if any"""
1016 if self._spin_thread is not None:
1016 if self._spin_thread is not None:
1017 self._stop_spinning.set()
1017 self._stop_spinning.set()
1018 self._spin_thread.join()
1018 self._spin_thread.join()
1019 self._spin_thread = None
1019 self._spin_thread = None
1020
1020
1021 def spin(self):
1021 def spin(self):
1022 """Flush any registration notifications and execution results
1022 """Flush any registration notifications and execution results
1023 waiting in the ZMQ queue.
1023 waiting in the ZMQ queue.
1024 """
1024 """
1025 if self._notification_socket:
1025 if self._notification_socket:
1026 self._flush_notifications()
1026 self._flush_notifications()
1027 if self._iopub_socket:
1027 if self._iopub_socket:
1028 self._flush_iopub(self._iopub_socket)
1028 self._flush_iopub(self._iopub_socket)
1029 if self._mux_socket:
1029 if self._mux_socket:
1030 self._flush_results(self._mux_socket)
1030 self._flush_results(self._mux_socket)
1031 if self._task_socket:
1031 if self._task_socket:
1032 self._flush_results(self._task_socket)
1032 self._flush_results(self._task_socket)
1033 if self._control_socket:
1033 if self._control_socket:
1034 self._flush_control(self._control_socket)
1034 self._flush_control(self._control_socket)
1035 if self._query_socket:
1035 if self._query_socket:
1036 self._flush_ignored_hub_replies()
1036 self._flush_ignored_hub_replies()
1037
1037
1038 def wait(self, jobs=None, timeout=-1):
1038 def wait(self, jobs=None, timeout=-1):
1039 """waits on one or more `jobs`, for up to `timeout` seconds.
1039 """waits on one or more `jobs`, for up to `timeout` seconds.
1040
1040
1041 Parameters
1041 Parameters
1042 ----------
1042 ----------
1043
1043
1044 jobs : int, str, or list of ints and/or strs, or one or more AsyncResult objects
1044 jobs : int, str, or list of ints and/or strs, or one or more AsyncResult objects
1045 ints are indices to self.history
1045 ints are indices to self.history
1046 strs are msg_ids
1046 strs are msg_ids
1047 default: wait on all outstanding messages
1047 default: wait on all outstanding messages
1048 timeout : float
1048 timeout : float
1049 a time in seconds, after which to give up.
1049 a time in seconds, after which to give up.
1050 default is -1, which means no timeout
1050 default is -1, which means no timeout
1051
1051
1052 Returns
1052 Returns
1053 -------
1053 -------
1054
1054
1055 True : when all msg_ids are done
1055 True : when all msg_ids are done
1056 False : timeout reached, some msg_ids still outstanding
1056 False : timeout reached, some msg_ids still outstanding
1057 """
1057 """
1058 tic = time.time()
1058 tic = time.time()
1059 if jobs is None:
1059 if jobs is None:
1060 theids = self.outstanding
1060 theids = self.outstanding
1061 else:
1061 else:
1062 if isinstance(jobs, string_types + (int, AsyncResult)):
1062 if isinstance(jobs, string_types + (int, AsyncResult)):
1063 jobs = [jobs]
1063 jobs = [jobs]
1064 theids = set()
1064 theids = set()
1065 for job in jobs:
1065 for job in jobs:
1066 if isinstance(job, int):
1066 if isinstance(job, int):
1067 # index access
1067 # index access
1068 job = self.history[job]
1068 job = self.history[job]
1069 elif isinstance(job, AsyncResult):
1069 elif isinstance(job, AsyncResult):
1070 theids.update(job.msg_ids)
1070 theids.update(job.msg_ids)
1071 continue
1071 continue
1072 theids.add(job)
1072 theids.add(job)
1073 if not theids.intersection(self.outstanding):
1073 if not theids.intersection(self.outstanding):
1074 return True
1074 return True
1075 self.spin()
1075 self.spin()
1076 while theids.intersection(self.outstanding):
1076 while theids.intersection(self.outstanding):
1077 if timeout >= 0 and ( time.time()-tic ) > timeout:
1077 if timeout >= 0 and ( time.time()-tic ) > timeout:
1078 break
1078 break
1079 time.sleep(1e-3)
1079 time.sleep(1e-3)
1080 self.spin()
1080 self.spin()
1081 return len(theids.intersection(self.outstanding)) == 0
1081 return len(theids.intersection(self.outstanding)) == 0
1082
1082
1083 #--------------------------------------------------------------------------
1083 #--------------------------------------------------------------------------
1084 # Control methods
1084 # Control methods
1085 #--------------------------------------------------------------------------
1085 #--------------------------------------------------------------------------
1086
1086
1087 @spin_first
1087 @spin_first
1088 def clear(self, targets=None, block=None):
1088 def clear(self, targets=None, block=None):
1089 """Clear the namespace in target(s)."""
1089 """Clear the namespace in target(s)."""
1090 block = self.block if block is None else block
1090 block = self.block if block is None else block
1091 targets = self._build_targets(targets)[0]
1091 targets = self._build_targets(targets)[0]
1092 for t in targets:
1092 for t in targets:
1093 self.session.send(self._control_socket, 'clear_request', content={}, ident=t)
1093 self.session.send(self._control_socket, 'clear_request', content={}, ident=t)
1094 error = False
1094 error = False
1095 if block:
1095 if block:
1096 self._flush_ignored_control()
1096 self._flush_ignored_control()
1097 for i in range(len(targets)):
1097 for i in range(len(targets)):
1098 idents,msg = self.session.recv(self._control_socket,0)
1098 idents,msg = self.session.recv(self._control_socket,0)
1099 if self.debug:
1099 if self.debug:
1100 pprint(msg)
1100 pprint(msg)
1101 if msg['content']['status'] != 'ok':
1101 if msg['content']['status'] != 'ok':
1102 error = self._unwrap_exception(msg['content'])
1102 error = self._unwrap_exception(msg['content'])
1103 else:
1103 else:
1104 self._ignored_control_replies += len(targets)
1104 self._ignored_control_replies += len(targets)
1105 if error:
1105 if error:
1106 raise error
1106 raise error
1107
1107
1108
1108
1109 @spin_first
1109 @spin_first
1110 def abort(self, jobs=None, targets=None, block=None):
1110 def abort(self, jobs=None, targets=None, block=None):
1111 """Abort specific jobs from the execution queues of target(s).
1111 """Abort specific jobs from the execution queues of target(s).
1112
1112
1113 This is a mechanism to prevent jobs that have already been submitted
1113 This is a mechanism to prevent jobs that have already been submitted
1114 from executing.
1114 from executing.
1115
1115
1116 Parameters
1116 Parameters
1117 ----------
1117 ----------
1118
1118
1119 jobs : msg_id, list of msg_ids, or AsyncResult
1119 jobs : msg_id, list of msg_ids, or AsyncResult
1120 The jobs to be aborted
1120 The jobs to be aborted
1121
1121
1122 If unspecified/None: abort all outstanding jobs.
1122 If unspecified/None: abort all outstanding jobs.
1123
1123
1124 """
1124 """
1125 block = self.block if block is None else block
1125 block = self.block if block is None else block
1126 jobs = jobs if jobs is not None else list(self.outstanding)
1126 jobs = jobs if jobs is not None else list(self.outstanding)
1127 targets = self._build_targets(targets)[0]
1127 targets = self._build_targets(targets)[0]
1128
1128
1129 msg_ids = []
1129 msg_ids = []
1130 if isinstance(jobs, string_types + (AsyncResult,)):
1130 if isinstance(jobs, string_types + (AsyncResult,)):
1131 jobs = [jobs]
1131 jobs = [jobs]
1132 bad_ids = [obj for obj in jobs if not isinstance(obj, string_types + (AsyncResult,))]
1132 bad_ids = [obj for obj in jobs if not isinstance(obj, string_types + (AsyncResult,))]
1133 if bad_ids:
1133 if bad_ids:
1134 raise TypeError("Invalid msg_id type %r, expected str or AsyncResult"%bad_ids[0])
1134 raise TypeError("Invalid msg_id type %r, expected str or AsyncResult"%bad_ids[0])
1135 for j in jobs:
1135 for j in jobs:
1136 if isinstance(j, AsyncResult):
1136 if isinstance(j, AsyncResult):
1137 msg_ids.extend(j.msg_ids)
1137 msg_ids.extend(j.msg_ids)
1138 else:
1138 else:
1139 msg_ids.append(j)
1139 msg_ids.append(j)
1140 content = dict(msg_ids=msg_ids)
1140 content = dict(msg_ids=msg_ids)
1141 for t in targets:
1141 for t in targets:
1142 self.session.send(self._control_socket, 'abort_request',
1142 self.session.send(self._control_socket, 'abort_request',
1143 content=content, ident=t)
1143 content=content, ident=t)
1144 error = False
1144 error = False
1145 if block:
1145 if block:
1146 self._flush_ignored_control()
1146 self._flush_ignored_control()
1147 for i in range(len(targets)):
1147 for i in range(len(targets)):
1148 idents,msg = self.session.recv(self._control_socket,0)
1148 idents,msg = self.session.recv(self._control_socket,0)
1149 if self.debug:
1149 if self.debug:
1150 pprint(msg)
1150 pprint(msg)
1151 if msg['content']['status'] != 'ok':
1151 if msg['content']['status'] != 'ok':
1152 error = self._unwrap_exception(msg['content'])
1152 error = self._unwrap_exception(msg['content'])
1153 else:
1153 else:
1154 self._ignored_control_replies += len(targets)
1154 self._ignored_control_replies += len(targets)
1155 if error:
1155 if error:
1156 raise error
1156 raise error
1157
1157
1158 @spin_first
1158 @spin_first
1159 def shutdown(self, targets='all', restart=False, hub=False, block=None):
1159 def shutdown(self, targets='all', restart=False, hub=False, block=None):
1160 """Terminates one or more engine processes, optionally including the hub.
1160 """Terminates one or more engine processes, optionally including the hub.
1161
1161
1162 Parameters
1162 Parameters
1163 ----------
1163 ----------
1164
1164
1165 targets: list of ints or 'all' [default: all]
1165 targets: list of ints or 'all' [default: all]
1166 Which engines to shutdown.
1166 Which engines to shutdown.
1167 hub: bool [default: False]
1167 hub: bool [default: False]
1168 Whether to include the Hub. hub=True implies targets='all'.
1168 Whether to include the Hub. hub=True implies targets='all'.
1169 block: bool [default: self.block]
1169 block: bool [default: self.block]
1170 Whether to wait for clean shutdown replies or not.
1170 Whether to wait for clean shutdown replies or not.
1171 restart: bool [default: False]
1171 restart: bool [default: False]
1172 NOT IMPLEMENTED
1172 NOT IMPLEMENTED
1173 whether to restart engines after shutting them down.
1173 whether to restart engines after shutting them down.
1174 """
1174 """
1175 from IPython.parallel.error import NoEnginesRegistered
1175 from IPython.parallel.error import NoEnginesRegistered
1176 if restart:
1176 if restart:
1177 raise NotImplementedError("Engine restart is not yet implemented")
1177 raise NotImplementedError("Engine restart is not yet implemented")
1178
1178
1179 block = self.block if block is None else block
1179 block = self.block if block is None else block
1180 if hub:
1180 if hub:
1181 targets = 'all'
1181 targets = 'all'
1182 try:
1182 try:
1183 targets = self._build_targets(targets)[0]
1183 targets = self._build_targets(targets)[0]
1184 except NoEnginesRegistered:
1184 except NoEnginesRegistered:
1185 targets = []
1185 targets = []
1186 for t in targets:
1186 for t in targets:
1187 self.session.send(self._control_socket, 'shutdown_request',
1187 self.session.send(self._control_socket, 'shutdown_request',
1188 content={'restart':restart},ident=t)
1188 content={'restart':restart},ident=t)
1189 error = False
1189 error = False
1190 if block or hub:
1190 if block or hub:
1191 self._flush_ignored_control()
1191 self._flush_ignored_control()
1192 for i in range(len(targets)):
1192 for i in range(len(targets)):
1193 idents,msg = self.session.recv(self._control_socket, 0)
1193 idents,msg = self.session.recv(self._control_socket, 0)
1194 if self.debug:
1194 if self.debug:
1195 pprint(msg)
1195 pprint(msg)
1196 if msg['content']['status'] != 'ok':
1196 if msg['content']['status'] != 'ok':
1197 error = self._unwrap_exception(msg['content'])
1197 error = self._unwrap_exception(msg['content'])
1198 else:
1198 else:
1199 self._ignored_control_replies += len(targets)
1199 self._ignored_control_replies += len(targets)
1200
1200
1201 if hub:
1201 if hub:
1202 time.sleep(0.25)
1202 time.sleep(0.25)
1203 self.session.send(self._query_socket, 'shutdown_request')
1203 self.session.send(self._query_socket, 'shutdown_request')
1204 idents,msg = self.session.recv(self._query_socket, 0)
1204 idents,msg = self.session.recv(self._query_socket, 0)
1205 if self.debug:
1205 if self.debug:
1206 pprint(msg)
1206 pprint(msg)
1207 if msg['content']['status'] != 'ok':
1207 if msg['content']['status'] != 'ok':
1208 error = self._unwrap_exception(msg['content'])
1208 error = self._unwrap_exception(msg['content'])
1209
1209
1210 if error:
1210 if error:
1211 raise error
1211 raise error
1212
1212
1213 #--------------------------------------------------------------------------
1213 #--------------------------------------------------------------------------
1214 # Execution related methods
1214 # Execution related methods
1215 #--------------------------------------------------------------------------
1215 #--------------------------------------------------------------------------
1216
1216
1217 def _maybe_raise(self, result):
1217 def _maybe_raise(self, result):
1218 """wrapper for maybe raising an exception if apply failed."""
1218 """wrapper for maybe raising an exception if apply failed."""
1219 if isinstance(result, error.RemoteError):
1219 if isinstance(result, error.RemoteError):
1220 raise result
1220 raise result
1221
1221
1222 return result
1222 return result
1223
1223
1224 def send_apply_request(self, socket, f, args=None, kwargs=None, metadata=None, track=False,
1224 def send_apply_request(self, socket, f, args=None, kwargs=None, metadata=None, track=False,
1225 ident=None):
1225 ident=None):
1226 """construct and send an apply message via a socket.
1226 """construct and send an apply message via a socket.
1227
1227
1228 This is the principal method with which all engine execution is performed by views.
1228 This is the principal method with which all engine execution is performed by views.
1229 """
1229 """
1230
1230
1231 if self._closed:
1231 if self._closed:
1232 raise RuntimeError("Client cannot be used after its sockets have been closed")
1232 raise RuntimeError("Client cannot be used after its sockets have been closed")
1233
1233
1234 # defaults:
1234 # defaults:
1235 args = args if args is not None else []
1235 args = args if args is not None else []
1236 kwargs = kwargs if kwargs is not None else {}
1236 kwargs = kwargs if kwargs is not None else {}
1237 metadata = metadata if metadata is not None else {}
1237 metadata = metadata if metadata is not None else {}
1238
1238
1239 # validate arguments
1239 # validate arguments
1240 if not callable(f) and not isinstance(f, Reference):
1240 if not callable(f) and not isinstance(f, Reference):
1241 raise TypeError("f must be callable, not %s"%type(f))
1241 raise TypeError("f must be callable, not %s"%type(f))
1242 if not isinstance(args, (tuple, list)):
1242 if not isinstance(args, (tuple, list)):
1243 raise TypeError("args must be tuple or list, not %s"%type(args))
1243 raise TypeError("args must be tuple or list, not %s"%type(args))
1244 if not isinstance(kwargs, dict):
1244 if not isinstance(kwargs, dict):
1245 raise TypeError("kwargs must be dict, not %s"%type(kwargs))
1245 raise TypeError("kwargs must be dict, not %s"%type(kwargs))
1246 if not isinstance(metadata, dict):
1246 if not isinstance(metadata, dict):
1247 raise TypeError("metadata must be dict, not %s"%type(metadata))
1247 raise TypeError("metadata must be dict, not %s"%type(metadata))
1248
1248
1249 bufs = serialize.pack_apply_message(f, args, kwargs,
1249 bufs = serialize.pack_apply_message(f, args, kwargs,
1250 buffer_threshold=self.session.buffer_threshold,
1250 buffer_threshold=self.session.buffer_threshold,
1251 item_threshold=self.session.item_threshold,
1251 item_threshold=self.session.item_threshold,
1252 )
1252 )
1253
1253
1254 msg = self.session.send(socket, "apply_request", buffers=bufs, ident=ident,
1254 msg = self.session.send(socket, "apply_request", buffers=bufs, ident=ident,
1255 metadata=metadata, track=track)
1255 metadata=metadata, track=track)
1256
1256
1257 msg_id = msg['header']['msg_id']
1257 msg_id = msg['header']['msg_id']
1258 self.outstanding.add(msg_id)
1258 self.outstanding.add(msg_id)
1259 if ident:
1259 if ident:
1260 # possibly routed to a specific engine
1260 # possibly routed to a specific engine
1261 if isinstance(ident, list):
1261 if isinstance(ident, list):
1262 ident = ident[-1]
1262 ident = ident[-1]
1263 if ident in self._engines.values():
1263 if ident in self._engines.values():
1264 # save for later, in case of engine death
1264 # save for later, in case of engine death
1265 self._outstanding_dict[ident].add(msg_id)
1265 self._outstanding_dict[ident].add(msg_id)
1266 self.history.append(msg_id)
1266 self.history.append(msg_id)
1267 self.metadata[msg_id]['submitted'] = datetime.now()
1267 self.metadata[msg_id]['submitted'] = datetime.now()
1268
1268
1269 return msg
1269 return msg
1270
1270
1271 def send_execute_request(self, socket, code, silent=True, metadata=None, ident=None):
1271 def send_execute_request(self, socket, code, silent=True, metadata=None, ident=None):
1272 """construct and send an execute request via a socket.
1272 """construct and send an execute request via a socket.
1273
1273
1274 """
1274 """
1275
1275
1276 if self._closed:
1276 if self._closed:
1277 raise RuntimeError("Client cannot be used after its sockets have been closed")
1277 raise RuntimeError("Client cannot be used after its sockets have been closed")
1278
1278
1279 # defaults:
1279 # defaults:
1280 metadata = metadata if metadata is not None else {}
1280 metadata = metadata if metadata is not None else {}
1281
1281
1282 # validate arguments
1282 # validate arguments
1283 if not isinstance(code, string_types):
1283 if not isinstance(code, string_types):
1284 raise TypeError("code must be text, not %s" % type(code))
1284 raise TypeError("code must be text, not %s" % type(code))
1285 if not isinstance(metadata, dict):
1285 if not isinstance(metadata, dict):
1286 raise TypeError("metadata must be dict, not %s" % type(metadata))
1286 raise TypeError("metadata must be dict, not %s" % type(metadata))
1287
1287
1288 content = dict(code=code, silent=bool(silent), user_expressions={})
1288 content = dict(code=code, silent=bool(silent), user_expressions={})
1289
1289
1290
1290
1291 msg = self.session.send(socket, "execute_request", content=content, ident=ident,
1291 msg = self.session.send(socket, "execute_request", content=content, ident=ident,
1292 metadata=metadata)
1292 metadata=metadata)
1293
1293
1294 msg_id = msg['header']['msg_id']
1294 msg_id = msg['header']['msg_id']
1295 self.outstanding.add(msg_id)
1295 self.outstanding.add(msg_id)
1296 if ident:
1296 if ident:
1297 # possibly routed to a specific engine
1297 # possibly routed to a specific engine
1298 if isinstance(ident, list):
1298 if isinstance(ident, list):
1299 ident = ident[-1]
1299 ident = ident[-1]
1300 if ident in self._engines.values():
1300 if ident in self._engines.values():
1301 # save for later, in case of engine death
1301 # save for later, in case of engine death
1302 self._outstanding_dict[ident].add(msg_id)
1302 self._outstanding_dict[ident].add(msg_id)
1303 self.history.append(msg_id)
1303 self.history.append(msg_id)
1304 self.metadata[msg_id]['submitted'] = datetime.now()
1304 self.metadata[msg_id]['submitted'] = datetime.now()
1305
1305
1306 return msg
1306 return msg
1307
1307
1308 #--------------------------------------------------------------------------
1308 #--------------------------------------------------------------------------
1309 # construct a View object
1309 # construct a View object
1310 #--------------------------------------------------------------------------
1310 #--------------------------------------------------------------------------
1311
1311
1312 def load_balanced_view(self, targets=None):
1312 def load_balanced_view(self, targets=None):
1313 """construct a DirectView object.
1313 """construct a DirectView object.
1314
1314
1315 If no arguments are specified, create a LoadBalancedView
1315 If no arguments are specified, create a LoadBalancedView
1316 using all engines.
1316 using all engines.
1317
1317
1318 Parameters
1318 Parameters
1319 ----------
1319 ----------
1320
1320
1321 targets: list,slice,int,etc. [default: use all engines]
1321 targets: list,slice,int,etc. [default: use all engines]
1322 The subset of engines across which to load-balance
1322 The subset of engines across which to load-balance
1323 """
1323 """
1324 if targets == 'all':
1324 if targets == 'all':
1325 targets = None
1325 targets = None
1326 if targets is not None:
1326 if targets is not None:
1327 targets = self._build_targets(targets)[1]
1327 targets = self._build_targets(targets)[1]
1328 return LoadBalancedView(client=self, socket=self._task_socket, targets=targets)
1328 return LoadBalancedView(client=self, socket=self._task_socket, targets=targets)
1329
1329
1330 def direct_view(self, targets='all'):
1330 def direct_view(self, targets='all'):
1331 """construct a DirectView object.
1331 """construct a DirectView object.
1332
1332
1333 If no targets are specified, create a DirectView using all engines.
1333 If no targets are specified, create a DirectView using all engines.
1334
1334
1335 rc.direct_view('all') is distinguished from rc[:] in that 'all' will
1335 rc.direct_view('all') is distinguished from rc[:] in that 'all' will
1336 evaluate the target engines at each execution, whereas rc[:] will connect to
1336 evaluate the target engines at each execution, whereas rc[:] will connect to
1337 all *current* engines, and that list will not change.
1337 all *current* engines, and that list will not change.
1338
1338
1339 That is, 'all' will always use all engines, whereas rc[:] will not use
1339 That is, 'all' will always use all engines, whereas rc[:] will not use
1340 engines added after the DirectView is constructed.
1340 engines added after the DirectView is constructed.
1341
1341
1342 Parameters
1342 Parameters
1343 ----------
1343 ----------
1344
1344
1345 targets: list,slice,int,etc. [default: use all engines]
1345 targets: list,slice,int,etc. [default: use all engines]
1346 The engines to use for the View
1346 The engines to use for the View
1347 """
1347 """
1348 single = isinstance(targets, int)
1348 single = isinstance(targets, int)
1349 # allow 'all' to be lazily evaluated at each execution
1349 # allow 'all' to be lazily evaluated at each execution
1350 if targets != 'all':
1350 if targets != 'all':
1351 targets = self._build_targets(targets)[1]
1351 targets = self._build_targets(targets)[1]
1352 if single:
1352 if single:
1353 targets = targets[0]
1353 targets = targets[0]
1354 return DirectView(client=self, socket=self._mux_socket, targets=targets)
1354 return DirectView(client=self, socket=self._mux_socket, targets=targets)
1355
1355
1356 #--------------------------------------------------------------------------
1356 #--------------------------------------------------------------------------
1357 # Query methods
1357 # Query methods
1358 #--------------------------------------------------------------------------
1358 #--------------------------------------------------------------------------
1359
1359
1360 @spin_first
1360 @spin_first
1361 def get_result(self, indices_or_msg_ids=None, block=None):
1361 def get_result(self, indices_or_msg_ids=None, block=None, owner=True):
1362 """Retrieve a result by msg_id or history index, wrapped in an AsyncResult object.
1362 """Retrieve a result by msg_id or history index, wrapped in an AsyncResult object.
1363
1363
1364 If the client already has the results, no request to the Hub will be made.
1364 If the client already has the results, no request to the Hub will be made.
1365
1365
1366 This is a convenient way to construct AsyncResult objects, which are wrappers
1366 This is a convenient way to construct AsyncResult objects, which are wrappers
1367 that include metadata about execution, and allow for awaiting results that
1367 that include metadata about execution, and allow for awaiting results that
1368 were not submitted by this Client.
1368 were not submitted by this Client.
1369
1369
1370 It can also be a convenient way to retrieve the metadata associated with
1370 It can also be a convenient way to retrieve the metadata associated with
1371 blocking execution, since it always retrieves
1371 blocking execution, since it always retrieves
1372
1372
1373 Examples
1373 Examples
1374 --------
1374 --------
1375 ::
1375 ::
1376
1376
1377 In [10]: r = client.apply()
1377 In [10]: r = client.apply()
1378
1378
1379 Parameters
1379 Parameters
1380 ----------
1380 ----------
1381
1381
1382 indices_or_msg_ids : integer history index, str msg_id, or list of either
1382 indices_or_msg_ids : integer history index, str msg_id, or list of either
1383 The indices or msg_ids of indices to be retrieved
1383 The indices or msg_ids of indices to be retrieved
1384
1384
1385 block : bool
1385 block : bool
1386 Whether to wait for the result to be done
1386 Whether to wait for the result to be done
1387 owner : bool [default: True]
1388 Whether this AsyncResult should own the result.
1389 If so, calling `ar.get()` will remove data from the
1390 client's result and metadata cache.
1391 There should only be one owner of any given msg_id.
1387
1392
1388 Returns
1393 Returns
1389 -------
1394 -------
1390
1395
1391 AsyncResult
1396 AsyncResult
1392 A single AsyncResult object will always be returned.
1397 A single AsyncResult object will always be returned.
1393
1398
1394 AsyncHubResult
1399 AsyncHubResult
1395 A subclass of AsyncResult that retrieves results from the Hub
1400 A subclass of AsyncResult that retrieves results from the Hub
1396
1401
1397 """
1402 """
1398 block = self.block if block is None else block
1403 block = self.block if block is None else block
1399 if indices_or_msg_ids is None:
1404 if indices_or_msg_ids is None:
1400 indices_or_msg_ids = -1
1405 indices_or_msg_ids = -1
1401
1406
1402 single_result = False
1407 single_result = False
1403 if not isinstance(indices_or_msg_ids, (list,tuple)):
1408 if not isinstance(indices_or_msg_ids, (list,tuple)):
1404 indices_or_msg_ids = [indices_or_msg_ids]
1409 indices_or_msg_ids = [indices_or_msg_ids]
1405 single_result = True
1410 single_result = True
1406
1411
1407 theids = []
1412 theids = []
1408 for id in indices_or_msg_ids:
1413 for id in indices_or_msg_ids:
1409 if isinstance(id, int):
1414 if isinstance(id, int):
1410 id = self.history[id]
1415 id = self.history[id]
1411 if not isinstance(id, string_types):
1416 if not isinstance(id, string_types):
1412 raise TypeError("indices must be str or int, not %r"%id)
1417 raise TypeError("indices must be str or int, not %r"%id)
1413 theids.append(id)
1418 theids.append(id)
1414
1419
1415 local_ids = [msg_id for msg_id in theids if (msg_id in self.outstanding or msg_id in self.results)]
1420 local_ids = [msg_id for msg_id in theids if (msg_id in self.outstanding or msg_id in self.results)]
1416 remote_ids = [msg_id for msg_id in theids if msg_id not in local_ids]
1421 remote_ids = [msg_id for msg_id in theids if msg_id not in local_ids]
1417
1422
1418 # given single msg_id initially, get_result shot get the result itself,
1423 # given single msg_id initially, get_result shot get the result itself,
1419 # not a length-one list
1424 # not a length-one list
1420 if single_result:
1425 if single_result:
1421 theids = theids[0]
1426 theids = theids[0]
1422
1427
1423 if remote_ids:
1428 if remote_ids:
1424 ar = AsyncHubResult(self, msg_ids=theids)
1429 ar = AsyncHubResult(self, msg_ids=theids, owner=owner)
1425 else:
1430 else:
1426 ar = AsyncResult(self, msg_ids=theids)
1431 ar = AsyncResult(self, msg_ids=theids, owner=owner)
1427
1432
1428 if block:
1433 if block:
1429 ar.wait()
1434 ar.wait()
1430
1435
1431 return ar
1436 return ar
1432
1437
1433 @spin_first
1438 @spin_first
1434 def resubmit(self, indices_or_msg_ids=None, metadata=None, block=None):
1439 def resubmit(self, indices_or_msg_ids=None, metadata=None, block=None):
1435 """Resubmit one or more tasks.
1440 """Resubmit one or more tasks.
1436
1441
1437 in-flight tasks may not be resubmitted.
1442 in-flight tasks may not be resubmitted.
1438
1443
1439 Parameters
1444 Parameters
1440 ----------
1445 ----------
1441
1446
1442 indices_or_msg_ids : integer history index, str msg_id, or list of either
1447 indices_or_msg_ids : integer history index, str msg_id, or list of either
1443 The indices or msg_ids of indices to be retrieved
1448 The indices or msg_ids of indices to be retrieved
1444
1449
1445 block : bool
1450 block : bool
1446 Whether to wait for the result to be done
1451 Whether to wait for the result to be done
1447
1452
1448 Returns
1453 Returns
1449 -------
1454 -------
1450
1455
1451 AsyncHubResult
1456 AsyncHubResult
1452 A subclass of AsyncResult that retrieves results from the Hub
1457 A subclass of AsyncResult that retrieves results from the Hub
1453
1458
1454 """
1459 """
1455 block = self.block if block is None else block
1460 block = self.block if block is None else block
1456 if indices_or_msg_ids is None:
1461 if indices_or_msg_ids is None:
1457 indices_or_msg_ids = -1
1462 indices_or_msg_ids = -1
1458
1463
1459 if not isinstance(indices_or_msg_ids, (list,tuple)):
1464 if not isinstance(indices_or_msg_ids, (list,tuple)):
1460 indices_or_msg_ids = [indices_or_msg_ids]
1465 indices_or_msg_ids = [indices_or_msg_ids]
1461
1466
1462 theids = []
1467 theids = []
1463 for id in indices_or_msg_ids:
1468 for id in indices_or_msg_ids:
1464 if isinstance(id, int):
1469 if isinstance(id, int):
1465 id = self.history[id]
1470 id = self.history[id]
1466 if not isinstance(id, string_types):
1471 if not isinstance(id, string_types):
1467 raise TypeError("indices must be str or int, not %r"%id)
1472 raise TypeError("indices must be str or int, not %r"%id)
1468 theids.append(id)
1473 theids.append(id)
1469
1474
1470 content = dict(msg_ids = theids)
1475 content = dict(msg_ids = theids)
1471
1476
1472 self.session.send(self._query_socket, 'resubmit_request', content)
1477 self.session.send(self._query_socket, 'resubmit_request', content)
1473
1478
1474 zmq.select([self._query_socket], [], [])
1479 zmq.select([self._query_socket], [], [])
1475 idents,msg = self.session.recv(self._query_socket, zmq.NOBLOCK)
1480 idents,msg = self.session.recv(self._query_socket, zmq.NOBLOCK)
1476 if self.debug:
1481 if self.debug:
1477 pprint(msg)
1482 pprint(msg)
1478 content = msg['content']
1483 content = msg['content']
1479 if content['status'] != 'ok':
1484 if content['status'] != 'ok':
1480 raise self._unwrap_exception(content)
1485 raise self._unwrap_exception(content)
1481 mapping = content['resubmitted']
1486 mapping = content['resubmitted']
1482 new_ids = [ mapping[msg_id] for msg_id in theids ]
1487 new_ids = [ mapping[msg_id] for msg_id in theids ]
1483
1488
1484 ar = AsyncHubResult(self, msg_ids=new_ids)
1489 ar = AsyncHubResult(self, msg_ids=new_ids)
1485
1490
1486 if block:
1491 if block:
1487 ar.wait()
1492 ar.wait()
1488
1493
1489 return ar
1494 return ar
1490
1495
1491 @spin_first
1496 @spin_first
1492 def result_status(self, msg_ids, status_only=True):
1497 def result_status(self, msg_ids, status_only=True):
1493 """Check on the status of the result(s) of the apply request with `msg_ids`.
1498 """Check on the status of the result(s) of the apply request with `msg_ids`.
1494
1499
1495 If status_only is False, then the actual results will be retrieved, else
1500 If status_only is False, then the actual results will be retrieved, else
1496 only the status of the results will be checked.
1501 only the status of the results will be checked.
1497
1502
1498 Parameters
1503 Parameters
1499 ----------
1504 ----------
1500
1505
1501 msg_ids : list of msg_ids
1506 msg_ids : list of msg_ids
1502 if int:
1507 if int:
1503 Passed as index to self.history for convenience.
1508 Passed as index to self.history for convenience.
1504 status_only : bool (default: True)
1509 status_only : bool (default: True)
1505 if False:
1510 if False:
1506 Retrieve the actual results of completed tasks.
1511 Retrieve the actual results of completed tasks.
1507
1512
1508 Returns
1513 Returns
1509 -------
1514 -------
1510
1515
1511 results : dict
1516 results : dict
1512 There will always be the keys 'pending' and 'completed', which will
1517 There will always be the keys 'pending' and 'completed', which will
1513 be lists of msg_ids that are incomplete or complete. If `status_only`
1518 be lists of msg_ids that are incomplete or complete. If `status_only`
1514 is False, then completed results will be keyed by their `msg_id`.
1519 is False, then completed results will be keyed by their `msg_id`.
1515 """
1520 """
1516 if not isinstance(msg_ids, (list,tuple)):
1521 if not isinstance(msg_ids, (list,tuple)):
1517 msg_ids = [msg_ids]
1522 msg_ids = [msg_ids]
1518
1523
1519 theids = []
1524 theids = []
1520 for msg_id in msg_ids:
1525 for msg_id in msg_ids:
1521 if isinstance(msg_id, int):
1526 if isinstance(msg_id, int):
1522 msg_id = self.history[msg_id]
1527 msg_id = self.history[msg_id]
1523 if not isinstance(msg_id, string_types):
1528 if not isinstance(msg_id, string_types):
1524 raise TypeError("msg_ids must be str, not %r"%msg_id)
1529 raise TypeError("msg_ids must be str, not %r"%msg_id)
1525 theids.append(msg_id)
1530 theids.append(msg_id)
1526
1531
1527 completed = []
1532 completed = []
1528 local_results = {}
1533 local_results = {}
1529
1534
1530 # comment this block out to temporarily disable local shortcut:
1535 # comment this block out to temporarily disable local shortcut:
1531 for msg_id in theids:
1536 for msg_id in theids:
1532 if msg_id in self.results:
1537 if msg_id in self.results:
1533 completed.append(msg_id)
1538 completed.append(msg_id)
1534 local_results[msg_id] = self.results[msg_id]
1539 local_results[msg_id] = self.results[msg_id]
1535 theids.remove(msg_id)
1540 theids.remove(msg_id)
1536
1541
1537 if theids: # some not locally cached
1542 if theids: # some not locally cached
1538 content = dict(msg_ids=theids, status_only=status_only)
1543 content = dict(msg_ids=theids, status_only=status_only)
1539 msg = self.session.send(self._query_socket, "result_request", content=content)
1544 msg = self.session.send(self._query_socket, "result_request", content=content)
1540 zmq.select([self._query_socket], [], [])
1545 zmq.select([self._query_socket], [], [])
1541 idents,msg = self.session.recv(self._query_socket, zmq.NOBLOCK)
1546 idents,msg = self.session.recv(self._query_socket, zmq.NOBLOCK)
1542 if self.debug:
1547 if self.debug:
1543 pprint(msg)
1548 pprint(msg)
1544 content = msg['content']
1549 content = msg['content']
1545 if content['status'] != 'ok':
1550 if content['status'] != 'ok':
1546 raise self._unwrap_exception(content)
1551 raise self._unwrap_exception(content)
1547 buffers = msg['buffers']
1552 buffers = msg['buffers']
1548 else:
1553 else:
1549 content = dict(completed=[],pending=[])
1554 content = dict(completed=[],pending=[])
1550
1555
1551 content['completed'].extend(completed)
1556 content['completed'].extend(completed)
1552
1557
1553 if status_only:
1558 if status_only:
1554 return content
1559 return content
1555
1560
1556 failures = []
1561 failures = []
1557 # load cached results into result:
1562 # load cached results into result:
1558 content.update(local_results)
1563 content.update(local_results)
1559
1564
1560 # update cache with results:
1565 # update cache with results:
1561 for msg_id in sorted(theids):
1566 for msg_id in sorted(theids):
1562 if msg_id in content['completed']:
1567 if msg_id in content['completed']:
1563 rec = content[msg_id]
1568 rec = content[msg_id]
1564 parent = extract_dates(rec['header'])
1569 parent = extract_dates(rec['header'])
1565 header = extract_dates(rec['result_header'])
1570 header = extract_dates(rec['result_header'])
1566 rcontent = rec['result_content']
1571 rcontent = rec['result_content']
1567 iodict = rec['io']
1572 iodict = rec['io']
1568 if isinstance(rcontent, str):
1573 if isinstance(rcontent, str):
1569 rcontent = self.session.unpack(rcontent)
1574 rcontent = self.session.unpack(rcontent)
1570
1575
1571 md = self.metadata[msg_id]
1576 md = self.metadata[msg_id]
1572 md_msg = dict(
1577 md_msg = dict(
1573 content=rcontent,
1578 content=rcontent,
1574 parent_header=parent,
1579 parent_header=parent,
1575 header=header,
1580 header=header,
1576 metadata=rec['result_metadata'],
1581 metadata=rec['result_metadata'],
1577 )
1582 )
1578 md.update(self._extract_metadata(md_msg))
1583 md.update(self._extract_metadata(md_msg))
1579 if rec.get('received'):
1584 if rec.get('received'):
1580 md['received'] = parse_date(rec['received'])
1585 md['received'] = parse_date(rec['received'])
1581 md.update(iodict)
1586 md.update(iodict)
1582
1587
1583 if rcontent['status'] == 'ok':
1588 if rcontent['status'] == 'ok':
1584 if header['msg_type'] == 'apply_reply':
1589 if header['msg_type'] == 'apply_reply':
1585 res,buffers = serialize.unserialize_object(buffers)
1590 res,buffers = serialize.unserialize_object(buffers)
1586 elif header['msg_type'] == 'execute_reply':
1591 elif header['msg_type'] == 'execute_reply':
1587 res = ExecuteReply(msg_id, rcontent, md)
1592 res = ExecuteReply(msg_id, rcontent, md)
1588 else:
1593 else:
1589 raise KeyError("unhandled msg type: %r" % header['msg_type'])
1594 raise KeyError("unhandled msg type: %r" % header['msg_type'])
1590 else:
1595 else:
1591 res = self._unwrap_exception(rcontent)
1596 res = self._unwrap_exception(rcontent)
1592 failures.append(res)
1597 failures.append(res)
1593
1598
1594 self.results[msg_id] = res
1599 self.results[msg_id] = res
1595 content[msg_id] = res
1600 content[msg_id] = res
1596
1601
1597 if len(theids) == 1 and failures:
1602 if len(theids) == 1 and failures:
1598 raise failures[0]
1603 raise failures[0]
1599
1604
1600 error.collect_exceptions(failures, "result_status")
1605 error.collect_exceptions(failures, "result_status")
1601 return content
1606 return content
1602
1607
1603 @spin_first
1608 @spin_first
1604 def queue_status(self, targets='all', verbose=False):
1609 def queue_status(self, targets='all', verbose=False):
1605 """Fetch the status of engine queues.
1610 """Fetch the status of engine queues.
1606
1611
1607 Parameters
1612 Parameters
1608 ----------
1613 ----------
1609
1614
1610 targets : int/str/list of ints/strs
1615 targets : int/str/list of ints/strs
1611 the engines whose states are to be queried.
1616 the engines whose states are to be queried.
1612 default : all
1617 default : all
1613 verbose : bool
1618 verbose : bool
1614 Whether to return lengths only, or lists of ids for each element
1619 Whether to return lengths only, or lists of ids for each element
1615 """
1620 """
1616 if targets == 'all':
1621 if targets == 'all':
1617 # allow 'all' to be evaluated on the engine
1622 # allow 'all' to be evaluated on the engine
1618 engine_ids = None
1623 engine_ids = None
1619 else:
1624 else:
1620 engine_ids = self._build_targets(targets)[1]
1625 engine_ids = self._build_targets(targets)[1]
1621 content = dict(targets=engine_ids, verbose=verbose)
1626 content = dict(targets=engine_ids, verbose=verbose)
1622 self.session.send(self._query_socket, "queue_request", content=content)
1627 self.session.send(self._query_socket, "queue_request", content=content)
1623 idents,msg = self.session.recv(self._query_socket, 0)
1628 idents,msg = self.session.recv(self._query_socket, 0)
1624 if self.debug:
1629 if self.debug:
1625 pprint(msg)
1630 pprint(msg)
1626 content = msg['content']
1631 content = msg['content']
1627 status = content.pop('status')
1632 status = content.pop('status')
1628 if status != 'ok':
1633 if status != 'ok':
1629 raise self._unwrap_exception(content)
1634 raise self._unwrap_exception(content)
1630 content = rekey(content)
1635 content = rekey(content)
1631 if isinstance(targets, int):
1636 if isinstance(targets, int):
1632 return content[targets]
1637 return content[targets]
1633 else:
1638 else:
1634 return content
1639 return content
1635
1640
1636 def _build_msgids_from_target(self, targets=None):
1641 def _build_msgids_from_target(self, targets=None):
1637 """Build a list of msg_ids from the list of engine targets"""
1642 """Build a list of msg_ids from the list of engine targets"""
1638 if not targets: # needed as _build_targets otherwise uses all engines
1643 if not targets: # needed as _build_targets otherwise uses all engines
1639 return []
1644 return []
1640 target_ids = self._build_targets(targets)[0]
1645 target_ids = self._build_targets(targets)[0]
1641 return [md_id for md_id in self.metadata if self.metadata[md_id]["engine_uuid"] in target_ids]
1646 return [md_id for md_id in self.metadata if self.metadata[md_id]["engine_uuid"] in target_ids]
1642
1647
1643 def _build_msgids_from_jobs(self, jobs=None):
1648 def _build_msgids_from_jobs(self, jobs=None):
1644 """Build a list of msg_ids from "jobs" """
1649 """Build a list of msg_ids from "jobs" """
1645 if not jobs:
1650 if not jobs:
1646 return []
1651 return []
1647 msg_ids = []
1652 msg_ids = []
1648 if isinstance(jobs, string_types + (AsyncResult,)):
1653 if isinstance(jobs, string_types + (AsyncResult,)):
1649 jobs = [jobs]
1654 jobs = [jobs]
1650 bad_ids = [obj for obj in jobs if not isinstance(obj, string_types + (AsyncResult,))]
1655 bad_ids = [obj for obj in jobs if not isinstance(obj, string_types + (AsyncResult,))]
1651 if bad_ids:
1656 if bad_ids:
1652 raise TypeError("Invalid msg_id type %r, expected str or AsyncResult"%bad_ids[0])
1657 raise TypeError("Invalid msg_id type %r, expected str or AsyncResult"%bad_ids[0])
1653 for j in jobs:
1658 for j in jobs:
1654 if isinstance(j, AsyncResult):
1659 if isinstance(j, AsyncResult):
1655 msg_ids.extend(j.msg_ids)
1660 msg_ids.extend(j.msg_ids)
1656 else:
1661 else:
1657 msg_ids.append(j)
1662 msg_ids.append(j)
1658 return msg_ids
1663 return msg_ids
1659
1664
1660 def purge_local_results(self, jobs=[], targets=[]):
1665 def purge_local_results(self, jobs=[], targets=[]):
1661 """Clears the client caches of results and their metadata.
1666 """Clears the client caches of results and their metadata.
1662
1667
1663 Individual results can be purged by msg_id, or the entire
1668 Individual results can be purged by msg_id, or the entire
1664 history of specific targets can be purged.
1669 history of specific targets can be purged.
1665
1670
1666 Use `purge_local_results('all')` to scrub everything from the Clients's
1671 Use `purge_local_results('all')` to scrub everything from the Clients's
1667 results and metadata caches.
1672 results and metadata caches.
1668
1673
1669 After this call all `AsyncResults` are invalid and should be discarded.
1674 After this call all `AsyncResults` are invalid and should be discarded.
1670
1675
1671 If you must "reget" the results, you can still do so by using
1676 If you must "reget" the results, you can still do so by using
1672 `client.get_result(msg_id)` or `client.get_result(asyncresult)`. This will
1677 `client.get_result(msg_id)` or `client.get_result(asyncresult)`. This will
1673 redownload the results from the hub if they are still available
1678 redownload the results from the hub if they are still available
1674 (i.e `client.purge_hub_results(...)` has not been called.
1679 (i.e `client.purge_hub_results(...)` has not been called.
1675
1680
1676 Parameters
1681 Parameters
1677 ----------
1682 ----------
1678
1683
1679 jobs : str or list of str or AsyncResult objects
1684 jobs : str or list of str or AsyncResult objects
1680 the msg_ids whose results should be purged.
1685 the msg_ids whose results should be purged.
1681 targets : int/list of ints
1686 targets : int/list of ints
1682 The engines, by integer ID, whose entire result histories are to be purged.
1687 The engines, by integer ID, whose entire result histories are to be purged.
1683
1688
1684 Raises
1689 Raises
1685 ------
1690 ------
1686
1691
1687 RuntimeError : if any of the tasks to be purged are still outstanding.
1692 RuntimeError : if any of the tasks to be purged are still outstanding.
1688
1693
1689 """
1694 """
1690 if not targets and not jobs:
1695 if not targets and not jobs:
1691 raise ValueError("Must specify at least one of `targets` and `jobs`")
1696 raise ValueError("Must specify at least one of `targets` and `jobs`")
1692
1697
1693 if jobs == 'all':
1698 if jobs == 'all':
1694 if self.outstanding:
1699 if self.outstanding:
1695 raise RuntimeError("Can't purge outstanding tasks: %s" % self.outstanding)
1700 raise RuntimeError("Can't purge outstanding tasks: %s" % self.outstanding)
1696 self.results.clear()
1701 self.results.clear()
1697 self.metadata.clear()
1702 self.metadata.clear()
1698 else:
1703 else:
1699 msg_ids = set()
1704 msg_ids = set()
1700 msg_ids.update(self._build_msgids_from_target(targets))
1705 msg_ids.update(self._build_msgids_from_target(targets))
1701 msg_ids.update(self._build_msgids_from_jobs(jobs))
1706 msg_ids.update(self._build_msgids_from_jobs(jobs))
1702 still_outstanding = self.outstanding.intersection(msg_ids)
1707 still_outstanding = self.outstanding.intersection(msg_ids)
1703 if still_outstanding:
1708 if still_outstanding:
1704 raise RuntimeError("Can't purge outstanding tasks: %s" % still_outstanding)
1709 raise RuntimeError("Can't purge outstanding tasks: %s" % still_outstanding)
1705 for mid in msg_ids:
1710 for mid in msg_ids:
1706 self.results.pop(mid)
1711 self.results.pop(mid, None)
1707 self.metadata.pop(mid)
1712 self.metadata.pop(mid, None)
1708
1713
1709
1714
1710 @spin_first
1715 @spin_first
1711 def purge_hub_results(self, jobs=[], targets=[]):
1716 def purge_hub_results(self, jobs=[], targets=[]):
1712 """Tell the Hub to forget results.
1717 """Tell the Hub to forget results.
1713
1718
1714 Individual results can be purged by msg_id, or the entire
1719 Individual results can be purged by msg_id, or the entire
1715 history of specific targets can be purged.
1720 history of specific targets can be purged.
1716
1721
1717 Use `purge_results('all')` to scrub everything from the Hub's db.
1722 Use `purge_results('all')` to scrub everything from the Hub's db.
1718
1723
1719 Parameters
1724 Parameters
1720 ----------
1725 ----------
1721
1726
1722 jobs : str or list of str or AsyncResult objects
1727 jobs : str or list of str or AsyncResult objects
1723 the msg_ids whose results should be forgotten.
1728 the msg_ids whose results should be forgotten.
1724 targets : int/str/list of ints/strs
1729 targets : int/str/list of ints/strs
1725 The targets, by int_id, whose entire history is to be purged.
1730 The targets, by int_id, whose entire history is to be purged.
1726
1731
1727 default : None
1732 default : None
1728 """
1733 """
1729 if not targets and not jobs:
1734 if not targets and not jobs:
1730 raise ValueError("Must specify at least one of `targets` and `jobs`")
1735 raise ValueError("Must specify at least one of `targets` and `jobs`")
1731 if targets:
1736 if targets:
1732 targets = self._build_targets(targets)[1]
1737 targets = self._build_targets(targets)[1]
1733
1738
1734 # construct msg_ids from jobs
1739 # construct msg_ids from jobs
1735 if jobs == 'all':
1740 if jobs == 'all':
1736 msg_ids = jobs
1741 msg_ids = jobs
1737 else:
1742 else:
1738 msg_ids = self._build_msgids_from_jobs(jobs)
1743 msg_ids = self._build_msgids_from_jobs(jobs)
1739
1744
1740 content = dict(engine_ids=targets, msg_ids=msg_ids)
1745 content = dict(engine_ids=targets, msg_ids=msg_ids)
1741 self.session.send(self._query_socket, "purge_request", content=content)
1746 self.session.send(self._query_socket, "purge_request", content=content)
1742 idents, msg = self.session.recv(self._query_socket, 0)
1747 idents, msg = self.session.recv(self._query_socket, 0)
1743 if self.debug:
1748 if self.debug:
1744 pprint(msg)
1749 pprint(msg)
1745 content = msg['content']
1750 content = msg['content']
1746 if content['status'] != 'ok':
1751 if content['status'] != 'ok':
1747 raise self._unwrap_exception(content)
1752 raise self._unwrap_exception(content)
1748
1753
1749 def purge_results(self, jobs=[], targets=[]):
1754 def purge_results(self, jobs=[], targets=[]):
1750 """Clears the cached results from both the hub and the local client
1755 """Clears the cached results from both the hub and the local client
1751
1756
1752 Individual results can be purged by msg_id, or the entire
1757 Individual results can be purged by msg_id, or the entire
1753 history of specific targets can be purged.
1758 history of specific targets can be purged.
1754
1759
1755 Use `purge_results('all')` to scrub every cached result from both the Hub's and
1760 Use `purge_results('all')` to scrub every cached result from both the Hub's and
1756 the Client's db.
1761 the Client's db.
1757
1762
1758 Equivalent to calling both `purge_hub_results()` and `purge_client_results()` with
1763 Equivalent to calling both `purge_hub_results()` and `purge_client_results()` with
1759 the same arguments.
1764 the same arguments.
1760
1765
1761 Parameters
1766 Parameters
1762 ----------
1767 ----------
1763
1768
1764 jobs : str or list of str or AsyncResult objects
1769 jobs : str or list of str or AsyncResult objects
1765 the msg_ids whose results should be forgotten.
1770 the msg_ids whose results should be forgotten.
1766 targets : int/str/list of ints/strs
1771 targets : int/str/list of ints/strs
1767 The targets, by int_id, whose entire history is to be purged.
1772 The targets, by int_id, whose entire history is to be purged.
1768
1773
1769 default : None
1774 default : None
1770 """
1775 """
1771 self.purge_local_results(jobs=jobs, targets=targets)
1776 self.purge_local_results(jobs=jobs, targets=targets)
1772 self.purge_hub_results(jobs=jobs, targets=targets)
1777 self.purge_hub_results(jobs=jobs, targets=targets)
1773
1778
1774 def purge_everything(self):
1779 def purge_everything(self):
1775 """Clears all content from previous Tasks from both the hub and the local client
1780 """Clears all content from previous Tasks from both the hub and the local client
1776
1781
1777 In addition to calling `purge_results("all")` it also deletes the history and
1782 In addition to calling `purge_results("all")` it also deletes the history and
1778 other bookkeeping lists.
1783 other bookkeeping lists.
1779 """
1784 """
1780 self.purge_results("all")
1785 self.purge_results("all")
1781 self.history = []
1786 self.history = []
1782 self.session.digest_history.clear()
1787 self.session.digest_history.clear()
1783
1788
1784 @spin_first
1789 @spin_first
1785 def hub_history(self):
1790 def hub_history(self):
1786 """Get the Hub's history
1791 """Get the Hub's history
1787
1792
1788 Just like the Client, the Hub has a history, which is a list of msg_ids.
1793 Just like the Client, the Hub has a history, which is a list of msg_ids.
1789 This will contain the history of all clients, and, depending on configuration,
1794 This will contain the history of all clients, and, depending on configuration,
1790 may contain history across multiple cluster sessions.
1795 may contain history across multiple cluster sessions.
1791
1796
1792 Any msg_id returned here is a valid argument to `get_result`.
1797 Any msg_id returned here is a valid argument to `get_result`.
1793
1798
1794 Returns
1799 Returns
1795 -------
1800 -------
1796
1801
1797 msg_ids : list of strs
1802 msg_ids : list of strs
1798 list of all msg_ids, ordered by task submission time.
1803 list of all msg_ids, ordered by task submission time.
1799 """
1804 """
1800
1805
1801 self.session.send(self._query_socket, "history_request", content={})
1806 self.session.send(self._query_socket, "history_request", content={})
1802 idents, msg = self.session.recv(self._query_socket, 0)
1807 idents, msg = self.session.recv(self._query_socket, 0)
1803
1808
1804 if self.debug:
1809 if self.debug:
1805 pprint(msg)
1810 pprint(msg)
1806 content = msg['content']
1811 content = msg['content']
1807 if content['status'] != 'ok':
1812 if content['status'] != 'ok':
1808 raise self._unwrap_exception(content)
1813 raise self._unwrap_exception(content)
1809 else:
1814 else:
1810 return content['history']
1815 return content['history']
1811
1816
1812 @spin_first
1817 @spin_first
1813 def db_query(self, query, keys=None):
1818 def db_query(self, query, keys=None):
1814 """Query the Hub's TaskRecord database
1819 """Query the Hub's TaskRecord database
1815
1820
1816 This will return a list of task record dicts that match `query`
1821 This will return a list of task record dicts that match `query`
1817
1822
1818 Parameters
1823 Parameters
1819 ----------
1824 ----------
1820
1825
1821 query : mongodb query dict
1826 query : mongodb query dict
1822 The search dict. See mongodb query docs for details.
1827 The search dict. See mongodb query docs for details.
1823 keys : list of strs [optional]
1828 keys : list of strs [optional]
1824 The subset of keys to be returned. The default is to fetch everything but buffers.
1829 The subset of keys to be returned. The default is to fetch everything but buffers.
1825 'msg_id' will *always* be included.
1830 'msg_id' will *always* be included.
1826 """
1831 """
1827 if isinstance(keys, string_types):
1832 if isinstance(keys, string_types):
1828 keys = [keys]
1833 keys = [keys]
1829 content = dict(query=query, keys=keys)
1834 content = dict(query=query, keys=keys)
1830 self.session.send(self._query_socket, "db_request", content=content)
1835 self.session.send(self._query_socket, "db_request", content=content)
1831 idents, msg = self.session.recv(self._query_socket, 0)
1836 idents, msg = self.session.recv(self._query_socket, 0)
1832 if self.debug:
1837 if self.debug:
1833 pprint(msg)
1838 pprint(msg)
1834 content = msg['content']
1839 content = msg['content']
1835 if content['status'] != 'ok':
1840 if content['status'] != 'ok':
1836 raise self._unwrap_exception(content)
1841 raise self._unwrap_exception(content)
1837
1842
1838 records = content['records']
1843 records = content['records']
1839
1844
1840 buffer_lens = content['buffer_lens']
1845 buffer_lens = content['buffer_lens']
1841 result_buffer_lens = content['result_buffer_lens']
1846 result_buffer_lens = content['result_buffer_lens']
1842 buffers = msg['buffers']
1847 buffers = msg['buffers']
1843 has_bufs = buffer_lens is not None
1848 has_bufs = buffer_lens is not None
1844 has_rbufs = result_buffer_lens is not None
1849 has_rbufs = result_buffer_lens is not None
1845 for i,rec in enumerate(records):
1850 for i,rec in enumerate(records):
1846 # unpack datetime objects
1851 # unpack datetime objects
1847 for hkey in ('header', 'result_header'):
1852 for hkey in ('header', 'result_header'):
1848 if hkey in rec:
1853 if hkey in rec:
1849 rec[hkey] = extract_dates(rec[hkey])
1854 rec[hkey] = extract_dates(rec[hkey])
1850 for dtkey in ('submitted', 'started', 'completed', 'received'):
1855 for dtkey in ('submitted', 'started', 'completed', 'received'):
1851 if dtkey in rec:
1856 if dtkey in rec:
1852 rec[dtkey] = parse_date(rec[dtkey])
1857 rec[dtkey] = parse_date(rec[dtkey])
1853 # relink buffers
1858 # relink buffers
1854 if has_bufs:
1859 if has_bufs:
1855 blen = buffer_lens[i]
1860 blen = buffer_lens[i]
1856 rec['buffers'], buffers = buffers[:blen],buffers[blen:]
1861 rec['buffers'], buffers = buffers[:blen],buffers[blen:]
1857 if has_rbufs:
1862 if has_rbufs:
1858 blen = result_buffer_lens[i]
1863 blen = result_buffer_lens[i]
1859 rec['result_buffers'], buffers = buffers[:blen],buffers[blen:]
1864 rec['result_buffers'], buffers = buffers[:blen],buffers[blen:]
1860
1865
1861 return records
1866 return records
1862
1867
1863 __all__ = [ 'Client' ]
1868 __all__ = [ 'Client' ]
@@ -1,1131 +1,1125 b''
1 """Views of remote engines.
1 """Views of remote engines."""
2
2
3 Authors:
3 # Copyright (c) IPython Development Team.
4 # Distributed under the terms of the Modified BSD License.
4
5
5 * Min RK
6 """
7 from __future__ import print_function
6 from __future__ import print_function
8 #-----------------------------------------------------------------------------
9 # Copyright (C) 2010-2011 The IPython Development Team
10 #
11 # Distributed under the terms of the BSD License. The full license is in
12 # the file COPYING, distributed as part of this software.
13 #-----------------------------------------------------------------------------
14
15 #-----------------------------------------------------------------------------
16 # Imports
17 #-----------------------------------------------------------------------------
18
7
19 import imp
8 import imp
20 import sys
9 import sys
21 import warnings
10 import warnings
22 from contextlib import contextmanager
11 from contextlib import contextmanager
23 from types import ModuleType
12 from types import ModuleType
24
13
25 import zmq
14 import zmq
26
15
27 from IPython.testing.skipdoctest import skip_doctest
16 from IPython.testing.skipdoctest import skip_doctest
28 from IPython.utils import pickleutil
17 from IPython.utils import pickleutil
29 from IPython.utils.traitlets import (
18 from IPython.utils.traitlets import (
30 HasTraits, Any, Bool, List, Dict, Set, Instance, CFloat, Integer
19 HasTraits, Any, Bool, List, Dict, Set, Instance, CFloat, Integer
31 )
20 )
32 from IPython.external.decorator import decorator
21 from IPython.external.decorator import decorator
33
22
34 from IPython.parallel import util
23 from IPython.parallel import util
35 from IPython.parallel.controller.dependency import Dependency, dependent
24 from IPython.parallel.controller.dependency import Dependency, dependent
36 from IPython.utils.py3compat import string_types, iteritems, PY3
25 from IPython.utils.py3compat import string_types, iteritems, PY3
37
26
38 from . import map as Map
27 from . import map as Map
39 from .asyncresult import AsyncResult, AsyncMapResult
28 from .asyncresult import AsyncResult, AsyncMapResult
40 from .remotefunction import ParallelFunction, parallel, remote, getname
29 from .remotefunction import ParallelFunction, parallel, remote, getname
41
30
42 #-----------------------------------------------------------------------------
31 #-----------------------------------------------------------------------------
43 # Decorators
32 # Decorators
44 #-----------------------------------------------------------------------------
33 #-----------------------------------------------------------------------------
45
34
46 @decorator
35 @decorator
47 def save_ids(f, self, *args, **kwargs):
36 def save_ids(f, self, *args, **kwargs):
48 """Keep our history and outstanding attributes up to date after a method call."""
37 """Keep our history and outstanding attributes up to date after a method call."""
49 n_previous = len(self.client.history)
38 n_previous = len(self.client.history)
50 try:
39 try:
51 ret = f(self, *args, **kwargs)
40 ret = f(self, *args, **kwargs)
52 finally:
41 finally:
53 nmsgs = len(self.client.history) - n_previous
42 nmsgs = len(self.client.history) - n_previous
54 msg_ids = self.client.history[-nmsgs:]
43 msg_ids = self.client.history[-nmsgs:]
55 self.history.extend(msg_ids)
44 self.history.extend(msg_ids)
56 self.outstanding.update(msg_ids)
45 self.outstanding.update(msg_ids)
57 return ret
46 return ret
58
47
59 @decorator
48 @decorator
60 def sync_results(f, self, *args, **kwargs):
49 def sync_results(f, self, *args, **kwargs):
61 """sync relevant results from self.client to our results attribute."""
50 """sync relevant results from self.client to our results attribute."""
62 if self._in_sync_results:
51 if self._in_sync_results:
63 return f(self, *args, **kwargs)
52 return f(self, *args, **kwargs)
64 self._in_sync_results = True
53 self._in_sync_results = True
65 try:
54 try:
66 ret = f(self, *args, **kwargs)
55 ret = f(self, *args, **kwargs)
67 finally:
56 finally:
68 self._in_sync_results = False
57 self._in_sync_results = False
69 self._sync_results()
58 self._sync_results()
70 return ret
59 return ret
71
60
72 @decorator
61 @decorator
73 def spin_after(f, self, *args, **kwargs):
62 def spin_after(f, self, *args, **kwargs):
74 """call spin after the method."""
63 """call spin after the method."""
75 ret = f(self, *args, **kwargs)
64 ret = f(self, *args, **kwargs)
76 self.spin()
65 self.spin()
77 return ret
66 return ret
78
67
79 #-----------------------------------------------------------------------------
68 #-----------------------------------------------------------------------------
80 # Classes
69 # Classes
81 #-----------------------------------------------------------------------------
70 #-----------------------------------------------------------------------------
82
71
83 @skip_doctest
72 @skip_doctest
84 class View(HasTraits):
73 class View(HasTraits):
85 """Base View class for more convenint apply(f,*args,**kwargs) syntax via attributes.
74 """Base View class for more convenint apply(f,*args,**kwargs) syntax via attributes.
86
75
87 Don't use this class, use subclasses.
76 Don't use this class, use subclasses.
88
77
89 Methods
78 Methods
90 -------
79 -------
91
80
92 spin
81 spin
93 flushes incoming results and registration state changes
82 flushes incoming results and registration state changes
94 control methods spin, and requesting `ids` also ensures up to date
83 control methods spin, and requesting `ids` also ensures up to date
95
84
96 wait
85 wait
97 wait on one or more msg_ids
86 wait on one or more msg_ids
98
87
99 execution methods
88 execution methods
100 apply
89 apply
101 legacy: execute, run
90 legacy: execute, run
102
91
103 data movement
92 data movement
104 push, pull, scatter, gather
93 push, pull, scatter, gather
105
94
106 query methods
95 query methods
107 get_result, queue_status, purge_results, result_status
96 get_result, queue_status, purge_results, result_status
108
97
109 control methods
98 control methods
110 abort, shutdown
99 abort, shutdown
111
100
112 """
101 """
113 # flags
102 # flags
114 block=Bool(False)
103 block=Bool(False)
115 track=Bool(True)
104 track=Bool(True)
116 targets = Any()
105 targets = Any()
117
106
118 history=List()
107 history=List()
119 outstanding = Set()
108 outstanding = Set()
120 results = Dict()
109 results = Dict()
121 client = Instance('IPython.parallel.Client')
110 client = Instance('IPython.parallel.Client')
122
111
123 _socket = Instance('zmq.Socket')
112 _socket = Instance('zmq.Socket')
124 _flag_names = List(['targets', 'block', 'track'])
113 _flag_names = List(['targets', 'block', 'track'])
125 _in_sync_results = Bool(False)
114 _in_sync_results = Bool(False)
126 _targets = Any()
115 _targets = Any()
127 _idents = Any()
116 _idents = Any()
128
117
129 def __init__(self, client=None, socket=None, **flags):
118 def __init__(self, client=None, socket=None, **flags):
130 super(View, self).__init__(client=client, _socket=socket)
119 super(View, self).__init__(client=client, _socket=socket)
131 self.results = client.results
120 self.results = client.results
132 self.block = client.block
121 self.block = client.block
133
122
134 self.set_flags(**flags)
123 self.set_flags(**flags)
135
124
136 assert not self.__class__ is View, "Don't use base View objects, use subclasses"
125 assert not self.__class__ is View, "Don't use base View objects, use subclasses"
137
126
138 def __repr__(self):
127 def __repr__(self):
139 strtargets = str(self.targets)
128 strtargets = str(self.targets)
140 if len(strtargets) > 16:
129 if len(strtargets) > 16:
141 strtargets = strtargets[:12]+'...]'
130 strtargets = strtargets[:12]+'...]'
142 return "<%s %s>"%(self.__class__.__name__, strtargets)
131 return "<%s %s>"%(self.__class__.__name__, strtargets)
143
132
144 def __len__(self):
133 def __len__(self):
145 if isinstance(self.targets, list):
134 if isinstance(self.targets, list):
146 return len(self.targets)
135 return len(self.targets)
147 elif isinstance(self.targets, int):
136 elif isinstance(self.targets, int):
148 return 1
137 return 1
149 else:
138 else:
150 return len(self.client)
139 return len(self.client)
151
140
152 def set_flags(self, **kwargs):
141 def set_flags(self, **kwargs):
153 """set my attribute flags by keyword.
142 """set my attribute flags by keyword.
154
143
155 Views determine behavior with a few attributes (`block`, `track`, etc.).
144 Views determine behavior with a few attributes (`block`, `track`, etc.).
156 These attributes can be set all at once by name with this method.
145 These attributes can be set all at once by name with this method.
157
146
158 Parameters
147 Parameters
159 ----------
148 ----------
160
149
161 block : bool
150 block : bool
162 whether to wait for results
151 whether to wait for results
163 track : bool
152 track : bool
164 whether to create a MessageTracker to allow the user to
153 whether to create a MessageTracker to allow the user to
165 safely edit after arrays and buffers during non-copying
154 safely edit after arrays and buffers during non-copying
166 sends.
155 sends.
167 """
156 """
168 for name, value in iteritems(kwargs):
157 for name, value in iteritems(kwargs):
169 if name not in self._flag_names:
158 if name not in self._flag_names:
170 raise KeyError("Invalid name: %r"%name)
159 raise KeyError("Invalid name: %r"%name)
171 else:
160 else:
172 setattr(self, name, value)
161 setattr(self, name, value)
173
162
174 @contextmanager
163 @contextmanager
175 def temp_flags(self, **kwargs):
164 def temp_flags(self, **kwargs):
176 """temporarily set flags, for use in `with` statements.
165 """temporarily set flags, for use in `with` statements.
177
166
178 See set_flags for permanent setting of flags
167 See set_flags for permanent setting of flags
179
168
180 Examples
169 Examples
181 --------
170 --------
182
171
183 >>> view.track=False
172 >>> view.track=False
184 ...
173 ...
185 >>> with view.temp_flags(track=True):
174 >>> with view.temp_flags(track=True):
186 ... ar = view.apply(dostuff, my_big_array)
175 ... ar = view.apply(dostuff, my_big_array)
187 ... ar.tracker.wait() # wait for send to finish
176 ... ar.tracker.wait() # wait for send to finish
188 >>> view.track
177 >>> view.track
189 False
178 False
190
179
191 """
180 """
192 # preflight: save flags, and set temporaries
181 # preflight: save flags, and set temporaries
193 saved_flags = {}
182 saved_flags = {}
194 for f in self._flag_names:
183 for f in self._flag_names:
195 saved_flags[f] = getattr(self, f)
184 saved_flags[f] = getattr(self, f)
196 self.set_flags(**kwargs)
185 self.set_flags(**kwargs)
197 # yield to the with-statement block
186 # yield to the with-statement block
198 try:
187 try:
199 yield
188 yield
200 finally:
189 finally:
201 # postflight: restore saved flags
190 # postflight: restore saved flags
202 self.set_flags(**saved_flags)
191 self.set_flags(**saved_flags)
203
192
204
193
205 #----------------------------------------------------------------
194 #----------------------------------------------------------------
206 # apply
195 # apply
207 #----------------------------------------------------------------
196 #----------------------------------------------------------------
208
197
209 def _sync_results(self):
198 def _sync_results(self):
210 """to be called by @sync_results decorator
199 """to be called by @sync_results decorator
211
200
212 after submitting any tasks.
201 after submitting any tasks.
213 """
202 """
214 delta = self.outstanding.difference(self.client.outstanding)
203 delta = self.outstanding.difference(self.client.outstanding)
215 completed = self.outstanding.intersection(delta)
204 completed = self.outstanding.intersection(delta)
216 self.outstanding = self.outstanding.difference(completed)
205 self.outstanding = self.outstanding.difference(completed)
217
206
218 @sync_results
207 @sync_results
219 @save_ids
208 @save_ids
220 def _really_apply(self, f, args, kwargs, block=None, **options):
209 def _really_apply(self, f, args, kwargs, block=None, **options):
221 """wrapper for client.send_apply_request"""
210 """wrapper for client.send_apply_request"""
222 raise NotImplementedError("Implement in subclasses")
211 raise NotImplementedError("Implement in subclasses")
223
212
224 def apply(self, f, *args, **kwargs):
213 def apply(self, f, *args, **kwargs):
225 """calls ``f(*args, **kwargs)`` on remote engines, returning the result.
214 """calls ``f(*args, **kwargs)`` on remote engines, returning the result.
226
215
227 This method sets all apply flags via this View's attributes.
216 This method sets all apply flags via this View's attributes.
228
217
229 Returns :class:`~IPython.parallel.client.asyncresult.AsyncResult`
218 Returns :class:`~IPython.parallel.client.asyncresult.AsyncResult`
230 instance if ``self.block`` is False, otherwise the return value of
219 instance if ``self.block`` is False, otherwise the return value of
231 ``f(*args, **kwargs)``.
220 ``f(*args, **kwargs)``.
232 """
221 """
233 return self._really_apply(f, args, kwargs)
222 return self._really_apply(f, args, kwargs)
234
223
235 def apply_async(self, f, *args, **kwargs):
224 def apply_async(self, f, *args, **kwargs):
236 """calls ``f(*args, **kwargs)`` on remote engines in a nonblocking manner.
225 """calls ``f(*args, **kwargs)`` on remote engines in a nonblocking manner.
237
226
238 Returns :class:`~IPython.parallel.client.asyncresult.AsyncResult` instance.
227 Returns :class:`~IPython.parallel.client.asyncresult.AsyncResult` instance.
239 """
228 """
240 return self._really_apply(f, args, kwargs, block=False)
229 return self._really_apply(f, args, kwargs, block=False)
241
230
242 @spin_after
231 @spin_after
243 def apply_sync(self, f, *args, **kwargs):
232 def apply_sync(self, f, *args, **kwargs):
244 """calls ``f(*args, **kwargs)`` on remote engines in a blocking manner,
233 """calls ``f(*args, **kwargs)`` on remote engines in a blocking manner,
245 returning the result.
234 returning the result.
246 """
235 """
247 return self._really_apply(f, args, kwargs, block=True)
236 return self._really_apply(f, args, kwargs, block=True)
248
237
249 #----------------------------------------------------------------
238 #----------------------------------------------------------------
250 # wrappers for client and control methods
239 # wrappers for client and control methods
251 #----------------------------------------------------------------
240 #----------------------------------------------------------------
252 @sync_results
241 @sync_results
253 def spin(self):
242 def spin(self):
254 """spin the client, and sync"""
243 """spin the client, and sync"""
255 self.client.spin()
244 self.client.spin()
256
245
257 @sync_results
246 @sync_results
258 def wait(self, jobs=None, timeout=-1):
247 def wait(self, jobs=None, timeout=-1):
259 """waits on one or more `jobs`, for up to `timeout` seconds.
248 """waits on one or more `jobs`, for up to `timeout` seconds.
260
249
261 Parameters
250 Parameters
262 ----------
251 ----------
263
252
264 jobs : int, str, or list of ints and/or strs, or one or more AsyncResult objects
253 jobs : int, str, or list of ints and/or strs, or one or more AsyncResult objects
265 ints are indices to self.history
254 ints are indices to self.history
266 strs are msg_ids
255 strs are msg_ids
267 default: wait on all outstanding messages
256 default: wait on all outstanding messages
268 timeout : float
257 timeout : float
269 a time in seconds, after which to give up.
258 a time in seconds, after which to give up.
270 default is -1, which means no timeout
259 default is -1, which means no timeout
271
260
272 Returns
261 Returns
273 -------
262 -------
274
263
275 True : when all msg_ids are done
264 True : when all msg_ids are done
276 False : timeout reached, some msg_ids still outstanding
265 False : timeout reached, some msg_ids still outstanding
277 """
266 """
278 if jobs is None:
267 if jobs is None:
279 jobs = self.history
268 jobs = self.history
280 return self.client.wait(jobs, timeout)
269 return self.client.wait(jobs, timeout)
281
270
282 def abort(self, jobs=None, targets=None, block=None):
271 def abort(self, jobs=None, targets=None, block=None):
283 """Abort jobs on my engines.
272 """Abort jobs on my engines.
284
273
285 Parameters
274 Parameters
286 ----------
275 ----------
287
276
288 jobs : None, str, list of strs, optional
277 jobs : None, str, list of strs, optional
289 if None: abort all jobs.
278 if None: abort all jobs.
290 else: abort specific msg_id(s).
279 else: abort specific msg_id(s).
291 """
280 """
292 block = block if block is not None else self.block
281 block = block if block is not None else self.block
293 targets = targets if targets is not None else self.targets
282 targets = targets if targets is not None else self.targets
294 jobs = jobs if jobs is not None else list(self.outstanding)
283 jobs = jobs if jobs is not None else list(self.outstanding)
295
284
296 return self.client.abort(jobs=jobs, targets=targets, block=block)
285 return self.client.abort(jobs=jobs, targets=targets, block=block)
297
286
298 def queue_status(self, targets=None, verbose=False):
287 def queue_status(self, targets=None, verbose=False):
299 """Fetch the Queue status of my engines"""
288 """Fetch the Queue status of my engines"""
300 targets = targets if targets is not None else self.targets
289 targets = targets if targets is not None else self.targets
301 return self.client.queue_status(targets=targets, verbose=verbose)
290 return self.client.queue_status(targets=targets, verbose=verbose)
302
291
303 def purge_results(self, jobs=[], targets=[]):
292 def purge_results(self, jobs=[], targets=[]):
304 """Instruct the controller to forget specific results."""
293 """Instruct the controller to forget specific results."""
305 if targets is None or targets == 'all':
294 if targets is None or targets == 'all':
306 targets = self.targets
295 targets = self.targets
307 return self.client.purge_results(jobs=jobs, targets=targets)
296 return self.client.purge_results(jobs=jobs, targets=targets)
308
297
309 def shutdown(self, targets=None, restart=False, hub=False, block=None):
298 def shutdown(self, targets=None, restart=False, hub=False, block=None):
310 """Terminates one or more engine processes, optionally including the hub.
299 """Terminates one or more engine processes, optionally including the hub.
311 """
300 """
312 block = self.block if block is None else block
301 block = self.block if block is None else block
313 if targets is None or targets == 'all':
302 if targets is None or targets == 'all':
314 targets = self.targets
303 targets = self.targets
315 return self.client.shutdown(targets=targets, restart=restart, hub=hub, block=block)
304 return self.client.shutdown(targets=targets, restart=restart, hub=hub, block=block)
316
305
317 @spin_after
306 @spin_after
318 def get_result(self, indices_or_msg_ids=None):
307 def get_result(self, indices_or_msg_ids=None, block=None, owner=True):
319 """return one or more results, specified by history index or msg_id.
308 """return one or more results, specified by history index or msg_id.
320
309
321 See :meth:`IPython.parallel.client.client.Client.get_result` for details.
310 See :meth:`IPython.parallel.client.client.Client.get_result` for details.
322 """
311 """
323
312
324 if indices_or_msg_ids is None:
313 if indices_or_msg_ids is None:
325 indices_or_msg_ids = -1
314 indices_or_msg_ids = -1
326 if isinstance(indices_or_msg_ids, int):
315 if isinstance(indices_or_msg_ids, int):
327 indices_or_msg_ids = self.history[indices_or_msg_ids]
316 indices_or_msg_ids = self.history[indices_or_msg_ids]
328 elif isinstance(indices_or_msg_ids, (list,tuple,set)):
317 elif isinstance(indices_or_msg_ids, (list,tuple,set)):
329 indices_or_msg_ids = list(indices_or_msg_ids)
318 indices_or_msg_ids = list(indices_or_msg_ids)
330 for i,index in enumerate(indices_or_msg_ids):
319 for i,index in enumerate(indices_or_msg_ids):
331 if isinstance(index, int):
320 if isinstance(index, int):
332 indices_or_msg_ids[i] = self.history[index]
321 indices_or_msg_ids[i] = self.history[index]
333 return self.client.get_result(indices_or_msg_ids)
322 return self.client.get_result(indices_or_msg_ids, block=block, owner=owner)
334
323
335 #-------------------------------------------------------------------
324 #-------------------------------------------------------------------
336 # Map
325 # Map
337 #-------------------------------------------------------------------
326 #-------------------------------------------------------------------
338
327
339 @sync_results
328 @sync_results
340 def map(self, f, *sequences, **kwargs):
329 def map(self, f, *sequences, **kwargs):
341 """override in subclasses"""
330 """override in subclasses"""
342 raise NotImplementedError
331 raise NotImplementedError
343
332
344 def map_async(self, f, *sequences, **kwargs):
333 def map_async(self, f, *sequences, **kwargs):
345 """Parallel version of builtin :func:`python:map`, using this view's engines.
334 """Parallel version of builtin :func:`python:map`, using this view's engines.
346
335
347 This is equivalent to ``map(...block=False)``.
336 This is equivalent to ``map(...block=False)``.
348
337
349 See `self.map` for details.
338 See `self.map` for details.
350 """
339 """
351 if 'block' in kwargs:
340 if 'block' in kwargs:
352 raise TypeError("map_async doesn't take a `block` keyword argument.")
341 raise TypeError("map_async doesn't take a `block` keyword argument.")
353 kwargs['block'] = False
342 kwargs['block'] = False
354 return self.map(f,*sequences,**kwargs)
343 return self.map(f,*sequences,**kwargs)
355
344
356 def map_sync(self, f, *sequences, **kwargs):
345 def map_sync(self, f, *sequences, **kwargs):
357 """Parallel version of builtin :func:`python:map`, using this view's engines.
346 """Parallel version of builtin :func:`python:map`, using this view's engines.
358
347
359 This is equivalent to ``map(...block=True)``.
348 This is equivalent to ``map(...block=True)``.
360
349
361 See `self.map` for details.
350 See `self.map` for details.
362 """
351 """
363 if 'block' in kwargs:
352 if 'block' in kwargs:
364 raise TypeError("map_sync doesn't take a `block` keyword argument.")
353 raise TypeError("map_sync doesn't take a `block` keyword argument.")
365 kwargs['block'] = True
354 kwargs['block'] = True
366 return self.map(f,*sequences,**kwargs)
355 return self.map(f,*sequences,**kwargs)
367
356
368 def imap(self, f, *sequences, **kwargs):
357 def imap(self, f, *sequences, **kwargs):
369 """Parallel version of :func:`itertools.imap`.
358 """Parallel version of :func:`itertools.imap`.
370
359
371 See `self.map` for details.
360 See `self.map` for details.
372
361
373 """
362 """
374
363
375 return iter(self.map_async(f,*sequences, **kwargs))
364 return iter(self.map_async(f,*sequences, **kwargs))
376
365
377 #-------------------------------------------------------------------
366 #-------------------------------------------------------------------
378 # Decorators
367 # Decorators
379 #-------------------------------------------------------------------
368 #-------------------------------------------------------------------
380
369
381 def remote(self, block=None, **flags):
370 def remote(self, block=None, **flags):
382 """Decorator for making a RemoteFunction"""
371 """Decorator for making a RemoteFunction"""
383 block = self.block if block is None else block
372 block = self.block if block is None else block
384 return remote(self, block=block, **flags)
373 return remote(self, block=block, **flags)
385
374
386 def parallel(self, dist='b', block=None, **flags):
375 def parallel(self, dist='b', block=None, **flags):
387 """Decorator for making a ParallelFunction"""
376 """Decorator for making a ParallelFunction"""
388 block = self.block if block is None else block
377 block = self.block if block is None else block
389 return parallel(self, dist=dist, block=block, **flags)
378 return parallel(self, dist=dist, block=block, **flags)
390
379
391 @skip_doctest
380 @skip_doctest
392 class DirectView(View):
381 class DirectView(View):
393 """Direct Multiplexer View of one or more engines.
382 """Direct Multiplexer View of one or more engines.
394
383
395 These are created via indexed access to a client:
384 These are created via indexed access to a client:
396
385
397 >>> dv_1 = client[1]
386 >>> dv_1 = client[1]
398 >>> dv_all = client[:]
387 >>> dv_all = client[:]
399 >>> dv_even = client[::2]
388 >>> dv_even = client[::2]
400 >>> dv_some = client[1:3]
389 >>> dv_some = client[1:3]
401
390
402 This object provides dictionary access to engine namespaces:
391 This object provides dictionary access to engine namespaces:
403
392
404 # push a=5:
393 # push a=5:
405 >>> dv['a'] = 5
394 >>> dv['a'] = 5
406 # pull 'foo':
395 # pull 'foo':
407 >>> dv['foo']
396 >>> dv['foo']
408
397
409 """
398 """
410
399
411 def __init__(self, client=None, socket=None, targets=None):
400 def __init__(self, client=None, socket=None, targets=None):
412 super(DirectView, self).__init__(client=client, socket=socket, targets=targets)
401 super(DirectView, self).__init__(client=client, socket=socket, targets=targets)
413
402
414 @property
403 @property
415 def importer(self):
404 def importer(self):
416 """sync_imports(local=True) as a property.
405 """sync_imports(local=True) as a property.
417
406
418 See sync_imports for details.
407 See sync_imports for details.
419
408
420 """
409 """
421 return self.sync_imports(True)
410 return self.sync_imports(True)
422
411
423 @contextmanager
412 @contextmanager
424 def sync_imports(self, local=True, quiet=False):
413 def sync_imports(self, local=True, quiet=False):
425 """Context Manager for performing simultaneous local and remote imports.
414 """Context Manager for performing simultaneous local and remote imports.
426
415
427 'import x as y' will *not* work. The 'as y' part will simply be ignored.
416 'import x as y' will *not* work. The 'as y' part will simply be ignored.
428
417
429 If `local=True`, then the package will also be imported locally.
418 If `local=True`, then the package will also be imported locally.
430
419
431 If `quiet=True`, no output will be produced when attempting remote
420 If `quiet=True`, no output will be produced when attempting remote
432 imports.
421 imports.
433
422
434 Note that remote-only (`local=False`) imports have not been implemented.
423 Note that remote-only (`local=False`) imports have not been implemented.
435
424
436 >>> with view.sync_imports():
425 >>> with view.sync_imports():
437 ... from numpy import recarray
426 ... from numpy import recarray
438 importing recarray from numpy on engine(s)
427 importing recarray from numpy on engine(s)
439
428
440 """
429 """
441 from IPython.utils.py3compat import builtin_mod
430 from IPython.utils.py3compat import builtin_mod
442 local_import = builtin_mod.__import__
431 local_import = builtin_mod.__import__
443 modules = set()
432 modules = set()
444 results = []
433 results = []
445 @util.interactive
434 @util.interactive
446 def remote_import(name, fromlist, level):
435 def remote_import(name, fromlist, level):
447 """the function to be passed to apply, that actually performs the import
436 """the function to be passed to apply, that actually performs the import
448 on the engine, and loads up the user namespace.
437 on the engine, and loads up the user namespace.
449 """
438 """
450 import sys
439 import sys
451 user_ns = globals()
440 user_ns = globals()
452 mod = __import__(name, fromlist=fromlist, level=level)
441 mod = __import__(name, fromlist=fromlist, level=level)
453 if fromlist:
442 if fromlist:
454 for key in fromlist:
443 for key in fromlist:
455 user_ns[key] = getattr(mod, key)
444 user_ns[key] = getattr(mod, key)
456 else:
445 else:
457 user_ns[name] = sys.modules[name]
446 user_ns[name] = sys.modules[name]
458
447
459 def view_import(name, globals={}, locals={}, fromlist=[], level=0):
448 def view_import(name, globals={}, locals={}, fromlist=[], level=0):
460 """the drop-in replacement for __import__, that optionally imports
449 """the drop-in replacement for __import__, that optionally imports
461 locally as well.
450 locally as well.
462 """
451 """
463 # don't override nested imports
452 # don't override nested imports
464 save_import = builtin_mod.__import__
453 save_import = builtin_mod.__import__
465 builtin_mod.__import__ = local_import
454 builtin_mod.__import__ = local_import
466
455
467 if imp.lock_held():
456 if imp.lock_held():
468 # this is a side-effect import, don't do it remotely, or even
457 # this is a side-effect import, don't do it remotely, or even
469 # ignore the local effects
458 # ignore the local effects
470 return local_import(name, globals, locals, fromlist, level)
459 return local_import(name, globals, locals, fromlist, level)
471
460
472 imp.acquire_lock()
461 imp.acquire_lock()
473 if local:
462 if local:
474 mod = local_import(name, globals, locals, fromlist, level)
463 mod = local_import(name, globals, locals, fromlist, level)
475 else:
464 else:
476 raise NotImplementedError("remote-only imports not yet implemented")
465 raise NotImplementedError("remote-only imports not yet implemented")
477 imp.release_lock()
466 imp.release_lock()
478
467
479 key = name+':'+','.join(fromlist or [])
468 key = name+':'+','.join(fromlist or [])
480 if level <= 0 and key not in modules:
469 if level <= 0 and key not in modules:
481 modules.add(key)
470 modules.add(key)
482 if not quiet:
471 if not quiet:
483 if fromlist:
472 if fromlist:
484 print("importing %s from %s on engine(s)"%(','.join(fromlist), name))
473 print("importing %s from %s on engine(s)"%(','.join(fromlist), name))
485 else:
474 else:
486 print("importing %s on engine(s)"%name)
475 print("importing %s on engine(s)"%name)
487 results.append(self.apply_async(remote_import, name, fromlist, level))
476 results.append(self.apply_async(remote_import, name, fromlist, level))
488 # restore override
477 # restore override
489 builtin_mod.__import__ = save_import
478 builtin_mod.__import__ = save_import
490
479
491 return mod
480 return mod
492
481
493 # override __import__
482 # override __import__
494 builtin_mod.__import__ = view_import
483 builtin_mod.__import__ = view_import
495 try:
484 try:
496 # enter the block
485 # enter the block
497 yield
486 yield
498 except ImportError:
487 except ImportError:
499 if local:
488 if local:
500 raise
489 raise
501 else:
490 else:
502 # ignore import errors if not doing local imports
491 # ignore import errors if not doing local imports
503 pass
492 pass
504 finally:
493 finally:
505 # always restore __import__
494 # always restore __import__
506 builtin_mod.__import__ = local_import
495 builtin_mod.__import__ = local_import
507
496
508 for r in results:
497 for r in results:
509 # raise possible remote ImportErrors here
498 # raise possible remote ImportErrors here
510 r.get()
499 r.get()
511
500
512 def use_dill(self):
501 def use_dill(self):
513 """Expand serialization support with dill
502 """Expand serialization support with dill
514
503
515 adds support for closures, etc.
504 adds support for closures, etc.
516
505
517 This calls IPython.utils.pickleutil.use_dill() here and on each engine.
506 This calls IPython.utils.pickleutil.use_dill() here and on each engine.
518 """
507 """
519 pickleutil.use_dill()
508 pickleutil.use_dill()
520 return self.apply(pickleutil.use_dill)
509 return self.apply(pickleutil.use_dill)
521
510
522 def use_cloudpickle(self):
511 def use_cloudpickle(self):
523 """Expand serialization support with cloudpickle.
512 """Expand serialization support with cloudpickle.
524 """
513 """
525 pickleutil.use_cloudpickle()
514 pickleutil.use_cloudpickle()
526 return self.apply(pickleutil.use_cloudpickle)
515 return self.apply(pickleutil.use_cloudpickle)
527
516
528
517
529 @sync_results
518 @sync_results
530 @save_ids
519 @save_ids
531 def _really_apply(self, f, args=None, kwargs=None, targets=None, block=None, track=None):
520 def _really_apply(self, f, args=None, kwargs=None, targets=None, block=None, track=None):
532 """calls f(*args, **kwargs) on remote engines, returning the result.
521 """calls f(*args, **kwargs) on remote engines, returning the result.
533
522
534 This method sets all of `apply`'s flags via this View's attributes.
523 This method sets all of `apply`'s flags via this View's attributes.
535
524
536 Parameters
525 Parameters
537 ----------
526 ----------
538
527
539 f : callable
528 f : callable
540
529
541 args : list [default: empty]
530 args : list [default: empty]
542
531
543 kwargs : dict [default: empty]
532 kwargs : dict [default: empty]
544
533
545 targets : target list [default: self.targets]
534 targets : target list [default: self.targets]
546 where to run
535 where to run
547 block : bool [default: self.block]
536 block : bool [default: self.block]
548 whether to block
537 whether to block
549 track : bool [default: self.track]
538 track : bool [default: self.track]
550 whether to ask zmq to track the message, for safe non-copying sends
539 whether to ask zmq to track the message, for safe non-copying sends
551
540
552 Returns
541 Returns
553 -------
542 -------
554
543
555 if self.block is False:
544 if self.block is False:
556 returns AsyncResult
545 returns AsyncResult
557 else:
546 else:
558 returns actual result of f(*args, **kwargs) on the engine(s)
547 returns actual result of f(*args, **kwargs) on the engine(s)
559 This will be a list of self.targets is also a list (even length 1), or
548 This will be a list of self.targets is also a list (even length 1), or
560 the single result if self.targets is an integer engine id
549 the single result if self.targets is an integer engine id
561 """
550 """
562 args = [] if args is None else args
551 args = [] if args is None else args
563 kwargs = {} if kwargs is None else kwargs
552 kwargs = {} if kwargs is None else kwargs
564 block = self.block if block is None else block
553 block = self.block if block is None else block
565 track = self.track if track is None else track
554 track = self.track if track is None else track
566 targets = self.targets if targets is None else targets
555 targets = self.targets if targets is None else targets
567
556
568 _idents, _targets = self.client._build_targets(targets)
557 _idents, _targets = self.client._build_targets(targets)
569 msg_ids = []
558 msg_ids = []
570 trackers = []
559 trackers = []
571 for ident in _idents:
560 for ident in _idents:
572 msg = self.client.send_apply_request(self._socket, f, args, kwargs, track=track,
561 msg = self.client.send_apply_request(self._socket, f, args, kwargs, track=track,
573 ident=ident)
562 ident=ident)
574 if track:
563 if track:
575 trackers.append(msg['tracker'])
564 trackers.append(msg['tracker'])
576 msg_ids.append(msg['header']['msg_id'])
565 msg_ids.append(msg['header']['msg_id'])
577 if isinstance(targets, int):
566 if isinstance(targets, int):
578 msg_ids = msg_ids[0]
567 msg_ids = msg_ids[0]
579 tracker = None if track is False else zmq.MessageTracker(*trackers)
568 tracker = None if track is False else zmq.MessageTracker(*trackers)
580 ar = AsyncResult(self.client, msg_ids, fname=getname(f), targets=_targets, tracker=tracker)
569 ar = AsyncResult(self.client, msg_ids, fname=getname(f), targets=_targets,
570 tracker=tracker, owner=True,
571 )
581 if block:
572 if block:
582 try:
573 try:
583 return ar.get()
574 return ar.get()
584 except KeyboardInterrupt:
575 except KeyboardInterrupt:
585 pass
576 pass
586 return ar
577 return ar
587
578
588
579
589 @sync_results
580 @sync_results
590 def map(self, f, *sequences, **kwargs):
581 def map(self, f, *sequences, **kwargs):
591 """``view.map(f, *sequences, block=self.block)`` => list|AsyncMapResult
582 """``view.map(f, *sequences, block=self.block)`` => list|AsyncMapResult
592
583
593 Parallel version of builtin `map`, using this View's `targets`.
584 Parallel version of builtin `map`, using this View's `targets`.
594
585
595 There will be one task per target, so work will be chunked
586 There will be one task per target, so work will be chunked
596 if the sequences are longer than `targets`.
587 if the sequences are longer than `targets`.
597
588
598 Results can be iterated as they are ready, but will become available in chunks.
589 Results can be iterated as they are ready, but will become available in chunks.
599
590
600 Parameters
591 Parameters
601 ----------
592 ----------
602
593
603 f : callable
594 f : callable
604 function to be mapped
595 function to be mapped
605 *sequences: one or more sequences of matching length
596 *sequences: one or more sequences of matching length
606 the sequences to be distributed and passed to `f`
597 the sequences to be distributed and passed to `f`
607 block : bool
598 block : bool
608 whether to wait for the result or not [default self.block]
599 whether to wait for the result or not [default self.block]
609
600
610 Returns
601 Returns
611 -------
602 -------
612
603
613
604
614 If block=False
605 If block=False
615 An :class:`~IPython.parallel.client.asyncresult.AsyncMapResult` instance.
606 An :class:`~IPython.parallel.client.asyncresult.AsyncMapResult` instance.
616 An object like AsyncResult, but which reassembles the sequence of results
607 An object like AsyncResult, but which reassembles the sequence of results
617 into a single list. AsyncMapResults can be iterated through before all
608 into a single list. AsyncMapResults can be iterated through before all
618 results are complete.
609 results are complete.
619 else
610 else
620 A list, the result of ``map(f,*sequences)``
611 A list, the result of ``map(f,*sequences)``
621 """
612 """
622
613
623 block = kwargs.pop('block', self.block)
614 block = kwargs.pop('block', self.block)
624 for k in kwargs.keys():
615 for k in kwargs.keys():
625 if k not in ['block', 'track']:
616 if k not in ['block', 'track']:
626 raise TypeError("invalid keyword arg, %r"%k)
617 raise TypeError("invalid keyword arg, %r"%k)
627
618
628 assert len(sequences) > 0, "must have some sequences to map onto!"
619 assert len(sequences) > 0, "must have some sequences to map onto!"
629 pf = ParallelFunction(self, f, block=block, **kwargs)
620 pf = ParallelFunction(self, f, block=block, **kwargs)
630 return pf.map(*sequences)
621 return pf.map(*sequences)
631
622
632 @sync_results
623 @sync_results
633 @save_ids
624 @save_ids
634 def execute(self, code, silent=True, targets=None, block=None):
625 def execute(self, code, silent=True, targets=None, block=None):
635 """Executes `code` on `targets` in blocking or nonblocking manner.
626 """Executes `code` on `targets` in blocking or nonblocking manner.
636
627
637 ``execute`` is always `bound` (affects engine namespace)
628 ``execute`` is always `bound` (affects engine namespace)
638
629
639 Parameters
630 Parameters
640 ----------
631 ----------
641
632
642 code : str
633 code : str
643 the code string to be executed
634 the code string to be executed
644 block : bool
635 block : bool
645 whether or not to wait until done to return
636 whether or not to wait until done to return
646 default: self.block
637 default: self.block
647 """
638 """
648 block = self.block if block is None else block
639 block = self.block if block is None else block
649 targets = self.targets if targets is None else targets
640 targets = self.targets if targets is None else targets
650
641
651 _idents, _targets = self.client._build_targets(targets)
642 _idents, _targets = self.client._build_targets(targets)
652 msg_ids = []
643 msg_ids = []
653 trackers = []
644 trackers = []
654 for ident in _idents:
645 for ident in _idents:
655 msg = self.client.send_execute_request(self._socket, code, silent=silent, ident=ident)
646 msg = self.client.send_execute_request(self._socket, code, silent=silent, ident=ident)
656 msg_ids.append(msg['header']['msg_id'])
647 msg_ids.append(msg['header']['msg_id'])
657 if isinstance(targets, int):
648 if isinstance(targets, int):
658 msg_ids = msg_ids[0]
649 msg_ids = msg_ids[0]
659 ar = AsyncResult(self.client, msg_ids, fname='execute', targets=_targets)
650 ar = AsyncResult(self.client, msg_ids, fname='execute', targets=_targets, owner=True)
660 if block:
651 if block:
661 try:
652 try:
662 ar.get()
653 ar.get()
663 except KeyboardInterrupt:
654 except KeyboardInterrupt:
664 pass
655 pass
665 return ar
656 return ar
666
657
667 def run(self, filename, targets=None, block=None):
658 def run(self, filename, targets=None, block=None):
668 """Execute contents of `filename` on my engine(s).
659 """Execute contents of `filename` on my engine(s).
669
660
670 This simply reads the contents of the file and calls `execute`.
661 This simply reads the contents of the file and calls `execute`.
671
662
672 Parameters
663 Parameters
673 ----------
664 ----------
674
665
675 filename : str
666 filename : str
676 The path to the file
667 The path to the file
677 targets : int/str/list of ints/strs
668 targets : int/str/list of ints/strs
678 the engines on which to execute
669 the engines on which to execute
679 default : all
670 default : all
680 block : bool
671 block : bool
681 whether or not to wait until done
672 whether or not to wait until done
682 default: self.block
673 default: self.block
683
674
684 """
675 """
685 with open(filename, 'r') as f:
676 with open(filename, 'r') as f:
686 # add newline in case of trailing indented whitespace
677 # add newline in case of trailing indented whitespace
687 # which will cause SyntaxError
678 # which will cause SyntaxError
688 code = f.read()+'\n'
679 code = f.read()+'\n'
689 return self.execute(code, block=block, targets=targets)
680 return self.execute(code, block=block, targets=targets)
690
681
691 def update(self, ns):
682 def update(self, ns):
692 """update remote namespace with dict `ns`
683 """update remote namespace with dict `ns`
693
684
694 See `push` for details.
685 See `push` for details.
695 """
686 """
696 return self.push(ns, block=self.block, track=self.track)
687 return self.push(ns, block=self.block, track=self.track)
697
688
698 def push(self, ns, targets=None, block=None, track=None):
689 def push(self, ns, targets=None, block=None, track=None):
699 """update remote namespace with dict `ns`
690 """update remote namespace with dict `ns`
700
691
701 Parameters
692 Parameters
702 ----------
693 ----------
703
694
704 ns : dict
695 ns : dict
705 dict of keys with which to update engine namespace(s)
696 dict of keys with which to update engine namespace(s)
706 block : bool [default : self.block]
697 block : bool [default : self.block]
707 whether to wait to be notified of engine receipt
698 whether to wait to be notified of engine receipt
708
699
709 """
700 """
710
701
711 block = block if block is not None else self.block
702 block = block if block is not None else self.block
712 track = track if track is not None else self.track
703 track = track if track is not None else self.track
713 targets = targets if targets is not None else self.targets
704 targets = targets if targets is not None else self.targets
714 # applier = self.apply_sync if block else self.apply_async
705 # applier = self.apply_sync if block else self.apply_async
715 if not isinstance(ns, dict):
706 if not isinstance(ns, dict):
716 raise TypeError("Must be a dict, not %s"%type(ns))
707 raise TypeError("Must be a dict, not %s"%type(ns))
717 return self._really_apply(util._push, kwargs=ns, block=block, track=track, targets=targets)
708 return self._really_apply(util._push, kwargs=ns, block=block, track=track, targets=targets)
718
709
719 def get(self, key_s):
710 def get(self, key_s):
720 """get object(s) by `key_s` from remote namespace
711 """get object(s) by `key_s` from remote namespace
721
712
722 see `pull` for details.
713 see `pull` for details.
723 """
714 """
724 # block = block if block is not None else self.block
715 # block = block if block is not None else self.block
725 return self.pull(key_s, block=True)
716 return self.pull(key_s, block=True)
726
717
727 def pull(self, names, targets=None, block=None):
718 def pull(self, names, targets=None, block=None):
728 """get object(s) by `name` from remote namespace
719 """get object(s) by `name` from remote namespace
729
720
730 will return one object if it is a key.
721 will return one object if it is a key.
731 can also take a list of keys, in which case it will return a list of objects.
722 can also take a list of keys, in which case it will return a list of objects.
732 """
723 """
733 block = block if block is not None else self.block
724 block = block if block is not None else self.block
734 targets = targets if targets is not None else self.targets
725 targets = targets if targets is not None else self.targets
735 applier = self.apply_sync if block else self.apply_async
726 applier = self.apply_sync if block else self.apply_async
736 if isinstance(names, string_types):
727 if isinstance(names, string_types):
737 pass
728 pass
738 elif isinstance(names, (list,tuple,set)):
729 elif isinstance(names, (list,tuple,set)):
739 for key in names:
730 for key in names:
740 if not isinstance(key, string_types):
731 if not isinstance(key, string_types):
741 raise TypeError("keys must be str, not type %r"%type(key))
732 raise TypeError("keys must be str, not type %r"%type(key))
742 else:
733 else:
743 raise TypeError("names must be strs, not %r"%names)
734 raise TypeError("names must be strs, not %r"%names)
744 return self._really_apply(util._pull, (names,), block=block, targets=targets)
735 return self._really_apply(util._pull, (names,), block=block, targets=targets)
745
736
746 def scatter(self, key, seq, dist='b', flatten=False, targets=None, block=None, track=None):
737 def scatter(self, key, seq, dist='b', flatten=False, targets=None, block=None, track=None):
747 """
738 """
748 Partition a Python sequence and send the partitions to a set of engines.
739 Partition a Python sequence and send the partitions to a set of engines.
749 """
740 """
750 block = block if block is not None else self.block
741 block = block if block is not None else self.block
751 track = track if track is not None else self.track
742 track = track if track is not None else self.track
752 targets = targets if targets is not None else self.targets
743 targets = targets if targets is not None else self.targets
753
744
754 # construct integer ID list:
745 # construct integer ID list:
755 targets = self.client._build_targets(targets)[1]
746 targets = self.client._build_targets(targets)[1]
756
747
757 mapObject = Map.dists[dist]()
748 mapObject = Map.dists[dist]()
758 nparts = len(targets)
749 nparts = len(targets)
759 msg_ids = []
750 msg_ids = []
760 trackers = []
751 trackers = []
761 for index, engineid in enumerate(targets):
752 for index, engineid in enumerate(targets):
762 partition = mapObject.getPartition(seq, index, nparts)
753 partition = mapObject.getPartition(seq, index, nparts)
763 if flatten and len(partition) == 1:
754 if flatten and len(partition) == 1:
764 ns = {key: partition[0]}
755 ns = {key: partition[0]}
765 else:
756 else:
766 ns = {key: partition}
757 ns = {key: partition}
767 r = self.push(ns, block=False, track=track, targets=engineid)
758 r = self.push(ns, block=False, track=track, targets=engineid)
768 msg_ids.extend(r.msg_ids)
759 msg_ids.extend(r.msg_ids)
769 if track:
760 if track:
770 trackers.append(r._tracker)
761 trackers.append(r._tracker)
771
762
772 if track:
763 if track:
773 tracker = zmq.MessageTracker(*trackers)
764 tracker = zmq.MessageTracker(*trackers)
774 else:
765 else:
775 tracker = None
766 tracker = None
776
767
777 r = AsyncResult(self.client, msg_ids, fname='scatter', targets=targets, tracker=tracker)
768 r = AsyncResult(self.client, msg_ids, fname='scatter', targets=targets,
769 tracker=tracker, owner=True,
770 )
778 if block:
771 if block:
779 r.wait()
772 r.wait()
780 else:
773 else:
781 return r
774 return r
782
775
783 @sync_results
776 @sync_results
784 @save_ids
777 @save_ids
785 def gather(self, key, dist='b', targets=None, block=None):
778 def gather(self, key, dist='b', targets=None, block=None):
786 """
779 """
787 Gather a partitioned sequence on a set of engines as a single local seq.
780 Gather a partitioned sequence on a set of engines as a single local seq.
788 """
781 """
789 block = block if block is not None else self.block
782 block = block if block is not None else self.block
790 targets = targets if targets is not None else self.targets
783 targets = targets if targets is not None else self.targets
791 mapObject = Map.dists[dist]()
784 mapObject = Map.dists[dist]()
792 msg_ids = []
785 msg_ids = []
793
786
794 # construct integer ID list:
787 # construct integer ID list:
795 targets = self.client._build_targets(targets)[1]
788 targets = self.client._build_targets(targets)[1]
796
789
797 for index, engineid in enumerate(targets):
790 for index, engineid in enumerate(targets):
798 msg_ids.extend(self.pull(key, block=False, targets=engineid).msg_ids)
791 msg_ids.extend(self.pull(key, block=False, targets=engineid).msg_ids)
799
792
800 r = AsyncMapResult(self.client, msg_ids, mapObject, fname='gather')
793 r = AsyncMapResult(self.client, msg_ids, mapObject, fname='gather')
801
794
802 if block:
795 if block:
803 try:
796 try:
804 return r.get()
797 return r.get()
805 except KeyboardInterrupt:
798 except KeyboardInterrupt:
806 pass
799 pass
807 return r
800 return r
808
801
809 def __getitem__(self, key):
802 def __getitem__(self, key):
810 return self.get(key)
803 return self.get(key)
811
804
812 def __setitem__(self,key, value):
805 def __setitem__(self,key, value):
813 self.update({key:value})
806 self.update({key:value})
814
807
815 def clear(self, targets=None, block=None):
808 def clear(self, targets=None, block=None):
816 """Clear the remote namespaces on my engines."""
809 """Clear the remote namespaces on my engines."""
817 block = block if block is not None else self.block
810 block = block if block is not None else self.block
818 targets = targets if targets is not None else self.targets
811 targets = targets if targets is not None else self.targets
819 return self.client.clear(targets=targets, block=block)
812 return self.client.clear(targets=targets, block=block)
820
813
821 #----------------------------------------
814 #----------------------------------------
822 # activate for %px, %autopx, etc. magics
815 # activate for %px, %autopx, etc. magics
823 #----------------------------------------
816 #----------------------------------------
824
817
825 def activate(self, suffix=''):
818 def activate(self, suffix=''):
826 """Activate IPython magics associated with this View
819 """Activate IPython magics associated with this View
827
820
828 Defines the magics `%px, %autopx, %pxresult, %%px, %pxconfig`
821 Defines the magics `%px, %autopx, %pxresult, %%px, %pxconfig`
829
822
830 Parameters
823 Parameters
831 ----------
824 ----------
832
825
833 suffix: str [default: '']
826 suffix: str [default: '']
834 The suffix, if any, for the magics. This allows you to have
827 The suffix, if any, for the magics. This allows you to have
835 multiple views associated with parallel magics at the same time.
828 multiple views associated with parallel magics at the same time.
836
829
837 e.g. ``rc[::2].activate(suffix='_even')`` will give you
830 e.g. ``rc[::2].activate(suffix='_even')`` will give you
838 the magics ``%px_even``, ``%pxresult_even``, etc. for running magics
831 the magics ``%px_even``, ``%pxresult_even``, etc. for running magics
839 on the even engines.
832 on the even engines.
840 """
833 """
841
834
842 from IPython.parallel.client.magics import ParallelMagics
835 from IPython.parallel.client.magics import ParallelMagics
843
836
844 try:
837 try:
845 # This is injected into __builtins__.
838 # This is injected into __builtins__.
846 ip = get_ipython()
839 ip = get_ipython()
847 except NameError:
840 except NameError:
848 print("The IPython parallel magics (%px, etc.) only work within IPython.")
841 print("The IPython parallel magics (%px, etc.) only work within IPython.")
849 return
842 return
850
843
851 M = ParallelMagics(ip, self, suffix)
844 M = ParallelMagics(ip, self, suffix)
852 ip.magics_manager.register(M)
845 ip.magics_manager.register(M)
853
846
854
847
855 @skip_doctest
848 @skip_doctest
856 class LoadBalancedView(View):
849 class LoadBalancedView(View):
857 """An load-balancing View that only executes via the Task scheduler.
850 """An load-balancing View that only executes via the Task scheduler.
858
851
859 Load-balanced views can be created with the client's `view` method:
852 Load-balanced views can be created with the client's `view` method:
860
853
861 >>> v = client.load_balanced_view()
854 >>> v = client.load_balanced_view()
862
855
863 or targets can be specified, to restrict the potential destinations:
856 or targets can be specified, to restrict the potential destinations:
864
857
865 >>> v = client.client.load_balanced_view([1,3])
858 >>> v = client.client.load_balanced_view([1,3])
866
859
867 which would restrict loadbalancing to between engines 1 and 3.
860 which would restrict loadbalancing to between engines 1 and 3.
868
861
869 """
862 """
870
863
871 follow=Any()
864 follow=Any()
872 after=Any()
865 after=Any()
873 timeout=CFloat()
866 timeout=CFloat()
874 retries = Integer(0)
867 retries = Integer(0)
875
868
876 _task_scheme = Any()
869 _task_scheme = Any()
877 _flag_names = List(['targets', 'block', 'track', 'follow', 'after', 'timeout', 'retries'])
870 _flag_names = List(['targets', 'block', 'track', 'follow', 'after', 'timeout', 'retries'])
878
871
879 def __init__(self, client=None, socket=None, **flags):
872 def __init__(self, client=None, socket=None, **flags):
880 super(LoadBalancedView, self).__init__(client=client, socket=socket, **flags)
873 super(LoadBalancedView, self).__init__(client=client, socket=socket, **flags)
881 self._task_scheme=client._task_scheme
874 self._task_scheme=client._task_scheme
882
875
883 def _validate_dependency(self, dep):
876 def _validate_dependency(self, dep):
884 """validate a dependency.
877 """validate a dependency.
885
878
886 For use in `set_flags`.
879 For use in `set_flags`.
887 """
880 """
888 if dep is None or isinstance(dep, string_types + (AsyncResult, Dependency)):
881 if dep is None or isinstance(dep, string_types + (AsyncResult, Dependency)):
889 return True
882 return True
890 elif isinstance(dep, (list,set, tuple)):
883 elif isinstance(dep, (list,set, tuple)):
891 for d in dep:
884 for d in dep:
892 if not isinstance(d, string_types + (AsyncResult,)):
885 if not isinstance(d, string_types + (AsyncResult,)):
893 return False
886 return False
894 elif isinstance(dep, dict):
887 elif isinstance(dep, dict):
895 if set(dep.keys()) != set(Dependency().as_dict().keys()):
888 if set(dep.keys()) != set(Dependency().as_dict().keys()):
896 return False
889 return False
897 if not isinstance(dep['msg_ids'], list):
890 if not isinstance(dep['msg_ids'], list):
898 return False
891 return False
899 for d in dep['msg_ids']:
892 for d in dep['msg_ids']:
900 if not isinstance(d, string_types):
893 if not isinstance(d, string_types):
901 return False
894 return False
902 else:
895 else:
903 return False
896 return False
904
897
905 return True
898 return True
906
899
907 def _render_dependency(self, dep):
900 def _render_dependency(self, dep):
908 """helper for building jsonable dependencies from various input forms."""
901 """helper for building jsonable dependencies from various input forms."""
909 if isinstance(dep, Dependency):
902 if isinstance(dep, Dependency):
910 return dep.as_dict()
903 return dep.as_dict()
911 elif isinstance(dep, AsyncResult):
904 elif isinstance(dep, AsyncResult):
912 return dep.msg_ids
905 return dep.msg_ids
913 elif dep is None:
906 elif dep is None:
914 return []
907 return []
915 else:
908 else:
916 # pass to Dependency constructor
909 # pass to Dependency constructor
917 return list(Dependency(dep))
910 return list(Dependency(dep))
918
911
919 def set_flags(self, **kwargs):
912 def set_flags(self, **kwargs):
920 """set my attribute flags by keyword.
913 """set my attribute flags by keyword.
921
914
922 A View is a wrapper for the Client's apply method, but with attributes
915 A View is a wrapper for the Client's apply method, but with attributes
923 that specify keyword arguments, those attributes can be set by keyword
916 that specify keyword arguments, those attributes can be set by keyword
924 argument with this method.
917 argument with this method.
925
918
926 Parameters
919 Parameters
927 ----------
920 ----------
928
921
929 block : bool
922 block : bool
930 whether to wait for results
923 whether to wait for results
931 track : bool
924 track : bool
932 whether to create a MessageTracker to allow the user to
925 whether to create a MessageTracker to allow the user to
933 safely edit after arrays and buffers during non-copying
926 safely edit after arrays and buffers during non-copying
934 sends.
927 sends.
935
928
936 after : Dependency or collection of msg_ids
929 after : Dependency or collection of msg_ids
937 Only for load-balanced execution (targets=None)
930 Only for load-balanced execution (targets=None)
938 Specify a list of msg_ids as a time-based dependency.
931 Specify a list of msg_ids as a time-based dependency.
939 This job will only be run *after* the dependencies
932 This job will only be run *after* the dependencies
940 have been met.
933 have been met.
941
934
942 follow : Dependency or collection of msg_ids
935 follow : Dependency or collection of msg_ids
943 Only for load-balanced execution (targets=None)
936 Only for load-balanced execution (targets=None)
944 Specify a list of msg_ids as a location-based dependency.
937 Specify a list of msg_ids as a location-based dependency.
945 This job will only be run on an engine where this dependency
938 This job will only be run on an engine where this dependency
946 is met.
939 is met.
947
940
948 timeout : float/int or None
941 timeout : float/int or None
949 Only for load-balanced execution (targets=None)
942 Only for load-balanced execution (targets=None)
950 Specify an amount of time (in seconds) for the scheduler to
943 Specify an amount of time (in seconds) for the scheduler to
951 wait for dependencies to be met before failing with a
944 wait for dependencies to be met before failing with a
952 DependencyTimeout.
945 DependencyTimeout.
953
946
954 retries : int
947 retries : int
955 Number of times a task will be retried on failure.
948 Number of times a task will be retried on failure.
956 """
949 """
957
950
958 super(LoadBalancedView, self).set_flags(**kwargs)
951 super(LoadBalancedView, self).set_flags(**kwargs)
959 for name in ('follow', 'after'):
952 for name in ('follow', 'after'):
960 if name in kwargs:
953 if name in kwargs:
961 value = kwargs[name]
954 value = kwargs[name]
962 if self._validate_dependency(value):
955 if self._validate_dependency(value):
963 setattr(self, name, value)
956 setattr(self, name, value)
964 else:
957 else:
965 raise ValueError("Invalid dependency: %r"%value)
958 raise ValueError("Invalid dependency: %r"%value)
966 if 'timeout' in kwargs:
959 if 'timeout' in kwargs:
967 t = kwargs['timeout']
960 t = kwargs['timeout']
968 if not isinstance(t, (int, float, type(None))):
961 if not isinstance(t, (int, float, type(None))):
969 if (not PY3) and (not isinstance(t, long)):
962 if (not PY3) and (not isinstance(t, long)):
970 raise TypeError("Invalid type for timeout: %r"%type(t))
963 raise TypeError("Invalid type for timeout: %r"%type(t))
971 if t is not None:
964 if t is not None:
972 if t < 0:
965 if t < 0:
973 raise ValueError("Invalid timeout: %s"%t)
966 raise ValueError("Invalid timeout: %s"%t)
974 self.timeout = t
967 self.timeout = t
975
968
976 @sync_results
969 @sync_results
977 @save_ids
970 @save_ids
978 def _really_apply(self, f, args=None, kwargs=None, block=None, track=None,
971 def _really_apply(self, f, args=None, kwargs=None, block=None, track=None,
979 after=None, follow=None, timeout=None,
972 after=None, follow=None, timeout=None,
980 targets=None, retries=None):
973 targets=None, retries=None):
981 """calls f(*args, **kwargs) on a remote engine, returning the result.
974 """calls f(*args, **kwargs) on a remote engine, returning the result.
982
975
983 This method temporarily sets all of `apply`'s flags for a single call.
976 This method temporarily sets all of `apply`'s flags for a single call.
984
977
985 Parameters
978 Parameters
986 ----------
979 ----------
987
980
988 f : callable
981 f : callable
989
982
990 args : list [default: empty]
983 args : list [default: empty]
991
984
992 kwargs : dict [default: empty]
985 kwargs : dict [default: empty]
993
986
994 block : bool [default: self.block]
987 block : bool [default: self.block]
995 whether to block
988 whether to block
996 track : bool [default: self.track]
989 track : bool [default: self.track]
997 whether to ask zmq to track the message, for safe non-copying sends
990 whether to ask zmq to track the message, for safe non-copying sends
998
991
999 !!!!!! TODO: THE REST HERE !!!!
992 !!!!!! TODO: THE REST HERE !!!!
1000
993
1001 Returns
994 Returns
1002 -------
995 -------
1003
996
1004 if self.block is False:
997 if self.block is False:
1005 returns AsyncResult
998 returns AsyncResult
1006 else:
999 else:
1007 returns actual result of f(*args, **kwargs) on the engine(s)
1000 returns actual result of f(*args, **kwargs) on the engine(s)
1008 This will be a list of self.targets is also a list (even length 1), or
1001 This will be a list of self.targets is also a list (even length 1), or
1009 the single result if self.targets is an integer engine id
1002 the single result if self.targets is an integer engine id
1010 """
1003 """
1011
1004
1012 # validate whether we can run
1005 # validate whether we can run
1013 if self._socket.closed:
1006 if self._socket.closed:
1014 msg = "Task farming is disabled"
1007 msg = "Task farming is disabled"
1015 if self._task_scheme == 'pure':
1008 if self._task_scheme == 'pure':
1016 msg += " because the pure ZMQ scheduler cannot handle"
1009 msg += " because the pure ZMQ scheduler cannot handle"
1017 msg += " disappearing engines."
1010 msg += " disappearing engines."
1018 raise RuntimeError(msg)
1011 raise RuntimeError(msg)
1019
1012
1020 if self._task_scheme == 'pure':
1013 if self._task_scheme == 'pure':
1021 # pure zmq scheme doesn't support extra features
1014 # pure zmq scheme doesn't support extra features
1022 msg = "Pure ZMQ scheduler doesn't support the following flags:"
1015 msg = "Pure ZMQ scheduler doesn't support the following flags:"
1023 "follow, after, retries, targets, timeout"
1016 "follow, after, retries, targets, timeout"
1024 if (follow or after or retries or targets or timeout):
1017 if (follow or after or retries or targets or timeout):
1025 # hard fail on Scheduler flags
1018 # hard fail on Scheduler flags
1026 raise RuntimeError(msg)
1019 raise RuntimeError(msg)
1027 if isinstance(f, dependent):
1020 if isinstance(f, dependent):
1028 # soft warn on functional dependencies
1021 # soft warn on functional dependencies
1029 warnings.warn(msg, RuntimeWarning)
1022 warnings.warn(msg, RuntimeWarning)
1030
1023
1031 # build args
1024 # build args
1032 args = [] if args is None else args
1025 args = [] if args is None else args
1033 kwargs = {} if kwargs is None else kwargs
1026 kwargs = {} if kwargs is None else kwargs
1034 block = self.block if block is None else block
1027 block = self.block if block is None else block
1035 track = self.track if track is None else track
1028 track = self.track if track is None else track
1036 after = self.after if after is None else after
1029 after = self.after if after is None else after
1037 retries = self.retries if retries is None else retries
1030 retries = self.retries if retries is None else retries
1038 follow = self.follow if follow is None else follow
1031 follow = self.follow if follow is None else follow
1039 timeout = self.timeout if timeout is None else timeout
1032 timeout = self.timeout if timeout is None else timeout
1040 targets = self.targets if targets is None else targets
1033 targets = self.targets if targets is None else targets
1041
1034
1042 if not isinstance(retries, int):
1035 if not isinstance(retries, int):
1043 raise TypeError('retries must be int, not %r'%type(retries))
1036 raise TypeError('retries must be int, not %r'%type(retries))
1044
1037
1045 if targets is None:
1038 if targets is None:
1046 idents = []
1039 idents = []
1047 else:
1040 else:
1048 idents = self.client._build_targets(targets)[0]
1041 idents = self.client._build_targets(targets)[0]
1049 # ensure *not* bytes
1042 # ensure *not* bytes
1050 idents = [ ident.decode() for ident in idents ]
1043 idents = [ ident.decode() for ident in idents ]
1051
1044
1052 after = self._render_dependency(after)
1045 after = self._render_dependency(after)
1053 follow = self._render_dependency(follow)
1046 follow = self._render_dependency(follow)
1054 metadata = dict(after=after, follow=follow, timeout=timeout, targets=idents, retries=retries)
1047 metadata = dict(after=after, follow=follow, timeout=timeout, targets=idents, retries=retries)
1055
1048
1056 msg = self.client.send_apply_request(self._socket, f, args, kwargs, track=track,
1049 msg = self.client.send_apply_request(self._socket, f, args, kwargs, track=track,
1057 metadata=metadata)
1050 metadata=metadata)
1058 tracker = None if track is False else msg['tracker']
1051 tracker = None if track is False else msg['tracker']
1059
1052
1060 ar = AsyncResult(self.client, msg['header']['msg_id'], fname=getname(f), targets=None, tracker=tracker)
1053 ar = AsyncResult(self.client, msg['header']['msg_id'], fname=getname(f),
1061
1054 targets=None, tracker=tracker, owner=True,
1055 )
1062 if block:
1056 if block:
1063 try:
1057 try:
1064 return ar.get()
1058 return ar.get()
1065 except KeyboardInterrupt:
1059 except KeyboardInterrupt:
1066 pass
1060 pass
1067 return ar
1061 return ar
1068
1062
1069 @sync_results
1063 @sync_results
1070 @save_ids
1064 @save_ids
1071 def map(self, f, *sequences, **kwargs):
1065 def map(self, f, *sequences, **kwargs):
1072 """``view.map(f, *sequences, block=self.block, chunksize=1, ordered=True)`` => list|AsyncMapResult
1066 """``view.map(f, *sequences, block=self.block, chunksize=1, ordered=True)`` => list|AsyncMapResult
1073
1067
1074 Parallel version of builtin `map`, load-balanced by this View.
1068 Parallel version of builtin `map`, load-balanced by this View.
1075
1069
1076 `block`, and `chunksize` can be specified by keyword only.
1070 `block`, and `chunksize` can be specified by keyword only.
1077
1071
1078 Each `chunksize` elements will be a separate task, and will be
1072 Each `chunksize` elements will be a separate task, and will be
1079 load-balanced. This lets individual elements be available for iteration
1073 load-balanced. This lets individual elements be available for iteration
1080 as soon as they arrive.
1074 as soon as they arrive.
1081
1075
1082 Parameters
1076 Parameters
1083 ----------
1077 ----------
1084
1078
1085 f : callable
1079 f : callable
1086 function to be mapped
1080 function to be mapped
1087 *sequences: one or more sequences of matching length
1081 *sequences: one or more sequences of matching length
1088 the sequences to be distributed and passed to `f`
1082 the sequences to be distributed and passed to `f`
1089 block : bool [default self.block]
1083 block : bool [default self.block]
1090 whether to wait for the result or not
1084 whether to wait for the result or not
1091 track : bool
1085 track : bool
1092 whether to create a MessageTracker to allow the user to
1086 whether to create a MessageTracker to allow the user to
1093 safely edit after arrays and buffers during non-copying
1087 safely edit after arrays and buffers during non-copying
1094 sends.
1088 sends.
1095 chunksize : int [default 1]
1089 chunksize : int [default 1]
1096 how many elements should be in each task.
1090 how many elements should be in each task.
1097 ordered : bool [default True]
1091 ordered : bool [default True]
1098 Whether the results should be gathered as they arrive, or enforce
1092 Whether the results should be gathered as they arrive, or enforce
1099 the order of submission.
1093 the order of submission.
1100
1094
1101 Only applies when iterating through AsyncMapResult as results arrive.
1095 Only applies when iterating through AsyncMapResult as results arrive.
1102 Has no effect when block=True.
1096 Has no effect when block=True.
1103
1097
1104 Returns
1098 Returns
1105 -------
1099 -------
1106
1100
1107 if block=False
1101 if block=False
1108 An :class:`~IPython.parallel.client.asyncresult.AsyncMapResult` instance.
1102 An :class:`~IPython.parallel.client.asyncresult.AsyncMapResult` instance.
1109 An object like AsyncResult, but which reassembles the sequence of results
1103 An object like AsyncResult, but which reassembles the sequence of results
1110 into a single list. AsyncMapResults can be iterated through before all
1104 into a single list. AsyncMapResults can be iterated through before all
1111 results are complete.
1105 results are complete.
1112 else
1106 else
1113 A list, the result of ``map(f,*sequences)``
1107 A list, the result of ``map(f,*sequences)``
1114 """
1108 """
1115
1109
1116 # default
1110 # default
1117 block = kwargs.get('block', self.block)
1111 block = kwargs.get('block', self.block)
1118 chunksize = kwargs.get('chunksize', 1)
1112 chunksize = kwargs.get('chunksize', 1)
1119 ordered = kwargs.get('ordered', True)
1113 ordered = kwargs.get('ordered', True)
1120
1114
1121 keyset = set(kwargs.keys())
1115 keyset = set(kwargs.keys())
1122 extra_keys = keyset.difference_update(set(['block', 'chunksize']))
1116 extra_keys = keyset.difference_update(set(['block', 'chunksize']))
1123 if extra_keys:
1117 if extra_keys:
1124 raise TypeError("Invalid kwargs: %s"%list(extra_keys))
1118 raise TypeError("Invalid kwargs: %s"%list(extra_keys))
1125
1119
1126 assert len(sequences) > 0, "must have some sequences to map onto!"
1120 assert len(sequences) > 0, "must have some sequences to map onto!"
1127
1121
1128 pf = ParallelFunction(self, f, block=block, chunksize=chunksize, ordered=ordered)
1122 pf = ParallelFunction(self, f, block=block, chunksize=chunksize, ordered=ordered)
1129 return pf.map(*sequences)
1123 return pf.map(*sequences)
1130
1124
1131 __all__ = ['LoadBalancedView', 'DirectView']
1125 __all__ = ['LoadBalancedView', 'DirectView']
@@ -1,327 +1,342 b''
1 """Tests for asyncresult.py
1 """Tests for asyncresult.py"""
2
2
3 Authors:
3 # Copyright (c) IPython Development Team.
4
4 # Distributed under the terms of the Modified BSD License.
5 * Min RK
6 """
7
8 #-------------------------------------------------------------------------------
9 # Copyright (C) 2011 The IPython Development Team
10 #
11 # Distributed under the terms of the BSD License. The full license is in
12 # the file COPYING, distributed as part of this software.
13 #-------------------------------------------------------------------------------
14
15 #-------------------------------------------------------------------------------
16 # Imports
17 #-------------------------------------------------------------------------------
18
5
19 import time
6 import time
20
7
21 import nose.tools as nt
8 import nose.tools as nt
22
9
23 from IPython.utils.io import capture_output
10 from IPython.utils.io import capture_output
24
11
25 from IPython.parallel.error import TimeoutError
12 from IPython.parallel.error import TimeoutError
26 from IPython.parallel import error, Client
13 from IPython.parallel import error, Client
27 from IPython.parallel.tests import add_engines
14 from IPython.parallel.tests import add_engines
28 from .clienttest import ClusterTestCase
15 from .clienttest import ClusterTestCase
29 from IPython.utils.py3compat import iteritems
16 from IPython.utils.py3compat import iteritems
30
17
31 def setup():
18 def setup():
32 add_engines(2, total=True)
19 add_engines(2, total=True)
33
20
34 def wait(n):
21 def wait(n):
35 import time
22 import time
36 time.sleep(n)
23 time.sleep(n)
37 return n
24 return n
38
25
39 def echo(x):
26 def echo(x):
40 return x
27 return x
41
28
42 class AsyncResultTest(ClusterTestCase):
29 class AsyncResultTest(ClusterTestCase):
43
30
44 def test_single_result_view(self):
31 def test_single_result_view(self):
45 """various one-target views get the right value for single_result"""
32 """various one-target views get the right value for single_result"""
46 eid = self.client.ids[-1]
33 eid = self.client.ids[-1]
47 ar = self.client[eid].apply_async(lambda : 42)
34 ar = self.client[eid].apply_async(lambda : 42)
48 self.assertEqual(ar.get(), 42)
35 self.assertEqual(ar.get(), 42)
49 ar = self.client[[eid]].apply_async(lambda : 42)
36 ar = self.client[[eid]].apply_async(lambda : 42)
50 self.assertEqual(ar.get(), [42])
37 self.assertEqual(ar.get(), [42])
51 ar = self.client[-1:].apply_async(lambda : 42)
38 ar = self.client[-1:].apply_async(lambda : 42)
52 self.assertEqual(ar.get(), [42])
39 self.assertEqual(ar.get(), [42])
53
40
54 def test_get_after_done(self):
41 def test_get_after_done(self):
55 ar = self.client[-1].apply_async(lambda : 42)
42 ar = self.client[-1].apply_async(lambda : 42)
56 ar.wait()
43 ar.wait()
57 self.assertTrue(ar.ready())
44 self.assertTrue(ar.ready())
58 self.assertEqual(ar.get(), 42)
45 self.assertEqual(ar.get(), 42)
59 self.assertEqual(ar.get(), 42)
46 self.assertEqual(ar.get(), 42)
60
47
61 def test_get_before_done(self):
48 def test_get_before_done(self):
62 ar = self.client[-1].apply_async(wait, 0.1)
49 ar = self.client[-1].apply_async(wait, 0.1)
63 self.assertRaises(TimeoutError, ar.get, 0)
50 self.assertRaises(TimeoutError, ar.get, 0)
64 ar.wait(0)
51 ar.wait(0)
65 self.assertFalse(ar.ready())
52 self.assertFalse(ar.ready())
66 self.assertEqual(ar.get(), 0.1)
53 self.assertEqual(ar.get(), 0.1)
67
54
68 def test_get_after_error(self):
55 def test_get_after_error(self):
69 ar = self.client[-1].apply_async(lambda : 1/0)
56 ar = self.client[-1].apply_async(lambda : 1/0)
70 ar.wait(10)
57 ar.wait(10)
71 self.assertRaisesRemote(ZeroDivisionError, ar.get)
58 self.assertRaisesRemote(ZeroDivisionError, ar.get)
72 self.assertRaisesRemote(ZeroDivisionError, ar.get)
59 self.assertRaisesRemote(ZeroDivisionError, ar.get)
73 self.assertRaisesRemote(ZeroDivisionError, ar.get_dict)
60 self.assertRaisesRemote(ZeroDivisionError, ar.get_dict)
74
61
75 def test_get_dict(self):
62 def test_get_dict(self):
76 n = len(self.client)
63 n = len(self.client)
77 ar = self.client[:].apply_async(lambda : 5)
64 ar = self.client[:].apply_async(lambda : 5)
78 self.assertEqual(ar.get(), [5]*n)
65 self.assertEqual(ar.get(), [5]*n)
79 d = ar.get_dict()
66 d = ar.get_dict()
80 self.assertEqual(sorted(d.keys()), sorted(self.client.ids))
67 self.assertEqual(sorted(d.keys()), sorted(self.client.ids))
81 for eid,r in iteritems(d):
68 for eid,r in iteritems(d):
82 self.assertEqual(r, 5)
69 self.assertEqual(r, 5)
83
70
84 def test_get_dict_single(self):
71 def test_get_dict_single(self):
85 view = self.client[-1]
72 view = self.client[-1]
86 for v in (list(range(5)), 5, ('abc', 'def'), 'string'):
73 for v in (list(range(5)), 5, ('abc', 'def'), 'string'):
87 ar = view.apply_async(echo, v)
74 ar = view.apply_async(echo, v)
88 self.assertEqual(ar.get(), v)
75 self.assertEqual(ar.get(), v)
89 d = ar.get_dict()
76 d = ar.get_dict()
90 self.assertEqual(d, {view.targets : v})
77 self.assertEqual(d, {view.targets : v})
91
78
92 def test_get_dict_bad(self):
79 def test_get_dict_bad(self):
93 ar = self.client[:].apply_async(lambda : 5)
80 ar = self.client[:].apply_async(lambda : 5)
94 ar2 = self.client[:].apply_async(lambda : 5)
81 ar2 = self.client[:].apply_async(lambda : 5)
95 ar = self.client.get_result(ar.msg_ids + ar2.msg_ids)
82 ar = self.client.get_result(ar.msg_ids + ar2.msg_ids)
96 self.assertRaises(ValueError, ar.get_dict)
83 self.assertRaises(ValueError, ar.get_dict)
97
84
98 def test_list_amr(self):
85 def test_list_amr(self):
99 ar = self.client.load_balanced_view().map_async(wait, [0.1]*5)
86 ar = self.client.load_balanced_view().map_async(wait, [0.1]*5)
100 rlist = list(ar)
87 rlist = list(ar)
101
88
102 def test_getattr(self):
89 def test_getattr(self):
103 ar = self.client[:].apply_async(wait, 0.5)
90 ar = self.client[:].apply_async(wait, 0.5)
104 self.assertEqual(ar.engine_id, [None] * len(ar))
91 self.assertEqual(ar.engine_id, [None] * len(ar))
105 self.assertRaises(AttributeError, lambda : ar._foo)
92 self.assertRaises(AttributeError, lambda : ar._foo)
106 self.assertRaises(AttributeError, lambda : ar.__length_hint__())
93 self.assertRaises(AttributeError, lambda : ar.__length_hint__())
107 self.assertRaises(AttributeError, lambda : ar.foo)
94 self.assertRaises(AttributeError, lambda : ar.foo)
108 self.assertFalse(hasattr(ar, '__length_hint__'))
95 self.assertFalse(hasattr(ar, '__length_hint__'))
109 self.assertFalse(hasattr(ar, 'foo'))
96 self.assertFalse(hasattr(ar, 'foo'))
110 self.assertTrue(hasattr(ar, 'engine_id'))
97 self.assertTrue(hasattr(ar, 'engine_id'))
111 ar.get(5)
98 ar.get(5)
112 self.assertRaises(AttributeError, lambda : ar._foo)
99 self.assertRaises(AttributeError, lambda : ar._foo)
113 self.assertRaises(AttributeError, lambda : ar.__length_hint__())
100 self.assertRaises(AttributeError, lambda : ar.__length_hint__())
114 self.assertRaises(AttributeError, lambda : ar.foo)
101 self.assertRaises(AttributeError, lambda : ar.foo)
115 self.assertTrue(isinstance(ar.engine_id, list))
102 self.assertTrue(isinstance(ar.engine_id, list))
116 self.assertEqual(ar.engine_id, ar['engine_id'])
103 self.assertEqual(ar.engine_id, ar['engine_id'])
117 self.assertFalse(hasattr(ar, '__length_hint__'))
104 self.assertFalse(hasattr(ar, '__length_hint__'))
118 self.assertFalse(hasattr(ar, 'foo'))
105 self.assertFalse(hasattr(ar, 'foo'))
119 self.assertTrue(hasattr(ar, 'engine_id'))
106 self.assertTrue(hasattr(ar, 'engine_id'))
120
107
121 def test_getitem(self):
108 def test_getitem(self):
122 ar = self.client[:].apply_async(wait, 0.5)
109 ar = self.client[:].apply_async(wait, 0.5)
123 self.assertEqual(ar['engine_id'], [None] * len(ar))
110 self.assertEqual(ar['engine_id'], [None] * len(ar))
124 self.assertRaises(KeyError, lambda : ar['foo'])
111 self.assertRaises(KeyError, lambda : ar['foo'])
125 ar.get(5)
112 ar.get(5)
126 self.assertRaises(KeyError, lambda : ar['foo'])
113 self.assertRaises(KeyError, lambda : ar['foo'])
127 self.assertTrue(isinstance(ar['engine_id'], list))
114 self.assertTrue(isinstance(ar['engine_id'], list))
128 self.assertEqual(ar.engine_id, ar['engine_id'])
115 self.assertEqual(ar.engine_id, ar['engine_id'])
129
116
130 def test_single_result(self):
117 def test_single_result(self):
131 ar = self.client[-1].apply_async(wait, 0.5)
118 ar = self.client[-1].apply_async(wait, 0.5)
132 self.assertRaises(KeyError, lambda : ar['foo'])
119 self.assertRaises(KeyError, lambda : ar['foo'])
133 self.assertEqual(ar['engine_id'], None)
120 self.assertEqual(ar['engine_id'], None)
134 self.assertTrue(ar.get(5) == 0.5)
121 self.assertTrue(ar.get(5) == 0.5)
135 self.assertTrue(isinstance(ar['engine_id'], int))
122 self.assertTrue(isinstance(ar['engine_id'], int))
136 self.assertTrue(isinstance(ar.engine_id, int))
123 self.assertTrue(isinstance(ar.engine_id, int))
137 self.assertEqual(ar.engine_id, ar['engine_id'])
124 self.assertEqual(ar.engine_id, ar['engine_id'])
138
125
139 def test_abort(self):
126 def test_abort(self):
140 e = self.client[-1]
127 e = self.client[-1]
141 ar = e.execute('import time; time.sleep(1)', block=False)
128 ar = e.execute('import time; time.sleep(1)', block=False)
142 ar2 = e.apply_async(lambda : 2)
129 ar2 = e.apply_async(lambda : 2)
143 ar2.abort()
130 ar2.abort()
144 self.assertRaises(error.TaskAborted, ar2.get)
131 self.assertRaises(error.TaskAborted, ar2.get)
145 ar.get()
132 ar.get()
146
133
147 def test_len(self):
134 def test_len(self):
148 v = self.client.load_balanced_view()
135 v = self.client.load_balanced_view()
149 ar = v.map_async(lambda x: x, list(range(10)))
136 ar = v.map_async(lambda x: x, list(range(10)))
150 self.assertEqual(len(ar), 10)
137 self.assertEqual(len(ar), 10)
151 ar = v.apply_async(lambda x: x, list(range(10)))
138 ar = v.apply_async(lambda x: x, list(range(10)))
152 self.assertEqual(len(ar), 1)
139 self.assertEqual(len(ar), 1)
153 ar = self.client[:].apply_async(lambda x: x, list(range(10)))
140 ar = self.client[:].apply_async(lambda x: x, list(range(10)))
154 self.assertEqual(len(ar), len(self.client.ids))
141 self.assertEqual(len(ar), len(self.client.ids))
155
142
156 def test_wall_time_single(self):
143 def test_wall_time_single(self):
157 v = self.client.load_balanced_view()
144 v = self.client.load_balanced_view()
158 ar = v.apply_async(time.sleep, 0.25)
145 ar = v.apply_async(time.sleep, 0.25)
159 self.assertRaises(TimeoutError, getattr, ar, 'wall_time')
146 self.assertRaises(TimeoutError, getattr, ar, 'wall_time')
160 ar.get(2)
147 ar.get(2)
161 self.assertTrue(ar.wall_time < 1.)
148 self.assertTrue(ar.wall_time < 1.)
162 self.assertTrue(ar.wall_time > 0.2)
149 self.assertTrue(ar.wall_time > 0.2)
163
150
164 def test_wall_time_multi(self):
151 def test_wall_time_multi(self):
165 self.minimum_engines(4)
152 self.minimum_engines(4)
166 v = self.client[:]
153 v = self.client[:]
167 ar = v.apply_async(time.sleep, 0.25)
154 ar = v.apply_async(time.sleep, 0.25)
168 self.assertRaises(TimeoutError, getattr, ar, 'wall_time')
155 self.assertRaises(TimeoutError, getattr, ar, 'wall_time')
169 ar.get(2)
156 ar.get(2)
170 self.assertTrue(ar.wall_time < 1.)
157 self.assertTrue(ar.wall_time < 1.)
171 self.assertTrue(ar.wall_time > 0.2)
158 self.assertTrue(ar.wall_time > 0.2)
172
159
173 def test_serial_time_single(self):
160 def test_serial_time_single(self):
174 v = self.client.load_balanced_view()
161 v = self.client.load_balanced_view()
175 ar = v.apply_async(time.sleep, 0.25)
162 ar = v.apply_async(time.sleep, 0.25)
176 self.assertRaises(TimeoutError, getattr, ar, 'serial_time')
163 self.assertRaises(TimeoutError, getattr, ar, 'serial_time')
177 ar.get(2)
164 ar.get(2)
178 self.assertTrue(ar.serial_time < 1.)
165 self.assertTrue(ar.serial_time < 1.)
179 self.assertTrue(ar.serial_time > 0.2)
166 self.assertTrue(ar.serial_time > 0.2)
180
167
181 def test_serial_time_multi(self):
168 def test_serial_time_multi(self):
182 self.minimum_engines(4)
169 self.minimum_engines(4)
183 v = self.client[:]
170 v = self.client[:]
184 ar = v.apply_async(time.sleep, 0.25)
171 ar = v.apply_async(time.sleep, 0.25)
185 self.assertRaises(TimeoutError, getattr, ar, 'serial_time')
172 self.assertRaises(TimeoutError, getattr, ar, 'serial_time')
186 ar.get(2)
173 ar.get(2)
187 self.assertTrue(ar.serial_time < 2.)
174 self.assertTrue(ar.serial_time < 2.)
188 self.assertTrue(ar.serial_time > 0.8)
175 self.assertTrue(ar.serial_time > 0.8)
189
176
190 def test_elapsed_single(self):
177 def test_elapsed_single(self):
191 v = self.client.load_balanced_view()
178 v = self.client.load_balanced_view()
192 ar = v.apply_async(time.sleep, 0.25)
179 ar = v.apply_async(time.sleep, 0.25)
193 while not ar.ready():
180 while not ar.ready():
194 time.sleep(0.01)
181 time.sleep(0.01)
195 self.assertTrue(ar.elapsed < 1)
182 self.assertTrue(ar.elapsed < 1)
196 self.assertTrue(ar.elapsed < 1)
183 self.assertTrue(ar.elapsed < 1)
197 ar.get(2)
184 ar.get(2)
198
185
199 def test_elapsed_multi(self):
186 def test_elapsed_multi(self):
200 v = self.client[:]
187 v = self.client[:]
201 ar = v.apply_async(time.sleep, 0.25)
188 ar = v.apply_async(time.sleep, 0.25)
202 while not ar.ready():
189 while not ar.ready():
203 time.sleep(0.01)
190 time.sleep(0.01)
204 self.assertTrue(ar.elapsed < 1)
191 self.assertTrue(ar.elapsed < 1)
205 self.assertTrue(ar.elapsed < 1)
192 self.assertTrue(ar.elapsed < 1)
206 ar.get(2)
193 ar.get(2)
207
194
208 def test_hubresult_timestamps(self):
195 def test_hubresult_timestamps(self):
209 self.minimum_engines(4)
196 self.minimum_engines(4)
210 v = self.client[:]
197 v = self.client[:]
211 ar = v.apply_async(time.sleep, 0.25)
198 ar = v.apply_async(time.sleep, 0.25)
212 ar.get(2)
199 ar.get(2)
213 rc2 = Client(profile='iptest')
200 rc2 = Client(profile='iptest')
214 # must have try/finally to close second Client, otherwise
201 # must have try/finally to close second Client, otherwise
215 # will have dangling sockets causing problems
202 # will have dangling sockets causing problems
216 try:
203 try:
217 time.sleep(0.25)
204 time.sleep(0.25)
218 hr = rc2.get_result(ar.msg_ids)
205 hr = rc2.get_result(ar.msg_ids)
219 self.assertTrue(hr.elapsed > 0., "got bad elapsed: %s" % hr.elapsed)
206 self.assertTrue(hr.elapsed > 0., "got bad elapsed: %s" % hr.elapsed)
220 hr.get(1)
207 hr.get(1)
221 self.assertTrue(hr.wall_time < ar.wall_time + 0.2, "got bad wall_time: %s > %s" % (hr.wall_time, ar.wall_time))
208 self.assertTrue(hr.wall_time < ar.wall_time + 0.2, "got bad wall_time: %s > %s" % (hr.wall_time, ar.wall_time))
222 self.assertEqual(hr.serial_time, ar.serial_time)
209 self.assertEqual(hr.serial_time, ar.serial_time)
223 finally:
210 finally:
224 rc2.close()
211 rc2.close()
225
212
226 def test_display_empty_streams_single(self):
213 def test_display_empty_streams_single(self):
227 """empty stdout/err are not displayed (single result)"""
214 """empty stdout/err are not displayed (single result)"""
228 self.minimum_engines(1)
215 self.minimum_engines(1)
229
216
230 v = self.client[-1]
217 v = self.client[-1]
231 ar = v.execute("print (5555)")
218 ar = v.execute("print (5555)")
232 ar.get(5)
219 ar.get(5)
233 with capture_output() as io:
220 with capture_output() as io:
234 ar.display_outputs()
221 ar.display_outputs()
235 self.assertEqual(io.stderr, '')
222 self.assertEqual(io.stderr, '')
236 self.assertEqual('5555\n', io.stdout)
223 self.assertEqual('5555\n', io.stdout)
237
224
238 ar = v.execute("a=5")
225 ar = v.execute("a=5")
239 ar.get(5)
226 ar.get(5)
240 with capture_output() as io:
227 with capture_output() as io:
241 ar.display_outputs()
228 ar.display_outputs()
242 self.assertEqual(io.stderr, '')
229 self.assertEqual(io.stderr, '')
243 self.assertEqual(io.stdout, '')
230 self.assertEqual(io.stdout, '')
244
231
245 def test_display_empty_streams_type(self):
232 def test_display_empty_streams_type(self):
246 """empty stdout/err are not displayed (groupby type)"""
233 """empty stdout/err are not displayed (groupby type)"""
247 self.minimum_engines(1)
234 self.minimum_engines(1)
248
235
249 v = self.client[:]
236 v = self.client[:]
250 ar = v.execute("print (5555)")
237 ar = v.execute("print (5555)")
251 ar.get(5)
238 ar.get(5)
252 with capture_output() as io:
239 with capture_output() as io:
253 ar.display_outputs()
240 ar.display_outputs()
254 self.assertEqual(io.stderr, '')
241 self.assertEqual(io.stderr, '')
255 self.assertEqual(io.stdout.count('5555'), len(v), io.stdout)
242 self.assertEqual(io.stdout.count('5555'), len(v), io.stdout)
256 self.assertFalse('\n\n' in io.stdout, io.stdout)
243 self.assertFalse('\n\n' in io.stdout, io.stdout)
257 self.assertEqual(io.stdout.count('[stdout:'), len(v), io.stdout)
244 self.assertEqual(io.stdout.count('[stdout:'), len(v), io.stdout)
258
245
259 ar = v.execute("a=5")
246 ar = v.execute("a=5")
260 ar.get(5)
247 ar.get(5)
261 with capture_output() as io:
248 with capture_output() as io:
262 ar.display_outputs()
249 ar.display_outputs()
263 self.assertEqual(io.stderr, '')
250 self.assertEqual(io.stderr, '')
264 self.assertEqual(io.stdout, '')
251 self.assertEqual(io.stdout, '')
265
252
266 def test_display_empty_streams_engine(self):
253 def test_display_empty_streams_engine(self):
267 """empty stdout/err are not displayed (groupby engine)"""
254 """empty stdout/err are not displayed (groupby engine)"""
268 self.minimum_engines(1)
255 self.minimum_engines(1)
269
256
270 v = self.client[:]
257 v = self.client[:]
271 ar = v.execute("print (5555)")
258 ar = v.execute("print (5555)")
272 ar.get(5)
259 ar.get(5)
273 with capture_output() as io:
260 with capture_output() as io:
274 ar.display_outputs('engine')
261 ar.display_outputs('engine')
275 self.assertEqual(io.stderr, '')
262 self.assertEqual(io.stderr, '')
276 self.assertEqual(io.stdout.count('5555'), len(v), io.stdout)
263 self.assertEqual(io.stdout.count('5555'), len(v), io.stdout)
277 self.assertFalse('\n\n' in io.stdout, io.stdout)
264 self.assertFalse('\n\n' in io.stdout, io.stdout)
278 self.assertEqual(io.stdout.count('[stdout:'), len(v), io.stdout)
265 self.assertEqual(io.stdout.count('[stdout:'), len(v), io.stdout)
279
266
280 ar = v.execute("a=5")
267 ar = v.execute("a=5")
281 ar.get(5)
268 ar.get(5)
282 with capture_output() as io:
269 with capture_output() as io:
283 ar.display_outputs('engine')
270 ar.display_outputs('engine')
284 self.assertEqual(io.stderr, '')
271 self.assertEqual(io.stderr, '')
285 self.assertEqual(io.stdout, '')
272 self.assertEqual(io.stdout, '')
286
273
287 def test_await_data(self):
274 def test_await_data(self):
288 """asking for ar.data flushes outputs"""
275 """asking for ar.data flushes outputs"""
289 self.minimum_engines(1)
276 self.minimum_engines(1)
290
277
291 v = self.client[-1]
278 v = self.client[-1]
292 ar = v.execute('\n'.join([
279 ar = v.execute('\n'.join([
293 "import time",
280 "import time",
294 "from IPython.kernel.zmq.datapub import publish_data",
281 "from IPython.kernel.zmq.datapub import publish_data",
295 "for i in range(5):",
282 "for i in range(5):",
296 " publish_data(dict(i=i))",
283 " publish_data(dict(i=i))",
297 " time.sleep(0.1)",
284 " time.sleep(0.1)",
298 ]), block=False)
285 ]), block=False)
299 found = set()
286 found = set()
300 tic = time.time()
287 tic = time.time()
301 # timeout after 10s
288 # timeout after 10s
302 while time.time() <= tic + 10:
289 while time.time() <= tic + 10:
303 if ar.data:
290 if ar.data:
304 i = ar.data['i']
291 i = ar.data['i']
305 found.add(i)
292 found.add(i)
306 if i == 4:
293 if i == 4:
307 break
294 break
308 time.sleep(0.05)
295 time.sleep(0.05)
309
296
310 ar.get(5)
297 ar.get(5)
311 nt.assert_in(4, found)
298 nt.assert_in(4, found)
312 self.assertTrue(len(found) > 1, "should have seen data multiple times, but got: %s" % found)
299 self.assertTrue(len(found) > 1, "should have seen data multiple times, but got: %s" % found)
313
300
314 def test_not_single_result(self):
301 def test_not_single_result(self):
315 save_build = self.client._build_targets
302 save_build = self.client._build_targets
316 def single_engine(*a, **kw):
303 def single_engine(*a, **kw):
317 idents, targets = save_build(*a, **kw)
304 idents, targets = save_build(*a, **kw)
318 return idents[:1], targets[:1]
305 return idents[:1], targets[:1]
319 ids = single_engine('all')[1]
306 ids = single_engine('all')[1]
320 self.client._build_targets = single_engine
307 self.client._build_targets = single_engine
321 for targets in ('all', None, ids):
308 for targets in ('all', None, ids):
322 dv = self.client.direct_view(targets=targets)
309 dv = self.client.direct_view(targets=targets)
323 ar = dv.apply_async(lambda : 5)
310 ar = dv.apply_async(lambda : 5)
324 self.assertEqual(ar.get(10), [5])
311 self.assertEqual(ar.get(10), [5])
325 self.client._build_targets = save_build
312 self.client._build_targets = save_build
326
313
314 def test_owner_pop(self):
315 self.minimum_engines(1)
316
317 view = self.client[-1]
318 ar = view.apply_async(lambda : 1)
319 ar.get()
320 msg_id = ar.msg_ids[0]
321 self.assertNotIn(msg_id, self.client.results)
322 self.assertNotIn(msg_id, self.client.metadata)
323
324 def test_non_owner(self):
325 self.minimum_engines(1)
326
327 view = self.client[-1]
328 ar = view.apply_async(lambda : 1)
329 ar.owner = False
330 ar.get()
331 msg_id = ar.msg_ids[0]
332 self.assertIn(msg_id, self.client.results)
333 self.assertIn(msg_id, self.client.metadata)
334 ar2 = self.client.get_result(msg_id, owner=True)
335 self.assertIs(type(ar2), type(ar))
336 self.assertTrue(ar2.owner)
337 self.assertEqual(ar.get(), ar2.get())
338 ar2.get()
339 self.assertNotIn(msg_id, self.client.results)
340 self.assertNotIn(msg_id, self.client.metadata)
341
327
342
@@ -1,547 +1,550 b''
1 """Tests for parallel client.py"""
1 """Tests for parallel client.py"""
2
2
3 # Copyright (c) IPython Development Team.
3 # Copyright (c) IPython Development Team.
4 # Distributed under the terms of the Modified BSD License.
4 # Distributed under the terms of the Modified BSD License.
5
5
6 from __future__ import division
6 from __future__ import division
7
7
8 import time
8 import time
9 from datetime import datetime
9 from datetime import datetime
10
10
11 import zmq
11 import zmq
12
12
13 from IPython import parallel
13 from IPython import parallel
14 from IPython.parallel.client import client as clientmod
14 from IPython.parallel.client import client as clientmod
15 from IPython.parallel import error
15 from IPython.parallel import error
16 from IPython.parallel import AsyncResult, AsyncHubResult
16 from IPython.parallel import AsyncResult, AsyncHubResult
17 from IPython.parallel import LoadBalancedView, DirectView
17 from IPython.parallel import LoadBalancedView, DirectView
18
18
19 from .clienttest import ClusterTestCase, segfault, wait, add_engines
19 from .clienttest import ClusterTestCase, segfault, wait, add_engines
20
20
21 def setup():
21 def setup():
22 add_engines(4, total=True)
22 add_engines(4, total=True)
23
23
24 class TestClient(ClusterTestCase):
24 class TestClient(ClusterTestCase):
25
25
26 def test_ids(self):
26 def test_ids(self):
27 n = len(self.client.ids)
27 n = len(self.client.ids)
28 self.add_engines(2)
28 self.add_engines(2)
29 self.assertEqual(len(self.client.ids), n+2)
29 self.assertEqual(len(self.client.ids), n+2)
30
30
31 def test_iter(self):
31 def test_iter(self):
32 self.minimum_engines(4)
32 self.minimum_engines(4)
33 engine_ids = [ view.targets for view in self.client ]
33 engine_ids = [ view.targets for view in self.client ]
34 self.assertEqual(engine_ids, self.client.ids)
34 self.assertEqual(engine_ids, self.client.ids)
35
35
36 def test_view_indexing(self):
36 def test_view_indexing(self):
37 """test index access for views"""
37 """test index access for views"""
38 self.minimum_engines(4)
38 self.minimum_engines(4)
39 targets = self.client._build_targets('all')[-1]
39 targets = self.client._build_targets('all')[-1]
40 v = self.client[:]
40 v = self.client[:]
41 self.assertEqual(v.targets, targets)
41 self.assertEqual(v.targets, targets)
42 t = self.client.ids[2]
42 t = self.client.ids[2]
43 v = self.client[t]
43 v = self.client[t]
44 self.assertTrue(isinstance(v, DirectView))
44 self.assertTrue(isinstance(v, DirectView))
45 self.assertEqual(v.targets, t)
45 self.assertEqual(v.targets, t)
46 t = self.client.ids[2:4]
46 t = self.client.ids[2:4]
47 v = self.client[t]
47 v = self.client[t]
48 self.assertTrue(isinstance(v, DirectView))
48 self.assertTrue(isinstance(v, DirectView))
49 self.assertEqual(v.targets, t)
49 self.assertEqual(v.targets, t)
50 v = self.client[::2]
50 v = self.client[::2]
51 self.assertTrue(isinstance(v, DirectView))
51 self.assertTrue(isinstance(v, DirectView))
52 self.assertEqual(v.targets, targets[::2])
52 self.assertEqual(v.targets, targets[::2])
53 v = self.client[1::3]
53 v = self.client[1::3]
54 self.assertTrue(isinstance(v, DirectView))
54 self.assertTrue(isinstance(v, DirectView))
55 self.assertEqual(v.targets, targets[1::3])
55 self.assertEqual(v.targets, targets[1::3])
56 v = self.client[:-3]
56 v = self.client[:-3]
57 self.assertTrue(isinstance(v, DirectView))
57 self.assertTrue(isinstance(v, DirectView))
58 self.assertEqual(v.targets, targets[:-3])
58 self.assertEqual(v.targets, targets[:-3])
59 v = self.client[-1]
59 v = self.client[-1]
60 self.assertTrue(isinstance(v, DirectView))
60 self.assertTrue(isinstance(v, DirectView))
61 self.assertEqual(v.targets, targets[-1])
61 self.assertEqual(v.targets, targets[-1])
62 self.assertRaises(TypeError, lambda : self.client[None])
62 self.assertRaises(TypeError, lambda : self.client[None])
63
63
64 def test_lbview_targets(self):
64 def test_lbview_targets(self):
65 """test load_balanced_view targets"""
65 """test load_balanced_view targets"""
66 v = self.client.load_balanced_view()
66 v = self.client.load_balanced_view()
67 self.assertEqual(v.targets, None)
67 self.assertEqual(v.targets, None)
68 v = self.client.load_balanced_view(-1)
68 v = self.client.load_balanced_view(-1)
69 self.assertEqual(v.targets, [self.client.ids[-1]])
69 self.assertEqual(v.targets, [self.client.ids[-1]])
70 v = self.client.load_balanced_view('all')
70 v = self.client.load_balanced_view('all')
71 self.assertEqual(v.targets, None)
71 self.assertEqual(v.targets, None)
72
72
73 def test_dview_targets(self):
73 def test_dview_targets(self):
74 """test direct_view targets"""
74 """test direct_view targets"""
75 v = self.client.direct_view()
75 v = self.client.direct_view()
76 self.assertEqual(v.targets, 'all')
76 self.assertEqual(v.targets, 'all')
77 v = self.client.direct_view('all')
77 v = self.client.direct_view('all')
78 self.assertEqual(v.targets, 'all')
78 self.assertEqual(v.targets, 'all')
79 v = self.client.direct_view(-1)
79 v = self.client.direct_view(-1)
80 self.assertEqual(v.targets, self.client.ids[-1])
80 self.assertEqual(v.targets, self.client.ids[-1])
81
81
82 def test_lazy_all_targets(self):
82 def test_lazy_all_targets(self):
83 """test lazy evaluation of rc.direct_view('all')"""
83 """test lazy evaluation of rc.direct_view('all')"""
84 v = self.client.direct_view()
84 v = self.client.direct_view()
85 self.assertEqual(v.targets, 'all')
85 self.assertEqual(v.targets, 'all')
86
86
87 def double(x):
87 def double(x):
88 return x*2
88 return x*2
89 seq = list(range(100))
89 seq = list(range(100))
90 ref = [ double(x) for x in seq ]
90 ref = [ double(x) for x in seq ]
91
91
92 # add some engines, which should be used
92 # add some engines, which should be used
93 self.add_engines(1)
93 self.add_engines(1)
94 n1 = len(self.client.ids)
94 n1 = len(self.client.ids)
95
95
96 # simple apply
96 # simple apply
97 r = v.apply_sync(lambda : 1)
97 r = v.apply_sync(lambda : 1)
98 self.assertEqual(r, [1] * n1)
98 self.assertEqual(r, [1] * n1)
99
99
100 # map goes through remotefunction
100 # map goes through remotefunction
101 r = v.map_sync(double, seq)
101 r = v.map_sync(double, seq)
102 self.assertEqual(r, ref)
102 self.assertEqual(r, ref)
103
103
104 # add a couple more engines, and try again
104 # add a couple more engines, and try again
105 self.add_engines(2)
105 self.add_engines(2)
106 n2 = len(self.client.ids)
106 n2 = len(self.client.ids)
107 self.assertNotEqual(n2, n1)
107 self.assertNotEqual(n2, n1)
108
108
109 # apply
109 # apply
110 r = v.apply_sync(lambda : 1)
110 r = v.apply_sync(lambda : 1)
111 self.assertEqual(r, [1] * n2)
111 self.assertEqual(r, [1] * n2)
112
112
113 # map
113 # map
114 r = v.map_sync(double, seq)
114 r = v.map_sync(double, seq)
115 self.assertEqual(r, ref)
115 self.assertEqual(r, ref)
116
116
117 def test_targets(self):
117 def test_targets(self):
118 """test various valid targets arguments"""
118 """test various valid targets arguments"""
119 build = self.client._build_targets
119 build = self.client._build_targets
120 ids = self.client.ids
120 ids = self.client.ids
121 idents,targets = build(None)
121 idents,targets = build(None)
122 self.assertEqual(ids, targets)
122 self.assertEqual(ids, targets)
123
123
124 def test_clear(self):
124 def test_clear(self):
125 """test clear behavior"""
125 """test clear behavior"""
126 self.minimum_engines(2)
126 self.minimum_engines(2)
127 v = self.client[:]
127 v = self.client[:]
128 v.block=True
128 v.block=True
129 v.push(dict(a=5))
129 v.push(dict(a=5))
130 v.pull('a')
130 v.pull('a')
131 id0 = self.client.ids[-1]
131 id0 = self.client.ids[-1]
132 self.client.clear(targets=id0, block=True)
132 self.client.clear(targets=id0, block=True)
133 a = self.client[:-1].get('a')
133 a = self.client[:-1].get('a')
134 self.assertRaisesRemote(NameError, self.client[id0].get, 'a')
134 self.assertRaisesRemote(NameError, self.client[id0].get, 'a')
135 self.client.clear(block=True)
135 self.client.clear(block=True)
136 for i in self.client.ids:
136 for i in self.client.ids:
137 self.assertRaisesRemote(NameError, self.client[i].get, 'a')
137 self.assertRaisesRemote(NameError, self.client[i].get, 'a')
138
138
139 def test_get_result(self):
139 def test_get_result(self):
140 """test getting results from the Hub."""
140 """test getting results from the Hub."""
141 c = clientmod.Client(profile='iptest')
141 c = clientmod.Client(profile='iptest')
142 t = c.ids[-1]
142 t = c.ids[-1]
143 ar = c[t].apply_async(wait, 1)
143 ar = c[t].apply_async(wait, 1)
144 # give the monitor time to notice the message
144 # give the monitor time to notice the message
145 time.sleep(.25)
145 time.sleep(.25)
146 ahr = self.client.get_result(ar.msg_ids[0])
146 ahr = self.client.get_result(ar.msg_ids[0], owner=False)
147 self.assertTrue(isinstance(ahr, AsyncHubResult))
147 self.assertIsInstance(ahr, AsyncHubResult)
148 self.assertEqual(ahr.get(), ar.get())
148 self.assertEqual(ahr.get(), ar.get())
149 ar2 = self.client.get_result(ar.msg_ids[0])
149 ar2 = self.client.get_result(ar.msg_ids[0])
150 self.assertFalse(isinstance(ar2, AsyncHubResult))
150 self.assertNotIsInstance(ar2, AsyncHubResult)
151 self.assertEqual(ahr.get(), ar2.get())
151 c.close()
152 c.close()
152
153
153 def test_get_execute_result(self):
154 def test_get_execute_result(self):
154 """test getting execute results from the Hub."""
155 """test getting execute results from the Hub."""
155 c = clientmod.Client(profile='iptest')
156 c = clientmod.Client(profile='iptest')
156 t = c.ids[-1]
157 t = c.ids[-1]
157 cell = '\n'.join([
158 cell = '\n'.join([
158 'import time',
159 'import time',
159 'time.sleep(0.25)',
160 'time.sleep(0.25)',
160 '5'
161 '5'
161 ])
162 ])
162 ar = c[t].execute("import time; time.sleep(1)", silent=False)
163 ar = c[t].execute("import time; time.sleep(1)", silent=False)
163 # give the monitor time to notice the message
164 # give the monitor time to notice the message
164 time.sleep(.25)
165 time.sleep(.25)
165 ahr = self.client.get_result(ar.msg_ids[0])
166 ahr = self.client.get_result(ar.msg_ids[0], owner=False)
166 self.assertTrue(isinstance(ahr, AsyncHubResult))
167 self.assertIsInstance(ahr, AsyncHubResult)
167 self.assertEqual(ahr.get().execute_result, ar.get().execute_result)
168 self.assertEqual(ahr.get().execute_result, ar.get().execute_result)
168 ar2 = self.client.get_result(ar.msg_ids[0])
169 ar2 = self.client.get_result(ar.msg_ids[0])
169 self.assertFalse(isinstance(ar2, AsyncHubResult))
170 self.assertNotIsInstance(ar2, AsyncHubResult)
171 self.assertEqual(ahr.get(), ar2.get())
170 c.close()
172 c.close()
171
173
172 def test_ids_list(self):
174 def test_ids_list(self):
173 """test client.ids"""
175 """test client.ids"""
174 ids = self.client.ids
176 ids = self.client.ids
175 self.assertEqual(ids, self.client._ids)
177 self.assertEqual(ids, self.client._ids)
176 self.assertFalse(ids is self.client._ids)
178 self.assertFalse(ids is self.client._ids)
177 ids.remove(ids[-1])
179 ids.remove(ids[-1])
178 self.assertNotEqual(ids, self.client._ids)
180 self.assertNotEqual(ids, self.client._ids)
179
181
180 def test_queue_status(self):
182 def test_queue_status(self):
181 ids = self.client.ids
183 ids = self.client.ids
182 id0 = ids[0]
184 id0 = ids[0]
183 qs = self.client.queue_status(targets=id0)
185 qs = self.client.queue_status(targets=id0)
184 self.assertTrue(isinstance(qs, dict))
186 self.assertTrue(isinstance(qs, dict))
185 self.assertEqual(sorted(qs.keys()), ['completed', 'queue', 'tasks'])
187 self.assertEqual(sorted(qs.keys()), ['completed', 'queue', 'tasks'])
186 allqs = self.client.queue_status()
188 allqs = self.client.queue_status()
187 self.assertTrue(isinstance(allqs, dict))
189 self.assertTrue(isinstance(allqs, dict))
188 intkeys = list(allqs.keys())
190 intkeys = list(allqs.keys())
189 intkeys.remove('unassigned')
191 intkeys.remove('unassigned')
190 print("intkeys", intkeys)
192 print("intkeys", intkeys)
191 intkeys = sorted(intkeys)
193 intkeys = sorted(intkeys)
192 ids = self.client.ids
194 ids = self.client.ids
193 print("client.ids", ids)
195 print("client.ids", ids)
194 ids = sorted(self.client.ids)
196 ids = sorted(self.client.ids)
195 self.assertEqual(intkeys, ids)
197 self.assertEqual(intkeys, ids)
196 unassigned = allqs.pop('unassigned')
198 unassigned = allqs.pop('unassigned')
197 for eid,qs in allqs.items():
199 for eid,qs in allqs.items():
198 self.assertTrue(isinstance(qs, dict))
200 self.assertTrue(isinstance(qs, dict))
199 self.assertEqual(sorted(qs.keys()), ['completed', 'queue', 'tasks'])
201 self.assertEqual(sorted(qs.keys()), ['completed', 'queue', 'tasks'])
200
202
201 def test_shutdown(self):
203 def test_shutdown(self):
202 ids = self.client.ids
204 ids = self.client.ids
203 id0 = ids[0]
205 id0 = ids[0]
204 self.client.shutdown(id0, block=True)
206 self.client.shutdown(id0, block=True)
205 while id0 in self.client.ids:
207 while id0 in self.client.ids:
206 time.sleep(0.1)
208 time.sleep(0.1)
207 self.client.spin()
209 self.client.spin()
208
210
209 self.assertRaises(IndexError, lambda : self.client[id0])
211 self.assertRaises(IndexError, lambda : self.client[id0])
210
212
211 def test_result_status(self):
213 def test_result_status(self):
212 pass
214 pass
213 # to be written
215 # to be written
214
216
215 def test_db_query_dt(self):
217 def test_db_query_dt(self):
216 """test db query by date"""
218 """test db query by date"""
217 hist = self.client.hub_history()
219 hist = self.client.hub_history()
218 middle = self.client.db_query({'msg_id' : hist[len(hist)//2]})[0]
220 middle = self.client.db_query({'msg_id' : hist[len(hist)//2]})[0]
219 tic = middle['submitted']
221 tic = middle['submitted']
220 before = self.client.db_query({'submitted' : {'$lt' : tic}})
222 before = self.client.db_query({'submitted' : {'$lt' : tic}})
221 after = self.client.db_query({'submitted' : {'$gte' : tic}})
223 after = self.client.db_query({'submitted' : {'$gte' : tic}})
222 self.assertEqual(len(before)+len(after),len(hist))
224 self.assertEqual(len(before)+len(after),len(hist))
223 for b in before:
225 for b in before:
224 self.assertTrue(b['submitted'] < tic)
226 self.assertTrue(b['submitted'] < tic)
225 for a in after:
227 for a in after:
226 self.assertTrue(a['submitted'] >= tic)
228 self.assertTrue(a['submitted'] >= tic)
227 same = self.client.db_query({'submitted' : tic})
229 same = self.client.db_query({'submitted' : tic})
228 for s in same:
230 for s in same:
229 self.assertTrue(s['submitted'] == tic)
231 self.assertTrue(s['submitted'] == tic)
230
232
231 def test_db_query_keys(self):
233 def test_db_query_keys(self):
232 """test extracting subset of record keys"""
234 """test extracting subset of record keys"""
233 found = self.client.db_query({'msg_id': {'$ne' : ''}},keys=['submitted', 'completed'])
235 found = self.client.db_query({'msg_id': {'$ne' : ''}},keys=['submitted', 'completed'])
234 for rec in found:
236 for rec in found:
235 self.assertEqual(set(rec.keys()), set(['msg_id', 'submitted', 'completed']))
237 self.assertEqual(set(rec.keys()), set(['msg_id', 'submitted', 'completed']))
236
238
237 def test_db_query_default_keys(self):
239 def test_db_query_default_keys(self):
238 """default db_query excludes buffers"""
240 """default db_query excludes buffers"""
239 found = self.client.db_query({'msg_id': {'$ne' : ''}})
241 found = self.client.db_query({'msg_id': {'$ne' : ''}})
240 for rec in found:
242 for rec in found:
241 keys = set(rec.keys())
243 keys = set(rec.keys())
242 self.assertFalse('buffers' in keys, "'buffers' should not be in: %s" % keys)
244 self.assertFalse('buffers' in keys, "'buffers' should not be in: %s" % keys)
243 self.assertFalse('result_buffers' in keys, "'result_buffers' should not be in: %s" % keys)
245 self.assertFalse('result_buffers' in keys, "'result_buffers' should not be in: %s" % keys)
244
246
245 def test_db_query_msg_id(self):
247 def test_db_query_msg_id(self):
246 """ensure msg_id is always in db queries"""
248 """ensure msg_id is always in db queries"""
247 found = self.client.db_query({'msg_id': {'$ne' : ''}},keys=['submitted', 'completed'])
249 found = self.client.db_query({'msg_id': {'$ne' : ''}},keys=['submitted', 'completed'])
248 for rec in found:
250 for rec in found:
249 self.assertTrue('msg_id' in rec.keys())
251 self.assertTrue('msg_id' in rec.keys())
250 found = self.client.db_query({'msg_id': {'$ne' : ''}},keys=['submitted'])
252 found = self.client.db_query({'msg_id': {'$ne' : ''}},keys=['submitted'])
251 for rec in found:
253 for rec in found:
252 self.assertTrue('msg_id' in rec.keys())
254 self.assertTrue('msg_id' in rec.keys())
253 found = self.client.db_query({'msg_id': {'$ne' : ''}},keys=['msg_id'])
255 found = self.client.db_query({'msg_id': {'$ne' : ''}},keys=['msg_id'])
254 for rec in found:
256 for rec in found:
255 self.assertTrue('msg_id' in rec.keys())
257 self.assertTrue('msg_id' in rec.keys())
256
258
257 def test_db_query_get_result(self):
259 def test_db_query_get_result(self):
258 """pop in db_query shouldn't pop from result itself"""
260 """pop in db_query shouldn't pop from result itself"""
259 self.client[:].apply_sync(lambda : 1)
261 self.client[:].apply_sync(lambda : 1)
260 found = self.client.db_query({'msg_id': {'$ne' : ''}})
262 found = self.client.db_query({'msg_id': {'$ne' : ''}})
261 rc2 = clientmod.Client(profile='iptest')
263 rc2 = clientmod.Client(profile='iptest')
262 # If this bug is not fixed, this call will hang:
264 # If this bug is not fixed, this call will hang:
263 ar = rc2.get_result(self.client.history[-1])
265 ar = rc2.get_result(self.client.history[-1])
264 ar.wait(2)
266 ar.wait(2)
265 self.assertTrue(ar.ready())
267 self.assertTrue(ar.ready())
266 ar.get()
268 ar.get()
267 rc2.close()
269 rc2.close()
268
270
269 def test_db_query_in(self):
271 def test_db_query_in(self):
270 """test db query with '$in','$nin' operators"""
272 """test db query with '$in','$nin' operators"""
271 hist = self.client.hub_history()
273 hist = self.client.hub_history()
272 even = hist[::2]
274 even = hist[::2]
273 odd = hist[1::2]
275 odd = hist[1::2]
274 recs = self.client.db_query({ 'msg_id' : {'$in' : even}})
276 recs = self.client.db_query({ 'msg_id' : {'$in' : even}})
275 found = [ r['msg_id'] for r in recs ]
277 found = [ r['msg_id'] for r in recs ]
276 self.assertEqual(set(even), set(found))
278 self.assertEqual(set(even), set(found))
277 recs = self.client.db_query({ 'msg_id' : {'$nin' : even}})
279 recs = self.client.db_query({ 'msg_id' : {'$nin' : even}})
278 found = [ r['msg_id'] for r in recs ]
280 found = [ r['msg_id'] for r in recs ]
279 self.assertEqual(set(odd), set(found))
281 self.assertEqual(set(odd), set(found))
280
282
281 def test_hub_history(self):
283 def test_hub_history(self):
282 hist = self.client.hub_history()
284 hist = self.client.hub_history()
283 recs = self.client.db_query({ 'msg_id' : {"$ne":''}})
285 recs = self.client.db_query({ 'msg_id' : {"$ne":''}})
284 recdict = {}
286 recdict = {}
285 for rec in recs:
287 for rec in recs:
286 recdict[rec['msg_id']] = rec
288 recdict[rec['msg_id']] = rec
287
289
288 latest = datetime(1984,1,1)
290 latest = datetime(1984,1,1)
289 for msg_id in hist:
291 for msg_id in hist:
290 rec = recdict[msg_id]
292 rec = recdict[msg_id]
291 newt = rec['submitted']
293 newt = rec['submitted']
292 self.assertTrue(newt >= latest)
294 self.assertTrue(newt >= latest)
293 latest = newt
295 latest = newt
294 ar = self.client[-1].apply_async(lambda : 1)
296 ar = self.client[-1].apply_async(lambda : 1)
295 ar.get()
297 ar.get()
296 time.sleep(0.25)
298 time.sleep(0.25)
297 self.assertEqual(self.client.hub_history()[-1:],ar.msg_ids)
299 self.assertEqual(self.client.hub_history()[-1:],ar.msg_ids)
298
300
299 def _wait_for_idle(self):
301 def _wait_for_idle(self):
300 """wait for the cluster to become idle, according to the everyone."""
302 """wait for the cluster to become idle, according to the everyone."""
301 rc = self.client
303 rc = self.client
302
304
303 # step 0. wait for local results
305 # step 0. wait for local results
304 # this should be sufficient 99% of the time.
306 # this should be sufficient 99% of the time.
305 rc.wait(timeout=5)
307 rc.wait(timeout=5)
306
308
307 # step 1. wait for all requests to be noticed
309 # step 1. wait for all requests to be noticed
308 # timeout 5s, polling every 100ms
310 # timeout 5s, polling every 100ms
309 msg_ids = set(rc.history)
311 msg_ids = set(rc.history)
310 hub_hist = rc.hub_history()
312 hub_hist = rc.hub_history()
311 for i in range(50):
313 for i in range(50):
312 if msg_ids.difference(hub_hist):
314 if msg_ids.difference(hub_hist):
313 time.sleep(0.1)
315 time.sleep(0.1)
314 hub_hist = rc.hub_history()
316 hub_hist = rc.hub_history()
315 else:
317 else:
316 break
318 break
317
319
318 self.assertEqual(len(msg_ids.difference(hub_hist)), 0)
320 self.assertEqual(len(msg_ids.difference(hub_hist)), 0)
319
321
320 # step 2. wait for all requests to be done
322 # step 2. wait for all requests to be done
321 # timeout 5s, polling every 100ms
323 # timeout 5s, polling every 100ms
322 qs = rc.queue_status()
324 qs = rc.queue_status()
323 for i in range(50):
325 for i in range(50):
324 if qs['unassigned'] or any(qs[eid]['tasks'] + qs[eid]['queue'] for eid in qs if eid != 'unassigned'):
326 if qs['unassigned'] or any(qs[eid]['tasks'] + qs[eid]['queue'] for eid in qs if eid != 'unassigned'):
325 time.sleep(0.1)
327 time.sleep(0.1)
326 qs = rc.queue_status()
328 qs = rc.queue_status()
327 else:
329 else:
328 break
330 break
329
331
330 # ensure Hub up to date:
332 # ensure Hub up to date:
331 self.assertEqual(qs['unassigned'], 0)
333 self.assertEqual(qs['unassigned'], 0)
332 for eid in [ eid for eid in qs if eid != 'unassigned' ]:
334 for eid in [ eid for eid in qs if eid != 'unassigned' ]:
333 self.assertEqual(qs[eid]['tasks'], 0)
335 self.assertEqual(qs[eid]['tasks'], 0)
334 self.assertEqual(qs[eid]['queue'], 0)
336 self.assertEqual(qs[eid]['queue'], 0)
335
337
336
338
337 def test_resubmit(self):
339 def test_resubmit(self):
338 def f():
340 def f():
339 import random
341 import random
340 return random.random()
342 return random.random()
341 v = self.client.load_balanced_view()
343 v = self.client.load_balanced_view()
342 ar = v.apply_async(f)
344 ar = v.apply_async(f)
343 r1 = ar.get(1)
345 r1 = ar.get(1)
344 # give the Hub a chance to notice:
346 # give the Hub a chance to notice:
345 self._wait_for_idle()
347 self._wait_for_idle()
346 ahr = self.client.resubmit(ar.msg_ids)
348 ahr = self.client.resubmit(ar.msg_ids)
347 r2 = ahr.get(1)
349 r2 = ahr.get(1)
348 self.assertFalse(r1 == r2)
350 self.assertFalse(r1 == r2)
349
351
350 def test_resubmit_chain(self):
352 def test_resubmit_chain(self):
351 """resubmit resubmitted tasks"""
353 """resubmit resubmitted tasks"""
352 v = self.client.load_balanced_view()
354 v = self.client.load_balanced_view()
353 ar = v.apply_async(lambda x: x, 'x'*1024)
355 ar = v.apply_async(lambda x: x, 'x'*1024)
354 ar.get()
356 ar.get()
355 self._wait_for_idle()
357 self._wait_for_idle()
356 ars = [ar]
358 ars = [ar]
357
359
358 for i in range(10):
360 for i in range(10):
359 ar = ars[-1]
361 ar = ars[-1]
360 ar2 = self.client.resubmit(ar.msg_ids)
362 ar2 = self.client.resubmit(ar.msg_ids)
361
363
362 [ ar.get() for ar in ars ]
364 [ ar.get() for ar in ars ]
363
365
364 def test_resubmit_header(self):
366 def test_resubmit_header(self):
365 """resubmit shouldn't clobber the whole header"""
367 """resubmit shouldn't clobber the whole header"""
366 def f():
368 def f():
367 import random
369 import random
368 return random.random()
370 return random.random()
369 v = self.client.load_balanced_view()
371 v = self.client.load_balanced_view()
370 v.retries = 1
372 v.retries = 1
371 ar = v.apply_async(f)
373 ar = v.apply_async(f)
372 r1 = ar.get(1)
374 r1 = ar.get(1)
373 # give the Hub a chance to notice:
375 # give the Hub a chance to notice:
374 self._wait_for_idle()
376 self._wait_for_idle()
375 ahr = self.client.resubmit(ar.msg_ids)
377 ahr = self.client.resubmit(ar.msg_ids)
376 ahr.get(1)
378 ahr.get(1)
377 time.sleep(0.5)
379 time.sleep(0.5)
378 records = self.client.db_query({'msg_id': {'$in': ar.msg_ids + ahr.msg_ids}}, keys='header')
380 records = self.client.db_query({'msg_id': {'$in': ar.msg_ids + ahr.msg_ids}}, keys='header')
379 h1,h2 = [ r['header'] for r in records ]
381 h1,h2 = [ r['header'] for r in records ]
380 for key in set(h1.keys()).union(set(h2.keys())):
382 for key in set(h1.keys()).union(set(h2.keys())):
381 if key in ('msg_id', 'date'):
383 if key in ('msg_id', 'date'):
382 self.assertNotEqual(h1[key], h2[key])
384 self.assertNotEqual(h1[key], h2[key])
383 else:
385 else:
384 self.assertEqual(h1[key], h2[key])
386 self.assertEqual(h1[key], h2[key])
385
387
386 def test_resubmit_aborted(self):
388 def test_resubmit_aborted(self):
387 def f():
389 def f():
388 import random
390 import random
389 return random.random()
391 return random.random()
390 v = self.client.load_balanced_view()
392 v = self.client.load_balanced_view()
391 # restrict to one engine, so we can put a sleep
393 # restrict to one engine, so we can put a sleep
392 # ahead of the task, so it will get aborted
394 # ahead of the task, so it will get aborted
393 eid = self.client.ids[-1]
395 eid = self.client.ids[-1]
394 v.targets = [eid]
396 v.targets = [eid]
395 sleep = v.apply_async(time.sleep, 0.5)
397 sleep = v.apply_async(time.sleep, 0.5)
396 ar = v.apply_async(f)
398 ar = v.apply_async(f)
397 ar.abort()
399 ar.abort()
398 self.assertRaises(error.TaskAborted, ar.get)
400 self.assertRaises(error.TaskAborted, ar.get)
399 # Give the Hub a chance to get up to date:
401 # Give the Hub a chance to get up to date:
400 self._wait_for_idle()
402 self._wait_for_idle()
401 ahr = self.client.resubmit(ar.msg_ids)
403 ahr = self.client.resubmit(ar.msg_ids)
402 r2 = ahr.get(1)
404 r2 = ahr.get(1)
403
405
404 def test_resubmit_inflight(self):
406 def test_resubmit_inflight(self):
405 """resubmit of inflight task"""
407 """resubmit of inflight task"""
406 v = self.client.load_balanced_view()
408 v = self.client.load_balanced_view()
407 ar = v.apply_async(time.sleep,1)
409 ar = v.apply_async(time.sleep,1)
408 # give the message a chance to arrive
410 # give the message a chance to arrive
409 time.sleep(0.2)
411 time.sleep(0.2)
410 ahr = self.client.resubmit(ar.msg_ids)
412 ahr = self.client.resubmit(ar.msg_ids)
411 ar.get(2)
413 ar.get(2)
412 ahr.get(2)
414 ahr.get(2)
413
415
414 def test_resubmit_badkey(self):
416 def test_resubmit_badkey(self):
415 """ensure KeyError on resubmit of nonexistant task"""
417 """ensure KeyError on resubmit of nonexistant task"""
416 self.assertRaisesRemote(KeyError, self.client.resubmit, ['invalid'])
418 self.assertRaisesRemote(KeyError, self.client.resubmit, ['invalid'])
417
419
418 def test_purge_hub_results(self):
420 def test_purge_hub_results(self):
419 # ensure there are some tasks
421 # ensure there are some tasks
420 for i in range(5):
422 for i in range(5):
421 self.client[:].apply_sync(lambda : 1)
423 self.client[:].apply_sync(lambda : 1)
422 # Wait for the Hub to realise the result is done:
424 # Wait for the Hub to realise the result is done:
423 # This prevents a race condition, where we
425 # This prevents a race condition, where we
424 # might purge a result the Hub still thinks is pending.
426 # might purge a result the Hub still thinks is pending.
425 self._wait_for_idle()
427 self._wait_for_idle()
426 rc2 = clientmod.Client(profile='iptest')
428 rc2 = clientmod.Client(profile='iptest')
427 hist = self.client.hub_history()
429 hist = self.client.hub_history()
428 ahr = rc2.get_result([hist[-1]])
430 ahr = rc2.get_result([hist[-1]])
429 ahr.wait(10)
431 ahr.wait(10)
430 self.client.purge_hub_results(hist[-1])
432 self.client.purge_hub_results(hist[-1])
431 newhist = self.client.hub_history()
433 newhist = self.client.hub_history()
432 self.assertEqual(len(newhist)+1,len(hist))
434 self.assertEqual(len(newhist)+1,len(hist))
433 rc2.spin()
435 rc2.spin()
434 rc2.close()
436 rc2.close()
435
437
436 def test_purge_local_results(self):
438 def test_purge_local_results(self):
437 # ensure there are some tasks
439 # ensure there are some tasks
438 res = []
440 res = []
439 for i in range(5):
441 for i in range(5):
440 res.append(self.client[:].apply_async(lambda : 1))
442 res.append(self.client[:].apply_async(lambda : 1))
441 self._wait_for_idle()
443 self._wait_for_idle()
442 self.client.wait(10) # wait for the results to come back
444 self.client.wait(10) # wait for the results to come back
443 before = len(self.client.results)
445 before = len(self.client.results)
444 self.assertEqual(len(self.client.metadata),before)
446 self.assertEqual(len(self.client.metadata),before)
445 self.client.purge_local_results(res[-1])
447 self.client.purge_local_results(res[-1])
446 self.assertEqual(len(self.client.results),before-len(res[-1]), msg="Not removed from results")
448 self.assertEqual(len(self.client.results),before-len(res[-1]), msg="Not removed from results")
447 self.assertEqual(len(self.client.metadata),before-len(res[-1]), msg="Not removed from metadata")
449 self.assertEqual(len(self.client.metadata),before-len(res[-1]), msg="Not removed from metadata")
448
450
449 def test_purge_local_results_outstanding(self):
451 def test_purge_local_results_outstanding(self):
450 v = self.client[-1]
452 v = self.client[-1]
451 ar = v.apply_async(lambda : 1)
453 ar = v.apply_async(lambda : 1)
452 msg_id = ar.msg_ids[0]
454 msg_id = ar.msg_ids[0]
455 ar.owner = False
453 ar.get()
456 ar.get()
454 self._wait_for_idle()
457 self._wait_for_idle()
455 ar2 = v.apply_async(time.sleep, 1)
458 ar2 = v.apply_async(time.sleep, 1)
456 self.assertIn(msg_id, self.client.results)
459 self.assertIn(msg_id, self.client.results)
457 self.assertIn(msg_id, self.client.metadata)
460 self.assertIn(msg_id, self.client.metadata)
458 self.client.purge_local_results(ar)
461 self.client.purge_local_results(ar)
459 self.assertNotIn(msg_id, self.client.results)
462 self.assertNotIn(msg_id, self.client.results)
460 self.assertNotIn(msg_id, self.client.metadata)
463 self.assertNotIn(msg_id, self.client.metadata)
461 with self.assertRaises(RuntimeError):
464 with self.assertRaises(RuntimeError):
462 self.client.purge_local_results(ar2)
465 self.client.purge_local_results(ar2)
463 ar2.get()
466 ar2.get()
464 self.client.purge_local_results(ar2)
467 self.client.purge_local_results(ar2)
465
468
466 def test_purge_all_local_results_outstanding(self):
469 def test_purge_all_local_results_outstanding(self):
467 v = self.client[-1]
470 v = self.client[-1]
468 ar = v.apply_async(time.sleep, 1)
471 ar = v.apply_async(time.sleep, 1)
469 with self.assertRaises(RuntimeError):
472 with self.assertRaises(RuntimeError):
470 self.client.purge_local_results('all')
473 self.client.purge_local_results('all')
471 ar.get()
474 ar.get()
472 self.client.purge_local_results('all')
475 self.client.purge_local_results('all')
473
476
474 def test_purge_all_hub_results(self):
477 def test_purge_all_hub_results(self):
475 self.client.purge_hub_results('all')
478 self.client.purge_hub_results('all')
476 hist = self.client.hub_history()
479 hist = self.client.hub_history()
477 self.assertEqual(len(hist), 0)
480 self.assertEqual(len(hist), 0)
478
481
479 def test_purge_all_local_results(self):
482 def test_purge_all_local_results(self):
480 self.client.purge_local_results('all')
483 self.client.purge_local_results('all')
481 self.assertEqual(len(self.client.results), 0, msg="Results not empty")
484 self.assertEqual(len(self.client.results), 0, msg="Results not empty")
482 self.assertEqual(len(self.client.metadata), 0, msg="metadata not empty")
485 self.assertEqual(len(self.client.metadata), 0, msg="metadata not empty")
483
486
484 def test_purge_all_results(self):
487 def test_purge_all_results(self):
485 # ensure there are some tasks
488 # ensure there are some tasks
486 for i in range(5):
489 for i in range(5):
487 self.client[:].apply_sync(lambda : 1)
490 self.client[:].apply_sync(lambda : 1)
488 self.client.wait(10)
491 self.client.wait(10)
489 self._wait_for_idle()
492 self._wait_for_idle()
490 self.client.purge_results('all')
493 self.client.purge_results('all')
491 self.assertEqual(len(self.client.results), 0, msg="Results not empty")
494 self.assertEqual(len(self.client.results), 0, msg="Results not empty")
492 self.assertEqual(len(self.client.metadata), 0, msg="metadata not empty")
495 self.assertEqual(len(self.client.metadata), 0, msg="metadata not empty")
493 hist = self.client.hub_history()
496 hist = self.client.hub_history()
494 self.assertEqual(len(hist), 0, msg="hub history not empty")
497 self.assertEqual(len(hist), 0, msg="hub history not empty")
495
498
496 def test_purge_everything(self):
499 def test_purge_everything(self):
497 # ensure there are some tasks
500 # ensure there are some tasks
498 for i in range(5):
501 for i in range(5):
499 self.client[:].apply_sync(lambda : 1)
502 self.client[:].apply_sync(lambda : 1)
500 self.client.wait(10)
503 self.client.wait(10)
501 self._wait_for_idle()
504 self._wait_for_idle()
502 self.client.purge_everything()
505 self.client.purge_everything()
503 # The client results
506 # The client results
504 self.assertEqual(len(self.client.results), 0, msg="Results not empty")
507 self.assertEqual(len(self.client.results), 0, msg="Results not empty")
505 self.assertEqual(len(self.client.metadata), 0, msg="metadata not empty")
508 self.assertEqual(len(self.client.metadata), 0, msg="metadata not empty")
506 # The client "bookkeeping"
509 # The client "bookkeeping"
507 self.assertEqual(len(self.client.session.digest_history), 0, msg="session digest not empty")
510 self.assertEqual(len(self.client.session.digest_history), 0, msg="session digest not empty")
508 self.assertEqual(len(self.client.history), 0, msg="client history not empty")
511 self.assertEqual(len(self.client.history), 0, msg="client history not empty")
509 # the hub results
512 # the hub results
510 hist = self.client.hub_history()
513 hist = self.client.hub_history()
511 self.assertEqual(len(hist), 0, msg="hub history not empty")
514 self.assertEqual(len(hist), 0, msg="hub history not empty")
512
515
513
516
514 def test_spin_thread(self):
517 def test_spin_thread(self):
515 self.client.spin_thread(0.01)
518 self.client.spin_thread(0.01)
516 ar = self.client[-1].apply_async(lambda : 1)
519 ar = self.client[-1].apply_async(lambda : 1)
517 md = self.client.metadata[ar.msg_ids[0]]
520 md = self.client.metadata[ar.msg_ids[0]]
518 # 3s timeout, 100ms poll
521 # 3s timeout, 100ms poll
519 for i in range(30):
522 for i in range(30):
520 time.sleep(0.1)
523 time.sleep(0.1)
521 if md['received'] is not None:
524 if md['received'] is not None:
522 break
525 break
523 self.assertIsInstance(md['received'], datetime)
526 self.assertIsInstance(md['received'], datetime)
524
527
525 def test_stop_spin_thread(self):
528 def test_stop_spin_thread(self):
526 self.client.spin_thread(0.01)
529 self.client.spin_thread(0.01)
527 self.client.stop_spin_thread()
530 self.client.stop_spin_thread()
528 ar = self.client[-1].apply_async(lambda : 1)
531 ar = self.client[-1].apply_async(lambda : 1)
529 md = self.client.metadata[ar.msg_ids[0]]
532 md = self.client.metadata[ar.msg_ids[0]]
530 # 500ms timeout, 100ms poll
533 # 500ms timeout, 100ms poll
531 for i in range(5):
534 for i in range(5):
532 time.sleep(0.1)
535 time.sleep(0.1)
533 self.assertIsNone(md['received'], None)
536 self.assertIsNone(md['received'], None)
534
537
535 def test_activate(self):
538 def test_activate(self):
536 ip = get_ipython()
539 ip = get_ipython()
537 magics = ip.magics_manager.magics
540 magics = ip.magics_manager.magics
538 self.assertTrue('px' in magics['line'])
541 self.assertTrue('px' in magics['line'])
539 self.assertTrue('px' in magics['cell'])
542 self.assertTrue('px' in magics['cell'])
540 v0 = self.client.activate(-1, '0')
543 v0 = self.client.activate(-1, '0')
541 self.assertTrue('px0' in magics['line'])
544 self.assertTrue('px0' in magics['line'])
542 self.assertTrue('px0' in magics['cell'])
545 self.assertTrue('px0' in magics['cell'])
543 self.assertEqual(v0.targets, self.client.ids[-1])
546 self.assertEqual(v0.targets, self.client.ids[-1])
544 v0 = self.client.activate('all', 'all')
547 v0 = self.client.activate('all', 'all')
545 self.assertTrue('pxall' in magics['line'])
548 self.assertTrue('pxall' in magics['line'])
546 self.assertTrue('pxall' in magics['cell'])
549 self.assertTrue('pxall' in magics['cell'])
547 self.assertEqual(v0.targets, 'all')
550 self.assertEqual(v0.targets, 'all')
@@ -1,842 +1,843 b''
1 # -*- coding: utf-8 -*-
1 # -*- coding: utf-8 -*-
2 """test View objects"""
2 """test View objects"""
3
3
4 # Copyright (c) IPython Development Team.
4 # Copyright (c) IPython Development Team.
5 # Distributed under the terms of the Modified BSD License.
5 # Distributed under the terms of the Modified BSD License.
6
6
7 import base64
7 import base64
8 import sys
8 import sys
9 import platform
9 import platform
10 import time
10 import time
11 from collections import namedtuple
11 from collections import namedtuple
12 from tempfile import NamedTemporaryFile
12 from tempfile import NamedTemporaryFile
13
13
14 import zmq
14 import zmq
15 from nose.plugins.attrib import attr
15 from nose.plugins.attrib import attr
16
16
17 from IPython.testing import decorators as dec
17 from IPython.testing import decorators as dec
18 from IPython.utils.io import capture_output
18 from IPython.utils.io import capture_output
19 from IPython.utils.py3compat import unicode_type
19 from IPython.utils.py3compat import unicode_type
20
20
21 from IPython import parallel as pmod
21 from IPython import parallel as pmod
22 from IPython.parallel import error
22 from IPython.parallel import error
23 from IPython.parallel import AsyncResult, AsyncHubResult, AsyncMapResult
23 from IPython.parallel import AsyncResult, AsyncHubResult, AsyncMapResult
24 from IPython.parallel.util import interactive
24 from IPython.parallel.util import interactive
25
25
26 from IPython.parallel.tests import add_engines
26 from IPython.parallel.tests import add_engines
27
27
28 from .clienttest import ClusterTestCase, crash, wait, skip_without
28 from .clienttest import ClusterTestCase, crash, wait, skip_without
29
29
30 def setup():
30 def setup():
31 add_engines(3, total=True)
31 add_engines(3, total=True)
32
32
33 point = namedtuple("point", "x y")
33 point = namedtuple("point", "x y")
34
34
35 class TestView(ClusterTestCase):
35 class TestView(ClusterTestCase):
36
36
37 def setUp(self):
37 def setUp(self):
38 # On Win XP, wait for resource cleanup, else parallel test group fails
38 # On Win XP, wait for resource cleanup, else parallel test group fails
39 if platform.system() == "Windows" and platform.win32_ver()[0] == "XP":
39 if platform.system() == "Windows" and platform.win32_ver()[0] == "XP":
40 # 1 sec fails. 1.5 sec seems ok. Using 2 sec for margin of safety
40 # 1 sec fails. 1.5 sec seems ok. Using 2 sec for margin of safety
41 time.sleep(2)
41 time.sleep(2)
42 super(TestView, self).setUp()
42 super(TestView, self).setUp()
43
43
44 @attr('crash')
44 @attr('crash')
45 def test_z_crash_mux(self):
45 def test_z_crash_mux(self):
46 """test graceful handling of engine death (direct)"""
46 """test graceful handling of engine death (direct)"""
47 # self.add_engines(1)
47 # self.add_engines(1)
48 eid = self.client.ids[-1]
48 eid = self.client.ids[-1]
49 ar = self.client[eid].apply_async(crash)
49 ar = self.client[eid].apply_async(crash)
50 self.assertRaisesRemote(error.EngineError, ar.get, 10)
50 self.assertRaisesRemote(error.EngineError, ar.get, 10)
51 eid = ar.engine_id
51 eid = ar.engine_id
52 tic = time.time()
52 tic = time.time()
53 while eid in self.client.ids and time.time()-tic < 5:
53 while eid in self.client.ids and time.time()-tic < 5:
54 time.sleep(.01)
54 time.sleep(.01)
55 self.client.spin()
55 self.client.spin()
56 self.assertFalse(eid in self.client.ids, "Engine should have died")
56 self.assertFalse(eid in self.client.ids, "Engine should have died")
57
57
58 def test_push_pull(self):
58 def test_push_pull(self):
59 """test pushing and pulling"""
59 """test pushing and pulling"""
60 data = dict(a=10, b=1.05, c=list(range(10)), d={'e':(1,2),'f':'hi'})
60 data = dict(a=10, b=1.05, c=list(range(10)), d={'e':(1,2),'f':'hi'})
61 t = self.client.ids[-1]
61 t = self.client.ids[-1]
62 v = self.client[t]
62 v = self.client[t]
63 push = v.push
63 push = v.push
64 pull = v.pull
64 pull = v.pull
65 v.block=True
65 v.block=True
66 nengines = len(self.client)
66 nengines = len(self.client)
67 push({'data':data})
67 push({'data':data})
68 d = pull('data')
68 d = pull('data')
69 self.assertEqual(d, data)
69 self.assertEqual(d, data)
70 self.client[:].push({'data':data})
70 self.client[:].push({'data':data})
71 d = self.client[:].pull('data', block=True)
71 d = self.client[:].pull('data', block=True)
72 self.assertEqual(d, nengines*[data])
72 self.assertEqual(d, nengines*[data])
73 ar = push({'data':data}, block=False)
73 ar = push({'data':data}, block=False)
74 self.assertTrue(isinstance(ar, AsyncResult))
74 self.assertTrue(isinstance(ar, AsyncResult))
75 r = ar.get()
75 r = ar.get()
76 ar = self.client[:].pull('data', block=False)
76 ar = self.client[:].pull('data', block=False)
77 self.assertTrue(isinstance(ar, AsyncResult))
77 self.assertTrue(isinstance(ar, AsyncResult))
78 r = ar.get()
78 r = ar.get()
79 self.assertEqual(r, nengines*[data])
79 self.assertEqual(r, nengines*[data])
80 self.client[:].push(dict(a=10,b=20))
80 self.client[:].push(dict(a=10,b=20))
81 r = self.client[:].pull(('a','b'), block=True)
81 r = self.client[:].pull(('a','b'), block=True)
82 self.assertEqual(r, nengines*[[10,20]])
82 self.assertEqual(r, nengines*[[10,20]])
83
83
84 def test_push_pull_function(self):
84 def test_push_pull_function(self):
85 "test pushing and pulling functions"
85 "test pushing and pulling functions"
86 def testf(x):
86 def testf(x):
87 return 2.0*x
87 return 2.0*x
88
88
89 t = self.client.ids[-1]
89 t = self.client.ids[-1]
90 v = self.client[t]
90 v = self.client[t]
91 v.block=True
91 v.block=True
92 push = v.push
92 push = v.push
93 pull = v.pull
93 pull = v.pull
94 execute = v.execute
94 execute = v.execute
95 push({'testf':testf})
95 push({'testf':testf})
96 r = pull('testf')
96 r = pull('testf')
97 self.assertEqual(r(1.0), testf(1.0))
97 self.assertEqual(r(1.0), testf(1.0))
98 execute('r = testf(10)')
98 execute('r = testf(10)')
99 r = pull('r')
99 r = pull('r')
100 self.assertEqual(r, testf(10))
100 self.assertEqual(r, testf(10))
101 ar = self.client[:].push({'testf':testf}, block=False)
101 ar = self.client[:].push({'testf':testf}, block=False)
102 ar.get()
102 ar.get()
103 ar = self.client[:].pull('testf', block=False)
103 ar = self.client[:].pull('testf', block=False)
104 rlist = ar.get()
104 rlist = ar.get()
105 for r in rlist:
105 for r in rlist:
106 self.assertEqual(r(1.0), testf(1.0))
106 self.assertEqual(r(1.0), testf(1.0))
107 execute("def g(x): return x*x")
107 execute("def g(x): return x*x")
108 r = pull(('testf','g'))
108 r = pull(('testf','g'))
109 self.assertEqual((r[0](10),r[1](10)), (testf(10), 100))
109 self.assertEqual((r[0](10),r[1](10)), (testf(10), 100))
110
110
111 def test_push_function_globals(self):
111 def test_push_function_globals(self):
112 """test that pushed functions have access to globals"""
112 """test that pushed functions have access to globals"""
113 @interactive
113 @interactive
114 def geta():
114 def geta():
115 return a
115 return a
116 # self.add_engines(1)
116 # self.add_engines(1)
117 v = self.client[-1]
117 v = self.client[-1]
118 v.block=True
118 v.block=True
119 v['f'] = geta
119 v['f'] = geta
120 self.assertRaisesRemote(NameError, v.execute, 'b=f()')
120 self.assertRaisesRemote(NameError, v.execute, 'b=f()')
121 v.execute('a=5')
121 v.execute('a=5')
122 v.execute('b=f()')
122 v.execute('b=f()')
123 self.assertEqual(v['b'], 5)
123 self.assertEqual(v['b'], 5)
124
124
125 def test_push_function_defaults(self):
125 def test_push_function_defaults(self):
126 """test that pushed functions preserve default args"""
126 """test that pushed functions preserve default args"""
127 def echo(a=10):
127 def echo(a=10):
128 return a
128 return a
129 v = self.client[-1]
129 v = self.client[-1]
130 v.block=True
130 v.block=True
131 v['f'] = echo
131 v['f'] = echo
132 v.execute('b=f()')
132 v.execute('b=f()')
133 self.assertEqual(v['b'], 10)
133 self.assertEqual(v['b'], 10)
134
134
135 def test_get_result(self):
135 def test_get_result(self):
136 """test getting results from the Hub."""
136 """test getting results from the Hub."""
137 c = pmod.Client(profile='iptest')
137 c = pmod.Client(profile='iptest')
138 # self.add_engines(1)
138 # self.add_engines(1)
139 t = c.ids[-1]
139 t = c.ids[-1]
140 v = c[t]
140 v = c[t]
141 v2 = self.client[t]
141 v2 = self.client[t]
142 ar = v.apply_async(wait, 1)
142 ar = v.apply_async(wait, 1)
143 # give the monitor time to notice the message
143 # give the monitor time to notice the message
144 time.sleep(.25)
144 time.sleep(.25)
145 ahr = v2.get_result(ar.msg_ids[0])
145 ahr = v2.get_result(ar.msg_ids[0], owner=False)
146 self.assertTrue(isinstance(ahr, AsyncHubResult))
146 self.assertIsInstance(ahr, AsyncHubResult)
147 self.assertEqual(ahr.get(), ar.get())
147 self.assertEqual(ahr.get(), ar.get())
148 ar2 = v2.get_result(ar.msg_ids[0])
148 ar2 = v2.get_result(ar.msg_ids[0])
149 self.assertFalse(isinstance(ar2, AsyncHubResult))
149 self.assertNotIsInstance(ar2, AsyncHubResult)
150 self.assertEqual(ahr.get(), ar2.get())
150 c.spin()
151 c.spin()
151 c.close()
152 c.close()
152
153
153 def test_run_newline(self):
154 def test_run_newline(self):
154 """test that run appends newline to files"""
155 """test that run appends newline to files"""
155 with NamedTemporaryFile('w', delete=False) as f:
156 with NamedTemporaryFile('w', delete=False) as f:
156 f.write("""def g():
157 f.write("""def g():
157 return 5
158 return 5
158 """)
159 """)
159 v = self.client[-1]
160 v = self.client[-1]
160 v.run(f.name, block=True)
161 v.run(f.name, block=True)
161 self.assertEqual(v.apply_sync(lambda f: f(), pmod.Reference('g')), 5)
162 self.assertEqual(v.apply_sync(lambda f: f(), pmod.Reference('g')), 5)
162
163
163 def test_apply_tracked(self):
164 def test_apply_tracked(self):
164 """test tracking for apply"""
165 """test tracking for apply"""
165 # self.add_engines(1)
166 # self.add_engines(1)
166 t = self.client.ids[-1]
167 t = self.client.ids[-1]
167 v = self.client[t]
168 v = self.client[t]
168 v.block=False
169 v.block=False
169 def echo(n=1024*1024, **kwargs):
170 def echo(n=1024*1024, **kwargs):
170 with v.temp_flags(**kwargs):
171 with v.temp_flags(**kwargs):
171 return v.apply(lambda x: x, 'x'*n)
172 return v.apply(lambda x: x, 'x'*n)
172 ar = echo(1, track=False)
173 ar = echo(1, track=False)
173 self.assertTrue(isinstance(ar._tracker, zmq.MessageTracker))
174 self.assertTrue(isinstance(ar._tracker, zmq.MessageTracker))
174 self.assertTrue(ar.sent)
175 self.assertTrue(ar.sent)
175 ar = echo(track=True)
176 ar = echo(track=True)
176 self.assertTrue(isinstance(ar._tracker, zmq.MessageTracker))
177 self.assertTrue(isinstance(ar._tracker, zmq.MessageTracker))
177 self.assertEqual(ar.sent, ar._tracker.done)
178 self.assertEqual(ar.sent, ar._tracker.done)
178 ar._tracker.wait()
179 ar._tracker.wait()
179 self.assertTrue(ar.sent)
180 self.assertTrue(ar.sent)
180
181
181 def test_push_tracked(self):
182 def test_push_tracked(self):
182 t = self.client.ids[-1]
183 t = self.client.ids[-1]
183 ns = dict(x='x'*1024*1024)
184 ns = dict(x='x'*1024*1024)
184 v = self.client[t]
185 v = self.client[t]
185 ar = v.push(ns, block=False, track=False)
186 ar = v.push(ns, block=False, track=False)
186 self.assertTrue(isinstance(ar._tracker, zmq.MessageTracker))
187 self.assertTrue(isinstance(ar._tracker, zmq.MessageTracker))
187 self.assertTrue(ar.sent)
188 self.assertTrue(ar.sent)
188
189
189 ar = v.push(ns, block=False, track=True)
190 ar = v.push(ns, block=False, track=True)
190 self.assertTrue(isinstance(ar._tracker, zmq.MessageTracker))
191 self.assertTrue(isinstance(ar._tracker, zmq.MessageTracker))
191 ar._tracker.wait()
192 ar._tracker.wait()
192 self.assertEqual(ar.sent, ar._tracker.done)
193 self.assertEqual(ar.sent, ar._tracker.done)
193 self.assertTrue(ar.sent)
194 self.assertTrue(ar.sent)
194 ar.get()
195 ar.get()
195
196
196 def test_scatter_tracked(self):
197 def test_scatter_tracked(self):
197 t = self.client.ids
198 t = self.client.ids
198 x='x'*1024*1024
199 x='x'*1024*1024
199 ar = self.client[t].scatter('x', x, block=False, track=False)
200 ar = self.client[t].scatter('x', x, block=False, track=False)
200 self.assertTrue(isinstance(ar._tracker, zmq.MessageTracker))
201 self.assertTrue(isinstance(ar._tracker, zmq.MessageTracker))
201 self.assertTrue(ar.sent)
202 self.assertTrue(ar.sent)
202
203
203 ar = self.client[t].scatter('x', x, block=False, track=True)
204 ar = self.client[t].scatter('x', x, block=False, track=True)
204 self.assertTrue(isinstance(ar._tracker, zmq.MessageTracker))
205 self.assertTrue(isinstance(ar._tracker, zmq.MessageTracker))
205 self.assertEqual(ar.sent, ar._tracker.done)
206 self.assertEqual(ar.sent, ar._tracker.done)
206 ar._tracker.wait()
207 ar._tracker.wait()
207 self.assertTrue(ar.sent)
208 self.assertTrue(ar.sent)
208 ar.get()
209 ar.get()
209
210
210 def test_remote_reference(self):
211 def test_remote_reference(self):
211 v = self.client[-1]
212 v = self.client[-1]
212 v['a'] = 123
213 v['a'] = 123
213 ra = pmod.Reference('a')
214 ra = pmod.Reference('a')
214 b = v.apply_sync(lambda x: x, ra)
215 b = v.apply_sync(lambda x: x, ra)
215 self.assertEqual(b, 123)
216 self.assertEqual(b, 123)
216
217
217
218
218 def test_scatter_gather(self):
219 def test_scatter_gather(self):
219 view = self.client[:]
220 view = self.client[:]
220 seq1 = list(range(16))
221 seq1 = list(range(16))
221 view.scatter('a', seq1)
222 view.scatter('a', seq1)
222 seq2 = view.gather('a', block=True)
223 seq2 = view.gather('a', block=True)
223 self.assertEqual(seq2, seq1)
224 self.assertEqual(seq2, seq1)
224 self.assertRaisesRemote(NameError, view.gather, 'asdf', block=True)
225 self.assertRaisesRemote(NameError, view.gather, 'asdf', block=True)
225
226
226 @skip_without('numpy')
227 @skip_without('numpy')
227 def test_scatter_gather_numpy(self):
228 def test_scatter_gather_numpy(self):
228 import numpy
229 import numpy
229 from numpy.testing.utils import assert_array_equal
230 from numpy.testing.utils import assert_array_equal
230 view = self.client[:]
231 view = self.client[:]
231 a = numpy.arange(64)
232 a = numpy.arange(64)
232 view.scatter('a', a, block=True)
233 view.scatter('a', a, block=True)
233 b = view.gather('a', block=True)
234 b = view.gather('a', block=True)
234 assert_array_equal(b, a)
235 assert_array_equal(b, a)
235
236
236 def test_scatter_gather_lazy(self):
237 def test_scatter_gather_lazy(self):
237 """scatter/gather with targets='all'"""
238 """scatter/gather with targets='all'"""
238 view = self.client.direct_view(targets='all')
239 view = self.client.direct_view(targets='all')
239 x = list(range(64))
240 x = list(range(64))
240 view.scatter('x', x)
241 view.scatter('x', x)
241 gathered = view.gather('x', block=True)
242 gathered = view.gather('x', block=True)
242 self.assertEqual(gathered, x)
243 self.assertEqual(gathered, x)
243
244
244
245
245 @dec.known_failure_py3
246 @dec.known_failure_py3
246 @skip_without('numpy')
247 @skip_without('numpy')
247 def test_push_numpy_nocopy(self):
248 def test_push_numpy_nocopy(self):
248 import numpy
249 import numpy
249 view = self.client[:]
250 view = self.client[:]
250 a = numpy.arange(64)
251 a = numpy.arange(64)
251 view['A'] = a
252 view['A'] = a
252 @interactive
253 @interactive
253 def check_writeable(x):
254 def check_writeable(x):
254 return x.flags.writeable
255 return x.flags.writeable
255
256
256 for flag in view.apply_sync(check_writeable, pmod.Reference('A')):
257 for flag in view.apply_sync(check_writeable, pmod.Reference('A')):
257 self.assertFalse(flag, "array is writeable, push shouldn't have pickled it")
258 self.assertFalse(flag, "array is writeable, push shouldn't have pickled it")
258
259
259 view.push(dict(B=a))
260 view.push(dict(B=a))
260 for flag in view.apply_sync(check_writeable, pmod.Reference('B')):
261 for flag in view.apply_sync(check_writeable, pmod.Reference('B')):
261 self.assertFalse(flag, "array is writeable, push shouldn't have pickled it")
262 self.assertFalse(flag, "array is writeable, push shouldn't have pickled it")
262
263
263 @skip_without('numpy')
264 @skip_without('numpy')
264 def test_apply_numpy(self):
265 def test_apply_numpy(self):
265 """view.apply(f, ndarray)"""
266 """view.apply(f, ndarray)"""
266 import numpy
267 import numpy
267 from numpy.testing.utils import assert_array_equal
268 from numpy.testing.utils import assert_array_equal
268
269
269 A = numpy.random.random((100,100))
270 A = numpy.random.random((100,100))
270 view = self.client[-1]
271 view = self.client[-1]
271 for dt in [ 'int32', 'uint8', 'float32', 'float64' ]:
272 for dt in [ 'int32', 'uint8', 'float32', 'float64' ]:
272 B = A.astype(dt)
273 B = A.astype(dt)
273 C = view.apply_sync(lambda x:x, B)
274 C = view.apply_sync(lambda x:x, B)
274 assert_array_equal(B,C)
275 assert_array_equal(B,C)
275
276
276 @skip_without('numpy')
277 @skip_without('numpy')
277 def test_apply_numpy_object_dtype(self):
278 def test_apply_numpy_object_dtype(self):
278 """view.apply(f, ndarray) with dtype=object"""
279 """view.apply(f, ndarray) with dtype=object"""
279 import numpy
280 import numpy
280 from numpy.testing.utils import assert_array_equal
281 from numpy.testing.utils import assert_array_equal
281 view = self.client[-1]
282 view = self.client[-1]
282
283
283 A = numpy.array([dict(a=5)])
284 A = numpy.array([dict(a=5)])
284 B = view.apply_sync(lambda x:x, A)
285 B = view.apply_sync(lambda x:x, A)
285 assert_array_equal(A,B)
286 assert_array_equal(A,B)
286
287
287 A = numpy.array([(0, dict(b=10))], dtype=[('i', int), ('o', object)])
288 A = numpy.array([(0, dict(b=10))], dtype=[('i', int), ('o', object)])
288 B = view.apply_sync(lambda x:x, A)
289 B = view.apply_sync(lambda x:x, A)
289 assert_array_equal(A,B)
290 assert_array_equal(A,B)
290
291
291 @skip_without('numpy')
292 @skip_without('numpy')
292 def test_push_pull_recarray(self):
293 def test_push_pull_recarray(self):
293 """push/pull recarrays"""
294 """push/pull recarrays"""
294 import numpy
295 import numpy
295 from numpy.testing.utils import assert_array_equal
296 from numpy.testing.utils import assert_array_equal
296
297
297 view = self.client[-1]
298 view = self.client[-1]
298
299
299 R = numpy.array([
300 R = numpy.array([
300 (1, 'hi', 0.),
301 (1, 'hi', 0.),
301 (2**30, 'there', 2.5),
302 (2**30, 'there', 2.5),
302 (-99999, 'world', -12345.6789),
303 (-99999, 'world', -12345.6789),
303 ], [('n', int), ('s', '|S10'), ('f', float)])
304 ], [('n', int), ('s', '|S10'), ('f', float)])
304
305
305 view['RR'] = R
306 view['RR'] = R
306 R2 = view['RR']
307 R2 = view['RR']
307
308
308 r_dtype, r_shape = view.apply_sync(interactive(lambda : (RR.dtype, RR.shape)))
309 r_dtype, r_shape = view.apply_sync(interactive(lambda : (RR.dtype, RR.shape)))
309 self.assertEqual(r_dtype, R.dtype)
310 self.assertEqual(r_dtype, R.dtype)
310 self.assertEqual(r_shape, R.shape)
311 self.assertEqual(r_shape, R.shape)
311 self.assertEqual(R2.dtype, R.dtype)
312 self.assertEqual(R2.dtype, R.dtype)
312 self.assertEqual(R2.shape, R.shape)
313 self.assertEqual(R2.shape, R.shape)
313 assert_array_equal(R2, R)
314 assert_array_equal(R2, R)
314
315
315 @skip_without('pandas')
316 @skip_without('pandas')
316 def test_push_pull_timeseries(self):
317 def test_push_pull_timeseries(self):
317 """push/pull pandas.TimeSeries"""
318 """push/pull pandas.TimeSeries"""
318 import pandas
319 import pandas
319
320
320 ts = pandas.TimeSeries(list(range(10)))
321 ts = pandas.TimeSeries(list(range(10)))
321
322
322 view = self.client[-1]
323 view = self.client[-1]
323
324
324 view.push(dict(ts=ts), block=True)
325 view.push(dict(ts=ts), block=True)
325 rts = view['ts']
326 rts = view['ts']
326
327
327 self.assertEqual(type(rts), type(ts))
328 self.assertEqual(type(rts), type(ts))
328 self.assertTrue((ts == rts).all())
329 self.assertTrue((ts == rts).all())
329
330
330 def test_map(self):
331 def test_map(self):
331 view = self.client[:]
332 view = self.client[:]
332 def f(x):
333 def f(x):
333 return x**2
334 return x**2
334 data = list(range(16))
335 data = list(range(16))
335 r = view.map_sync(f, data)
336 r = view.map_sync(f, data)
336 self.assertEqual(r, list(map(f, data)))
337 self.assertEqual(r, list(map(f, data)))
337
338
338 def test_map_empty_sequence(self):
339 def test_map_empty_sequence(self):
339 view = self.client[:]
340 view = self.client[:]
340 r = view.map_sync(lambda x: x, [])
341 r = view.map_sync(lambda x: x, [])
341 self.assertEqual(r, [])
342 self.assertEqual(r, [])
342
343
343 def test_map_iterable(self):
344 def test_map_iterable(self):
344 """test map on iterables (direct)"""
345 """test map on iterables (direct)"""
345 view = self.client[:]
346 view = self.client[:]
346 # 101 is prime, so it won't be evenly distributed
347 # 101 is prime, so it won't be evenly distributed
347 arr = range(101)
348 arr = range(101)
348 # ensure it will be an iterator, even in Python 3
349 # ensure it will be an iterator, even in Python 3
349 it = iter(arr)
350 it = iter(arr)
350 r = view.map_sync(lambda x: x, it)
351 r = view.map_sync(lambda x: x, it)
351 self.assertEqual(r, list(arr))
352 self.assertEqual(r, list(arr))
352
353
353 @skip_without('numpy')
354 @skip_without('numpy')
354 def test_map_numpy(self):
355 def test_map_numpy(self):
355 """test map on numpy arrays (direct)"""
356 """test map on numpy arrays (direct)"""
356 import numpy
357 import numpy
357 from numpy.testing.utils import assert_array_equal
358 from numpy.testing.utils import assert_array_equal
358
359
359 view = self.client[:]
360 view = self.client[:]
360 # 101 is prime, so it won't be evenly distributed
361 # 101 is prime, so it won't be evenly distributed
361 arr = numpy.arange(101)
362 arr = numpy.arange(101)
362 r = view.map_sync(lambda x: x, arr)
363 r = view.map_sync(lambda x: x, arr)
363 assert_array_equal(r, arr)
364 assert_array_equal(r, arr)
364
365
365 def test_scatter_gather_nonblocking(self):
366 def test_scatter_gather_nonblocking(self):
366 data = list(range(16))
367 data = list(range(16))
367 view = self.client[:]
368 view = self.client[:]
368 view.scatter('a', data, block=False)
369 view.scatter('a', data, block=False)
369 ar = view.gather('a', block=False)
370 ar = view.gather('a', block=False)
370 self.assertEqual(ar.get(), data)
371 self.assertEqual(ar.get(), data)
371
372
372 @skip_without('numpy')
373 @skip_without('numpy')
373 def test_scatter_gather_numpy_nonblocking(self):
374 def test_scatter_gather_numpy_nonblocking(self):
374 import numpy
375 import numpy
375 from numpy.testing.utils import assert_array_equal
376 from numpy.testing.utils import assert_array_equal
376 a = numpy.arange(64)
377 a = numpy.arange(64)
377 view = self.client[:]
378 view = self.client[:]
378 ar = view.scatter('a', a, block=False)
379 ar = view.scatter('a', a, block=False)
379 self.assertTrue(isinstance(ar, AsyncResult))
380 self.assertTrue(isinstance(ar, AsyncResult))
380 amr = view.gather('a', block=False)
381 amr = view.gather('a', block=False)
381 self.assertTrue(isinstance(amr, AsyncMapResult))
382 self.assertTrue(isinstance(amr, AsyncMapResult))
382 assert_array_equal(amr.get(), a)
383 assert_array_equal(amr.get(), a)
383
384
384 def test_execute(self):
385 def test_execute(self):
385 view = self.client[:]
386 view = self.client[:]
386 # self.client.debug=True
387 # self.client.debug=True
387 execute = view.execute
388 execute = view.execute
388 ar = execute('c=30', block=False)
389 ar = execute('c=30', block=False)
389 self.assertTrue(isinstance(ar, AsyncResult))
390 self.assertTrue(isinstance(ar, AsyncResult))
390 ar = execute('d=[0,1,2]', block=False)
391 ar = execute('d=[0,1,2]', block=False)
391 self.client.wait(ar, 1)
392 self.client.wait(ar, 1)
392 self.assertEqual(len(ar.get()), len(self.client))
393 self.assertEqual(len(ar.get()), len(self.client))
393 for c in view['c']:
394 for c in view['c']:
394 self.assertEqual(c, 30)
395 self.assertEqual(c, 30)
395
396
396 def test_abort(self):
397 def test_abort(self):
397 view = self.client[-1]
398 view = self.client[-1]
398 ar = view.execute('import time; time.sleep(1)', block=False)
399 ar = view.execute('import time; time.sleep(1)', block=False)
399 ar2 = view.apply_async(lambda : 2)
400 ar2 = view.apply_async(lambda : 2)
400 ar3 = view.apply_async(lambda : 3)
401 ar3 = view.apply_async(lambda : 3)
401 view.abort(ar2)
402 view.abort(ar2)
402 view.abort(ar3.msg_ids)
403 view.abort(ar3.msg_ids)
403 self.assertRaises(error.TaskAborted, ar2.get)
404 self.assertRaises(error.TaskAborted, ar2.get)
404 self.assertRaises(error.TaskAborted, ar3.get)
405 self.assertRaises(error.TaskAborted, ar3.get)
405
406
406 def test_abort_all(self):
407 def test_abort_all(self):
407 """view.abort() aborts all outstanding tasks"""
408 """view.abort() aborts all outstanding tasks"""
408 view = self.client[-1]
409 view = self.client[-1]
409 ars = [ view.apply_async(time.sleep, 0.25) for i in range(10) ]
410 ars = [ view.apply_async(time.sleep, 0.25) for i in range(10) ]
410 view.abort()
411 view.abort()
411 view.wait(timeout=5)
412 view.wait(timeout=5)
412 for ar in ars[5:]:
413 for ar in ars[5:]:
413 self.assertRaises(error.TaskAborted, ar.get)
414 self.assertRaises(error.TaskAborted, ar.get)
414
415
415 def test_temp_flags(self):
416 def test_temp_flags(self):
416 view = self.client[-1]
417 view = self.client[-1]
417 view.block=True
418 view.block=True
418 with view.temp_flags(block=False):
419 with view.temp_flags(block=False):
419 self.assertFalse(view.block)
420 self.assertFalse(view.block)
420 self.assertTrue(view.block)
421 self.assertTrue(view.block)
421
422
422 @dec.known_failure_py3
423 @dec.known_failure_py3
423 def test_importer(self):
424 def test_importer(self):
424 view = self.client[-1]
425 view = self.client[-1]
425 view.clear(block=True)
426 view.clear(block=True)
426 with view.importer:
427 with view.importer:
427 import re
428 import re
428
429
429 @interactive
430 @interactive
430 def findall(pat, s):
431 def findall(pat, s):
431 # this globals() step isn't necessary in real code
432 # this globals() step isn't necessary in real code
432 # only to prevent a closure in the test
433 # only to prevent a closure in the test
433 re = globals()['re']
434 re = globals()['re']
434 return re.findall(pat, s)
435 return re.findall(pat, s)
435
436
436 self.assertEqual(view.apply_sync(findall, '\w+', 'hello world'), 'hello world'.split())
437 self.assertEqual(view.apply_sync(findall, '\w+', 'hello world'), 'hello world'.split())
437
438
438 def test_unicode_execute(self):
439 def test_unicode_execute(self):
439 """test executing unicode strings"""
440 """test executing unicode strings"""
440 v = self.client[-1]
441 v = self.client[-1]
441 v.block=True
442 v.block=True
442 if sys.version_info[0] >= 3:
443 if sys.version_info[0] >= 3:
443 code="a='é'"
444 code="a='é'"
444 else:
445 else:
445 code=u"a=u'é'"
446 code=u"a=u'é'"
446 v.execute(code)
447 v.execute(code)
447 self.assertEqual(v['a'], u'é')
448 self.assertEqual(v['a'], u'é')
448
449
449 def test_unicode_apply_result(self):
450 def test_unicode_apply_result(self):
450 """test unicode apply results"""
451 """test unicode apply results"""
451 v = self.client[-1]
452 v = self.client[-1]
452 r = v.apply_sync(lambda : u'é')
453 r = v.apply_sync(lambda : u'é')
453 self.assertEqual(r, u'é')
454 self.assertEqual(r, u'é')
454
455
455 def test_unicode_apply_arg(self):
456 def test_unicode_apply_arg(self):
456 """test passing unicode arguments to apply"""
457 """test passing unicode arguments to apply"""
457 v = self.client[-1]
458 v = self.client[-1]
458
459
459 @interactive
460 @interactive
460 def check_unicode(a, check):
461 def check_unicode(a, check):
461 assert not isinstance(a, bytes), "%r is bytes, not unicode"%a
462 assert not isinstance(a, bytes), "%r is bytes, not unicode"%a
462 assert isinstance(check, bytes), "%r is not bytes"%check
463 assert isinstance(check, bytes), "%r is not bytes"%check
463 assert a.encode('utf8') == check, "%s != %s"%(a,check)
464 assert a.encode('utf8') == check, "%s != %s"%(a,check)
464
465
465 for s in [ u'é', u'ßø®∫',u'asdf' ]:
466 for s in [ u'é', u'ßø®∫',u'asdf' ]:
466 try:
467 try:
467 v.apply_sync(check_unicode, s, s.encode('utf8'))
468 v.apply_sync(check_unicode, s, s.encode('utf8'))
468 except error.RemoteError as e:
469 except error.RemoteError as e:
469 if e.ename == 'AssertionError':
470 if e.ename == 'AssertionError':
470 self.fail(e.evalue)
471 self.fail(e.evalue)
471 else:
472 else:
472 raise e
473 raise e
473
474
474 def test_map_reference(self):
475 def test_map_reference(self):
475 """view.map(<Reference>, *seqs) should work"""
476 """view.map(<Reference>, *seqs) should work"""
476 v = self.client[:]
477 v = self.client[:]
477 v.scatter('n', self.client.ids, flatten=True)
478 v.scatter('n', self.client.ids, flatten=True)
478 v.execute("f = lambda x,y: x*y")
479 v.execute("f = lambda x,y: x*y")
479 rf = pmod.Reference('f')
480 rf = pmod.Reference('f')
480 nlist = list(range(10))
481 nlist = list(range(10))
481 mlist = nlist[::-1]
482 mlist = nlist[::-1]
482 expected = [ m*n for m,n in zip(mlist, nlist) ]
483 expected = [ m*n for m,n in zip(mlist, nlist) ]
483 result = v.map_sync(rf, mlist, nlist)
484 result = v.map_sync(rf, mlist, nlist)
484 self.assertEqual(result, expected)
485 self.assertEqual(result, expected)
485
486
486 def test_apply_reference(self):
487 def test_apply_reference(self):
487 """view.apply(<Reference>, *args) should work"""
488 """view.apply(<Reference>, *args) should work"""
488 v = self.client[:]
489 v = self.client[:]
489 v.scatter('n', self.client.ids, flatten=True)
490 v.scatter('n', self.client.ids, flatten=True)
490 v.execute("f = lambda x: n*x")
491 v.execute("f = lambda x: n*x")
491 rf = pmod.Reference('f')
492 rf = pmod.Reference('f')
492 result = v.apply_sync(rf, 5)
493 result = v.apply_sync(rf, 5)
493 expected = [ 5*id for id in self.client.ids ]
494 expected = [ 5*id for id in self.client.ids ]
494 self.assertEqual(result, expected)
495 self.assertEqual(result, expected)
495
496
496 def test_eval_reference(self):
497 def test_eval_reference(self):
497 v = self.client[self.client.ids[0]]
498 v = self.client[self.client.ids[0]]
498 v['g'] = list(range(5))
499 v['g'] = list(range(5))
499 rg = pmod.Reference('g[0]')
500 rg = pmod.Reference('g[0]')
500 echo = lambda x:x
501 echo = lambda x:x
501 self.assertEqual(v.apply_sync(echo, rg), 0)
502 self.assertEqual(v.apply_sync(echo, rg), 0)
502
503
503 def test_reference_nameerror(self):
504 def test_reference_nameerror(self):
504 v = self.client[self.client.ids[0]]
505 v = self.client[self.client.ids[0]]
505 r = pmod.Reference('elvis_has_left')
506 r = pmod.Reference('elvis_has_left')
506 echo = lambda x:x
507 echo = lambda x:x
507 self.assertRaisesRemote(NameError, v.apply_sync, echo, r)
508 self.assertRaisesRemote(NameError, v.apply_sync, echo, r)
508
509
509 def test_single_engine_map(self):
510 def test_single_engine_map(self):
510 e0 = self.client[self.client.ids[0]]
511 e0 = self.client[self.client.ids[0]]
511 r = list(range(5))
512 r = list(range(5))
512 check = [ -1*i for i in r ]
513 check = [ -1*i for i in r ]
513 result = e0.map_sync(lambda x: -1*x, r)
514 result = e0.map_sync(lambda x: -1*x, r)
514 self.assertEqual(result, check)
515 self.assertEqual(result, check)
515
516
516 def test_len(self):
517 def test_len(self):
517 """len(view) makes sense"""
518 """len(view) makes sense"""
518 e0 = self.client[self.client.ids[0]]
519 e0 = self.client[self.client.ids[0]]
519 self.assertEqual(len(e0), 1)
520 self.assertEqual(len(e0), 1)
520 v = self.client[:]
521 v = self.client[:]
521 self.assertEqual(len(v), len(self.client.ids))
522 self.assertEqual(len(v), len(self.client.ids))
522 v = self.client.direct_view('all')
523 v = self.client.direct_view('all')
523 self.assertEqual(len(v), len(self.client.ids))
524 self.assertEqual(len(v), len(self.client.ids))
524 v = self.client[:2]
525 v = self.client[:2]
525 self.assertEqual(len(v), 2)
526 self.assertEqual(len(v), 2)
526 v = self.client[:1]
527 v = self.client[:1]
527 self.assertEqual(len(v), 1)
528 self.assertEqual(len(v), 1)
528 v = self.client.load_balanced_view()
529 v = self.client.load_balanced_view()
529 self.assertEqual(len(v), len(self.client.ids))
530 self.assertEqual(len(v), len(self.client.ids))
530
531
531
532
532 # begin execute tests
533 # begin execute tests
533
534
534 def test_execute_reply(self):
535 def test_execute_reply(self):
535 e0 = self.client[self.client.ids[0]]
536 e0 = self.client[self.client.ids[0]]
536 e0.block = True
537 e0.block = True
537 ar = e0.execute("5", silent=False)
538 ar = e0.execute("5", silent=False)
538 er = ar.get()
539 er = ar.get()
539 self.assertEqual(str(er), "<ExecuteReply[%i]: 5>" % er.execution_count)
540 self.assertEqual(str(er), "<ExecuteReply[%i]: 5>" % er.execution_count)
540 self.assertEqual(er.execute_result['data']['text/plain'], '5')
541 self.assertEqual(er.execute_result['data']['text/plain'], '5')
541
542
542 def test_execute_reply_rich(self):
543 def test_execute_reply_rich(self):
543 e0 = self.client[self.client.ids[0]]
544 e0 = self.client[self.client.ids[0]]
544 e0.block = True
545 e0.block = True
545 e0.execute("from IPython.display import Image, HTML")
546 e0.execute("from IPython.display import Image, HTML")
546 ar = e0.execute("Image(data=b'garbage', format='png', width=10)", silent=False)
547 ar = e0.execute("Image(data=b'garbage', format='png', width=10)", silent=False)
547 er = ar.get()
548 er = ar.get()
548 b64data = base64.encodestring(b'garbage').decode('ascii')
549 b64data = base64.encodestring(b'garbage').decode('ascii')
549 self.assertEqual(er._repr_png_(), (b64data, dict(width=10)))
550 self.assertEqual(er._repr_png_(), (b64data, dict(width=10)))
550 ar = e0.execute("HTML('<b>bold</b>')", silent=False)
551 ar = e0.execute("HTML('<b>bold</b>')", silent=False)
551 er = ar.get()
552 er = ar.get()
552 self.assertEqual(er._repr_html_(), "<b>bold</b>")
553 self.assertEqual(er._repr_html_(), "<b>bold</b>")
553
554
554 def test_execute_reply_stdout(self):
555 def test_execute_reply_stdout(self):
555 e0 = self.client[self.client.ids[0]]
556 e0 = self.client[self.client.ids[0]]
556 e0.block = True
557 e0.block = True
557 ar = e0.execute("print (5)", silent=False)
558 ar = e0.execute("print (5)", silent=False)
558 er = ar.get()
559 er = ar.get()
559 self.assertEqual(er.stdout.strip(), '5')
560 self.assertEqual(er.stdout.strip(), '5')
560
561
561 def test_execute_result(self):
562 def test_execute_result(self):
562 """execute triggers execute_result with silent=False"""
563 """execute triggers execute_result with silent=False"""
563 view = self.client[:]
564 view = self.client[:]
564 ar = view.execute("5", silent=False, block=True)
565 ar = view.execute("5", silent=False, block=True)
565
566
566 expected = [{'text/plain' : '5'}] * len(view)
567 expected = [{'text/plain' : '5'}] * len(view)
567 mimes = [ out['data'] for out in ar.execute_result ]
568 mimes = [ out['data'] for out in ar.execute_result ]
568 self.assertEqual(mimes, expected)
569 self.assertEqual(mimes, expected)
569
570
570 def test_execute_silent(self):
571 def test_execute_silent(self):
571 """execute does not trigger execute_result with silent=True"""
572 """execute does not trigger execute_result with silent=True"""
572 view = self.client[:]
573 view = self.client[:]
573 ar = view.execute("5", block=True)
574 ar = view.execute("5", block=True)
574 expected = [None] * len(view)
575 expected = [None] * len(view)
575 self.assertEqual(ar.execute_result, expected)
576 self.assertEqual(ar.execute_result, expected)
576
577
577 def test_execute_magic(self):
578 def test_execute_magic(self):
578 """execute accepts IPython commands"""
579 """execute accepts IPython commands"""
579 view = self.client[:]
580 view = self.client[:]
580 view.execute("a = 5")
581 view.execute("a = 5")
581 ar = view.execute("%whos", block=True)
582 ar = view.execute("%whos", block=True)
582 # this will raise, if that failed
583 # this will raise, if that failed
583 ar.get(5)
584 ar.get(5)
584 for stdout in ar.stdout:
585 for stdout in ar.stdout:
585 lines = stdout.splitlines()
586 lines = stdout.splitlines()
586 self.assertEqual(lines[0].split(), ['Variable', 'Type', 'Data/Info'])
587 self.assertEqual(lines[0].split(), ['Variable', 'Type', 'Data/Info'])
587 found = False
588 found = False
588 for line in lines[2:]:
589 for line in lines[2:]:
589 split = line.split()
590 split = line.split()
590 if split == ['a', 'int', '5']:
591 if split == ['a', 'int', '5']:
591 found = True
592 found = True
592 break
593 break
593 self.assertTrue(found, "whos output wrong: %s" % stdout)
594 self.assertTrue(found, "whos output wrong: %s" % stdout)
594
595
595 def test_execute_displaypub(self):
596 def test_execute_displaypub(self):
596 """execute tracks display_pub output"""
597 """execute tracks display_pub output"""
597 view = self.client[:]
598 view = self.client[:]
598 view.execute("from IPython.core.display import *")
599 view.execute("from IPython.core.display import *")
599 ar = view.execute("[ display(i) for i in range(5) ]", block=True)
600 ar = view.execute("[ display(i) for i in range(5) ]", block=True)
600
601
601 expected = [ {u'text/plain' : unicode_type(j)} for j in range(5) ]
602 expected = [ {u'text/plain' : unicode_type(j)} for j in range(5) ]
602 for outputs in ar.outputs:
603 for outputs in ar.outputs:
603 mimes = [ out['data'] for out in outputs ]
604 mimes = [ out['data'] for out in outputs ]
604 self.assertEqual(mimes, expected)
605 self.assertEqual(mimes, expected)
605
606
606 def test_apply_displaypub(self):
607 def test_apply_displaypub(self):
607 """apply tracks display_pub output"""
608 """apply tracks display_pub output"""
608 view = self.client[:]
609 view = self.client[:]
609 view.execute("from IPython.core.display import *")
610 view.execute("from IPython.core.display import *")
610
611
611 @interactive
612 @interactive
612 def publish():
613 def publish():
613 [ display(i) for i in range(5) ]
614 [ display(i) for i in range(5) ]
614
615
615 ar = view.apply_async(publish)
616 ar = view.apply_async(publish)
616 ar.get(5)
617 ar.get(5)
617 expected = [ {u'text/plain' : unicode_type(j)} for j in range(5) ]
618 expected = [ {u'text/plain' : unicode_type(j)} for j in range(5) ]
618 for outputs in ar.outputs:
619 for outputs in ar.outputs:
619 mimes = [ out['data'] for out in outputs ]
620 mimes = [ out['data'] for out in outputs ]
620 self.assertEqual(mimes, expected)
621 self.assertEqual(mimes, expected)
621
622
622 def test_execute_raises(self):
623 def test_execute_raises(self):
623 """exceptions in execute requests raise appropriately"""
624 """exceptions in execute requests raise appropriately"""
624 view = self.client[-1]
625 view = self.client[-1]
625 ar = view.execute("1/0")
626 ar = view.execute("1/0")
626 self.assertRaisesRemote(ZeroDivisionError, ar.get, 2)
627 self.assertRaisesRemote(ZeroDivisionError, ar.get, 2)
627
628
628 def test_remoteerror_render_exception(self):
629 def test_remoteerror_render_exception(self):
629 """RemoteErrors get nice tracebacks"""
630 """RemoteErrors get nice tracebacks"""
630 view = self.client[-1]
631 view = self.client[-1]
631 ar = view.execute("1/0")
632 ar = view.execute("1/0")
632 ip = get_ipython()
633 ip = get_ipython()
633 ip.user_ns['ar'] = ar
634 ip.user_ns['ar'] = ar
634 with capture_output() as io:
635 with capture_output() as io:
635 ip.run_cell("ar.get(2)")
636 ip.run_cell("ar.get(2)")
636
637
637 self.assertTrue('ZeroDivisionError' in io.stdout, io.stdout)
638 self.assertTrue('ZeroDivisionError' in io.stdout, io.stdout)
638
639
639 def test_compositeerror_render_exception(self):
640 def test_compositeerror_render_exception(self):
640 """CompositeErrors get nice tracebacks"""
641 """CompositeErrors get nice tracebacks"""
641 view = self.client[:]
642 view = self.client[:]
642 ar = view.execute("1/0")
643 ar = view.execute("1/0")
643 ip = get_ipython()
644 ip = get_ipython()
644 ip.user_ns['ar'] = ar
645 ip.user_ns['ar'] = ar
645
646
646 with capture_output() as io:
647 with capture_output() as io:
647 ip.run_cell("ar.get(2)")
648 ip.run_cell("ar.get(2)")
648
649
649 count = min(error.CompositeError.tb_limit, len(view))
650 count = min(error.CompositeError.tb_limit, len(view))
650
651
651 self.assertEqual(io.stdout.count('ZeroDivisionError'), count * 2, io.stdout)
652 self.assertEqual(io.stdout.count('ZeroDivisionError'), count * 2, io.stdout)
652 self.assertEqual(io.stdout.count('by zero'), count, io.stdout)
653 self.assertEqual(io.stdout.count('by zero'), count, io.stdout)
653 self.assertEqual(io.stdout.count(':execute'), count, io.stdout)
654 self.assertEqual(io.stdout.count(':execute'), count, io.stdout)
654
655
655 def test_compositeerror_truncate(self):
656 def test_compositeerror_truncate(self):
656 """Truncate CompositeErrors with many exceptions"""
657 """Truncate CompositeErrors with many exceptions"""
657 view = self.client[:]
658 view = self.client[:]
658 msg_ids = []
659 msg_ids = []
659 for i in range(10):
660 for i in range(10):
660 ar = view.execute("1/0")
661 ar = view.execute("1/0")
661 msg_ids.extend(ar.msg_ids)
662 msg_ids.extend(ar.msg_ids)
662
663
663 ar = self.client.get_result(msg_ids)
664 ar = self.client.get_result(msg_ids)
664 try:
665 try:
665 ar.get()
666 ar.get()
666 except error.CompositeError as _e:
667 except error.CompositeError as _e:
667 e = _e
668 e = _e
668 else:
669 else:
669 self.fail("Should have raised CompositeError")
670 self.fail("Should have raised CompositeError")
670
671
671 lines = e.render_traceback()
672 lines = e.render_traceback()
672 with capture_output() as io:
673 with capture_output() as io:
673 e.print_traceback()
674 e.print_traceback()
674
675
675 self.assertTrue("more exceptions" in lines[-1])
676 self.assertTrue("more exceptions" in lines[-1])
676 count = e.tb_limit
677 count = e.tb_limit
677
678
678 self.assertEqual(io.stdout.count('ZeroDivisionError'), 2 * count, io.stdout)
679 self.assertEqual(io.stdout.count('ZeroDivisionError'), 2 * count, io.stdout)
679 self.assertEqual(io.stdout.count('by zero'), count, io.stdout)
680 self.assertEqual(io.stdout.count('by zero'), count, io.stdout)
680 self.assertEqual(io.stdout.count(':execute'), count, io.stdout)
681 self.assertEqual(io.stdout.count(':execute'), count, io.stdout)
681
682
682 @dec.skipif_not_matplotlib
683 @dec.skipif_not_matplotlib
683 def test_magic_pylab(self):
684 def test_magic_pylab(self):
684 """%pylab works on engines"""
685 """%pylab works on engines"""
685 view = self.client[-1]
686 view = self.client[-1]
686 ar = view.execute("%pylab inline")
687 ar = view.execute("%pylab inline")
687 # at least check if this raised:
688 # at least check if this raised:
688 reply = ar.get(5)
689 reply = ar.get(5)
689 # include imports, in case user config
690 # include imports, in case user config
690 ar = view.execute("plot(rand(100))", silent=False)
691 ar = view.execute("plot(rand(100))", silent=False)
691 reply = ar.get(5)
692 reply = ar.get(5)
692 self.assertEqual(len(reply.outputs), 1)
693 self.assertEqual(len(reply.outputs), 1)
693 output = reply.outputs[0]
694 output = reply.outputs[0]
694 self.assertTrue("data" in output)
695 self.assertTrue("data" in output)
695 data = output['data']
696 data = output['data']
696 self.assertTrue("image/png" in data)
697 self.assertTrue("image/png" in data)
697
698
698 def test_func_default_func(self):
699 def test_func_default_func(self):
699 """interactively defined function as apply func default"""
700 """interactively defined function as apply func default"""
700 def foo():
701 def foo():
701 return 'foo'
702 return 'foo'
702
703
703 def bar(f=foo):
704 def bar(f=foo):
704 return f()
705 return f()
705
706
706 view = self.client[-1]
707 view = self.client[-1]
707 ar = view.apply_async(bar)
708 ar = view.apply_async(bar)
708 r = ar.get(10)
709 r = ar.get(10)
709 self.assertEqual(r, 'foo')
710 self.assertEqual(r, 'foo')
710 def test_data_pub_single(self):
711 def test_data_pub_single(self):
711 view = self.client[-1]
712 view = self.client[-1]
712 ar = view.execute('\n'.join([
713 ar = view.execute('\n'.join([
713 'from IPython.kernel.zmq.datapub import publish_data',
714 'from IPython.kernel.zmq.datapub import publish_data',
714 'for i in range(5):',
715 'for i in range(5):',
715 ' publish_data(dict(i=i))'
716 ' publish_data(dict(i=i))'
716 ]), block=False)
717 ]), block=False)
717 self.assertTrue(isinstance(ar.data, dict))
718 self.assertTrue(isinstance(ar.data, dict))
718 ar.get(5)
719 ar.get(5)
719 self.assertEqual(ar.data, dict(i=4))
720 self.assertEqual(ar.data, dict(i=4))
720
721
721 def test_data_pub(self):
722 def test_data_pub(self):
722 view = self.client[:]
723 view = self.client[:]
723 ar = view.execute('\n'.join([
724 ar = view.execute('\n'.join([
724 'from IPython.kernel.zmq.datapub import publish_data',
725 'from IPython.kernel.zmq.datapub import publish_data',
725 'for i in range(5):',
726 'for i in range(5):',
726 ' publish_data(dict(i=i))'
727 ' publish_data(dict(i=i))'
727 ]), block=False)
728 ]), block=False)
728 self.assertTrue(all(isinstance(d, dict) for d in ar.data))
729 self.assertTrue(all(isinstance(d, dict) for d in ar.data))
729 ar.get(5)
730 ar.get(5)
730 self.assertEqual(ar.data, [dict(i=4)] * len(ar))
731 self.assertEqual(ar.data, [dict(i=4)] * len(ar))
731
732
732 def test_can_list_arg(self):
733 def test_can_list_arg(self):
733 """args in lists are canned"""
734 """args in lists are canned"""
734 view = self.client[-1]
735 view = self.client[-1]
735 view['a'] = 128
736 view['a'] = 128
736 rA = pmod.Reference('a')
737 rA = pmod.Reference('a')
737 ar = view.apply_async(lambda x: x, [rA])
738 ar = view.apply_async(lambda x: x, [rA])
738 r = ar.get(5)
739 r = ar.get(5)
739 self.assertEqual(r, [128])
740 self.assertEqual(r, [128])
740
741
741 def test_can_dict_arg(self):
742 def test_can_dict_arg(self):
742 """args in dicts are canned"""
743 """args in dicts are canned"""
743 view = self.client[-1]
744 view = self.client[-1]
744 view['a'] = 128
745 view['a'] = 128
745 rA = pmod.Reference('a')
746 rA = pmod.Reference('a')
746 ar = view.apply_async(lambda x: x, dict(foo=rA))
747 ar = view.apply_async(lambda x: x, dict(foo=rA))
747 r = ar.get(5)
748 r = ar.get(5)
748 self.assertEqual(r, dict(foo=128))
749 self.assertEqual(r, dict(foo=128))
749
750
750 def test_can_list_kwarg(self):
751 def test_can_list_kwarg(self):
751 """kwargs in lists are canned"""
752 """kwargs in lists are canned"""
752 view = self.client[-1]
753 view = self.client[-1]
753 view['a'] = 128
754 view['a'] = 128
754 rA = pmod.Reference('a')
755 rA = pmod.Reference('a')
755 ar = view.apply_async(lambda x=5: x, x=[rA])
756 ar = view.apply_async(lambda x=5: x, x=[rA])
756 r = ar.get(5)
757 r = ar.get(5)
757 self.assertEqual(r, [128])
758 self.assertEqual(r, [128])
758
759
759 def test_can_dict_kwarg(self):
760 def test_can_dict_kwarg(self):
760 """kwargs in dicts are canned"""
761 """kwargs in dicts are canned"""
761 view = self.client[-1]
762 view = self.client[-1]
762 view['a'] = 128
763 view['a'] = 128
763 rA = pmod.Reference('a')
764 rA = pmod.Reference('a')
764 ar = view.apply_async(lambda x=5: x, dict(foo=rA))
765 ar = view.apply_async(lambda x=5: x, dict(foo=rA))
765 r = ar.get(5)
766 r = ar.get(5)
766 self.assertEqual(r, dict(foo=128))
767 self.assertEqual(r, dict(foo=128))
767
768
768 def test_map_ref(self):
769 def test_map_ref(self):
769 """view.map works with references"""
770 """view.map works with references"""
770 view = self.client[:]
771 view = self.client[:]
771 ranks = sorted(self.client.ids)
772 ranks = sorted(self.client.ids)
772 view.scatter('rank', ranks, flatten=True)
773 view.scatter('rank', ranks, flatten=True)
773 rrank = pmod.Reference('rank')
774 rrank = pmod.Reference('rank')
774
775
775 amr = view.map_async(lambda x: x*2, [rrank] * len(view))
776 amr = view.map_async(lambda x: x*2, [rrank] * len(view))
776 drank = amr.get(5)
777 drank = amr.get(5)
777 self.assertEqual(drank, [ r*2 for r in ranks ])
778 self.assertEqual(drank, [ r*2 for r in ranks ])
778
779
779 def test_nested_getitem_setitem(self):
780 def test_nested_getitem_setitem(self):
780 """get and set with view['a.b']"""
781 """get and set with view['a.b']"""
781 view = self.client[-1]
782 view = self.client[-1]
782 view.execute('\n'.join([
783 view.execute('\n'.join([
783 'class A(object): pass',
784 'class A(object): pass',
784 'a = A()',
785 'a = A()',
785 'a.b = 128',
786 'a.b = 128',
786 ]), block=True)
787 ]), block=True)
787 ra = pmod.Reference('a')
788 ra = pmod.Reference('a')
788
789
789 r = view.apply_sync(lambda x: x.b, ra)
790 r = view.apply_sync(lambda x: x.b, ra)
790 self.assertEqual(r, 128)
791 self.assertEqual(r, 128)
791 self.assertEqual(view['a.b'], 128)
792 self.assertEqual(view['a.b'], 128)
792
793
793 view['a.b'] = 0
794 view['a.b'] = 0
794
795
795 r = view.apply_sync(lambda x: x.b, ra)
796 r = view.apply_sync(lambda x: x.b, ra)
796 self.assertEqual(r, 0)
797 self.assertEqual(r, 0)
797 self.assertEqual(view['a.b'], 0)
798 self.assertEqual(view['a.b'], 0)
798
799
799 def test_return_namedtuple(self):
800 def test_return_namedtuple(self):
800 def namedtuplify(x, y):
801 def namedtuplify(x, y):
801 from IPython.parallel.tests.test_view import point
802 from IPython.parallel.tests.test_view import point
802 return point(x, y)
803 return point(x, y)
803
804
804 view = self.client[-1]
805 view = self.client[-1]
805 p = view.apply_sync(namedtuplify, 1, 2)
806 p = view.apply_sync(namedtuplify, 1, 2)
806 self.assertEqual(p.x, 1)
807 self.assertEqual(p.x, 1)
807 self.assertEqual(p.y, 2)
808 self.assertEqual(p.y, 2)
808
809
809 def test_apply_namedtuple(self):
810 def test_apply_namedtuple(self):
810 def echoxy(p):
811 def echoxy(p):
811 return p.y, p.x
812 return p.y, p.x
812
813
813 view = self.client[-1]
814 view = self.client[-1]
814 tup = view.apply_sync(echoxy, point(1, 2))
815 tup = view.apply_sync(echoxy, point(1, 2))
815 self.assertEqual(tup, (2,1))
816 self.assertEqual(tup, (2,1))
816
817
817 def test_sync_imports(self):
818 def test_sync_imports(self):
818 view = self.client[-1]
819 view = self.client[-1]
819 with capture_output() as io:
820 with capture_output() as io:
820 with view.sync_imports():
821 with view.sync_imports():
821 import IPython
822 import IPython
822 self.assertIn("IPython", io.stdout)
823 self.assertIn("IPython", io.stdout)
823
824
824 @interactive
825 @interactive
825 def find_ipython():
826 def find_ipython():
826 return 'IPython' in globals()
827 return 'IPython' in globals()
827
828
828 assert view.apply_sync(find_ipython)
829 assert view.apply_sync(find_ipython)
829
830
830 def test_sync_imports_quiet(self):
831 def test_sync_imports_quiet(self):
831 view = self.client[-1]
832 view = self.client[-1]
832 with capture_output() as io:
833 with capture_output() as io:
833 with view.sync_imports(quiet=True):
834 with view.sync_imports(quiet=True):
834 import IPython
835 import IPython
835 self.assertEqual(io.stdout, '')
836 self.assertEqual(io.stdout, '')
836
837
837 @interactive
838 @interactive
838 def find_ipython():
839 def find_ipython():
839 return 'IPython' in globals()
840 return 'IPython' in globals()
840
841
841 assert view.apply_sync(find_ipython)
842 assert view.apply_sync(find_ipython)
842
843
General Comments 0
You need to be logged in to leave comments. Login now