##// END OF EJS Templates
Merge pull request #1489 from minrk/ncpush...
Min RK -
r6299:6db7d230 merge
parent child Browse files
Show More
@@ -1,1069 +1,1069
1 """Views of remote engines.
1 """Views of remote engines.
2
2
3 Authors:
3 Authors:
4
4
5 * Min RK
5 * Min RK
6 """
6 """
7 #-----------------------------------------------------------------------------
7 #-----------------------------------------------------------------------------
8 # Copyright (C) 2010-2011 The IPython Development Team
8 # Copyright (C) 2010-2011 The IPython Development Team
9 #
9 #
10 # Distributed under the terms of the BSD License. The full license is in
10 # Distributed under the terms of the BSD License. The full license is in
11 # the file COPYING, distributed as part of this software.
11 # the file COPYING, distributed as part of this software.
12 #-----------------------------------------------------------------------------
12 #-----------------------------------------------------------------------------
13
13
14 #-----------------------------------------------------------------------------
14 #-----------------------------------------------------------------------------
15 # Imports
15 # Imports
16 #-----------------------------------------------------------------------------
16 #-----------------------------------------------------------------------------
17
17
18 import imp
18 import imp
19 import sys
19 import sys
20 import warnings
20 import warnings
21 from contextlib import contextmanager
21 from contextlib import contextmanager
22 from types import ModuleType
22 from types import ModuleType
23
23
24 import zmq
24 import zmq
25
25
26 from IPython.testing.skipdoctest import skip_doctest
26 from IPython.testing.skipdoctest import skip_doctest
27 from IPython.utils.traitlets import (
27 from IPython.utils.traitlets import (
28 HasTraits, Any, Bool, List, Dict, Set, Instance, CFloat, Integer
28 HasTraits, Any, Bool, List, Dict, Set, Instance, CFloat, Integer
29 )
29 )
30 from IPython.external.decorator import decorator
30 from IPython.external.decorator import decorator
31
31
32 from IPython.parallel import util
32 from IPython.parallel import util
33 from IPython.parallel.controller.dependency import Dependency, dependent
33 from IPython.parallel.controller.dependency import Dependency, dependent
34
34
35 from . import map as Map
35 from . import map as Map
36 from .asyncresult import AsyncResult, AsyncMapResult
36 from .asyncresult import AsyncResult, AsyncMapResult
37 from .remotefunction import ParallelFunction, parallel, remote, getname
37 from .remotefunction import ParallelFunction, parallel, remote, getname
38
38
39 #-----------------------------------------------------------------------------
39 #-----------------------------------------------------------------------------
40 # Decorators
40 # Decorators
41 #-----------------------------------------------------------------------------
41 #-----------------------------------------------------------------------------
42
42
43 @decorator
43 @decorator
44 def save_ids(f, self, *args, **kwargs):
44 def save_ids(f, self, *args, **kwargs):
45 """Keep our history and outstanding attributes up to date after a method call."""
45 """Keep our history and outstanding attributes up to date after a method call."""
46 n_previous = len(self.client.history)
46 n_previous = len(self.client.history)
47 try:
47 try:
48 ret = f(self, *args, **kwargs)
48 ret = f(self, *args, **kwargs)
49 finally:
49 finally:
50 nmsgs = len(self.client.history) - n_previous
50 nmsgs = len(self.client.history) - n_previous
51 msg_ids = self.client.history[-nmsgs:]
51 msg_ids = self.client.history[-nmsgs:]
52 self.history.extend(msg_ids)
52 self.history.extend(msg_ids)
53 map(self.outstanding.add, msg_ids)
53 map(self.outstanding.add, msg_ids)
54 return ret
54 return ret
55
55
56 @decorator
56 @decorator
57 def sync_results(f, self, *args, **kwargs):
57 def sync_results(f, self, *args, **kwargs):
58 """sync relevant results from self.client to our results attribute."""
58 """sync relevant results from self.client to our results attribute."""
59 ret = f(self, *args, **kwargs)
59 ret = f(self, *args, **kwargs)
60 delta = self.outstanding.difference(self.client.outstanding)
60 delta = self.outstanding.difference(self.client.outstanding)
61 completed = self.outstanding.intersection(delta)
61 completed = self.outstanding.intersection(delta)
62 self.outstanding = self.outstanding.difference(completed)
62 self.outstanding = self.outstanding.difference(completed)
63 for msg_id in completed:
63 for msg_id in completed:
64 self.results[msg_id] = self.client.results[msg_id]
64 self.results[msg_id] = self.client.results[msg_id]
65 return ret
65 return ret
66
66
67 @decorator
67 @decorator
68 def spin_after(f, self, *args, **kwargs):
68 def spin_after(f, self, *args, **kwargs):
69 """call spin after the method."""
69 """call spin after the method."""
70 ret = f(self, *args, **kwargs)
70 ret = f(self, *args, **kwargs)
71 self.spin()
71 self.spin()
72 return ret
72 return ret
73
73
74 #-----------------------------------------------------------------------------
74 #-----------------------------------------------------------------------------
75 # Classes
75 # Classes
76 #-----------------------------------------------------------------------------
76 #-----------------------------------------------------------------------------
77
77
78 @skip_doctest
78 @skip_doctest
79 class View(HasTraits):
79 class View(HasTraits):
80 """Base View class for more convenint apply(f,*args,**kwargs) syntax via attributes.
80 """Base View class for more convenint apply(f,*args,**kwargs) syntax via attributes.
81
81
82 Don't use this class, use subclasses.
82 Don't use this class, use subclasses.
83
83
84 Methods
84 Methods
85 -------
85 -------
86
86
87 spin
87 spin
88 flushes incoming results and registration state changes
88 flushes incoming results and registration state changes
89 control methods spin, and requesting `ids` also ensures up to date
89 control methods spin, and requesting `ids` also ensures up to date
90
90
91 wait
91 wait
92 wait on one or more msg_ids
92 wait on one or more msg_ids
93
93
94 execution methods
94 execution methods
95 apply
95 apply
96 legacy: execute, run
96 legacy: execute, run
97
97
98 data movement
98 data movement
99 push, pull, scatter, gather
99 push, pull, scatter, gather
100
100
101 query methods
101 query methods
102 get_result, queue_status, purge_results, result_status
102 get_result, queue_status, purge_results, result_status
103
103
104 control methods
104 control methods
105 abort, shutdown
105 abort, shutdown
106
106
107 """
107 """
108 # flags
108 # flags
109 block=Bool(False)
109 block=Bool(False)
110 track=Bool(True)
110 track=Bool(True)
111 targets = Any()
111 targets = Any()
112
112
113 history=List()
113 history=List()
114 outstanding = Set()
114 outstanding = Set()
115 results = Dict()
115 results = Dict()
116 client = Instance('IPython.parallel.Client')
116 client = Instance('IPython.parallel.Client')
117
117
118 _socket = Instance('zmq.Socket')
118 _socket = Instance('zmq.Socket')
119 _flag_names = List(['targets', 'block', 'track'])
119 _flag_names = List(['targets', 'block', 'track'])
120 _targets = Any()
120 _targets = Any()
121 _idents = Any()
121 _idents = Any()
122
122
123 def __init__(self, client=None, socket=None, **flags):
123 def __init__(self, client=None, socket=None, **flags):
124 super(View, self).__init__(client=client, _socket=socket)
124 super(View, self).__init__(client=client, _socket=socket)
125 self.block = client.block
125 self.block = client.block
126
126
127 self.set_flags(**flags)
127 self.set_flags(**flags)
128
128
129 assert not self.__class__ is View, "Don't use base View objects, use subclasses"
129 assert not self.__class__ is View, "Don't use base View objects, use subclasses"
130
130
131
131
132 def __repr__(self):
132 def __repr__(self):
133 strtargets = str(self.targets)
133 strtargets = str(self.targets)
134 if len(strtargets) > 16:
134 if len(strtargets) > 16:
135 strtargets = strtargets[:12]+'...]'
135 strtargets = strtargets[:12]+'...]'
136 return "<%s %s>"%(self.__class__.__name__, strtargets)
136 return "<%s %s>"%(self.__class__.__name__, strtargets)
137
137
138 def set_flags(self, **kwargs):
138 def set_flags(self, **kwargs):
139 """set my attribute flags by keyword.
139 """set my attribute flags by keyword.
140
140
141 Views determine behavior with a few attributes (`block`, `track`, etc.).
141 Views determine behavior with a few attributes (`block`, `track`, etc.).
142 These attributes can be set all at once by name with this method.
142 These attributes can be set all at once by name with this method.
143
143
144 Parameters
144 Parameters
145 ----------
145 ----------
146
146
147 block : bool
147 block : bool
148 whether to wait for results
148 whether to wait for results
149 track : bool
149 track : bool
150 whether to create a MessageTracker to allow the user to
150 whether to create a MessageTracker to allow the user to
151 safely edit after arrays and buffers during non-copying
151 safely edit after arrays and buffers during non-copying
152 sends.
152 sends.
153 """
153 """
154 for name, value in kwargs.iteritems():
154 for name, value in kwargs.iteritems():
155 if name not in self._flag_names:
155 if name not in self._flag_names:
156 raise KeyError("Invalid name: %r"%name)
156 raise KeyError("Invalid name: %r"%name)
157 else:
157 else:
158 setattr(self, name, value)
158 setattr(self, name, value)
159
159
160 @contextmanager
160 @contextmanager
161 def temp_flags(self, **kwargs):
161 def temp_flags(self, **kwargs):
162 """temporarily set flags, for use in `with` statements.
162 """temporarily set flags, for use in `with` statements.
163
163
164 See set_flags for permanent setting of flags
164 See set_flags for permanent setting of flags
165
165
166 Examples
166 Examples
167 --------
167 --------
168
168
169 >>> view.track=False
169 >>> view.track=False
170 ...
170 ...
171 >>> with view.temp_flags(track=True):
171 >>> with view.temp_flags(track=True):
172 ... ar = view.apply(dostuff, my_big_array)
172 ... ar = view.apply(dostuff, my_big_array)
173 ... ar.tracker.wait() # wait for send to finish
173 ... ar.tracker.wait() # wait for send to finish
174 >>> view.track
174 >>> view.track
175 False
175 False
176
176
177 """
177 """
178 # preflight: save flags, and set temporaries
178 # preflight: save flags, and set temporaries
179 saved_flags = {}
179 saved_flags = {}
180 for f in self._flag_names:
180 for f in self._flag_names:
181 saved_flags[f] = getattr(self, f)
181 saved_flags[f] = getattr(self, f)
182 self.set_flags(**kwargs)
182 self.set_flags(**kwargs)
183 # yield to the with-statement block
183 # yield to the with-statement block
184 try:
184 try:
185 yield
185 yield
186 finally:
186 finally:
187 # postflight: restore saved flags
187 # postflight: restore saved flags
188 self.set_flags(**saved_flags)
188 self.set_flags(**saved_flags)
189
189
190
190
191 #----------------------------------------------------------------
191 #----------------------------------------------------------------
192 # apply
192 # apply
193 #----------------------------------------------------------------
193 #----------------------------------------------------------------
194
194
195 @sync_results
195 @sync_results
196 @save_ids
196 @save_ids
197 def _really_apply(self, f, args, kwargs, block=None, **options):
197 def _really_apply(self, f, args, kwargs, block=None, **options):
198 """wrapper for client.send_apply_message"""
198 """wrapper for client.send_apply_message"""
199 raise NotImplementedError("Implement in subclasses")
199 raise NotImplementedError("Implement in subclasses")
200
200
201 def apply(self, f, *args, **kwargs):
201 def apply(self, f, *args, **kwargs):
202 """calls f(*args, **kwargs) on remote engines, returning the result.
202 """calls f(*args, **kwargs) on remote engines, returning the result.
203
203
204 This method sets all apply flags via this View's attributes.
204 This method sets all apply flags via this View's attributes.
205
205
206 if self.block is False:
206 if self.block is False:
207 returns AsyncResult
207 returns AsyncResult
208 else:
208 else:
209 returns actual result of f(*args, **kwargs)
209 returns actual result of f(*args, **kwargs)
210 """
210 """
211 return self._really_apply(f, args, kwargs)
211 return self._really_apply(f, args, kwargs)
212
212
213 def apply_async(self, f, *args, **kwargs):
213 def apply_async(self, f, *args, **kwargs):
214 """calls f(*args, **kwargs) on remote engines in a nonblocking manner.
214 """calls f(*args, **kwargs) on remote engines in a nonblocking manner.
215
215
216 returns AsyncResult
216 returns AsyncResult
217 """
217 """
218 return self._really_apply(f, args, kwargs, block=False)
218 return self._really_apply(f, args, kwargs, block=False)
219
219
220 @spin_after
220 @spin_after
221 def apply_sync(self, f, *args, **kwargs):
221 def apply_sync(self, f, *args, **kwargs):
222 """calls f(*args, **kwargs) on remote engines in a blocking manner,
222 """calls f(*args, **kwargs) on remote engines in a blocking manner,
223 returning the result.
223 returning the result.
224
224
225 returns: actual result of f(*args, **kwargs)
225 returns: actual result of f(*args, **kwargs)
226 """
226 """
227 return self._really_apply(f, args, kwargs, block=True)
227 return self._really_apply(f, args, kwargs, block=True)
228
228
229 #----------------------------------------------------------------
229 #----------------------------------------------------------------
230 # wrappers for client and control methods
230 # wrappers for client and control methods
231 #----------------------------------------------------------------
231 #----------------------------------------------------------------
232 @sync_results
232 @sync_results
233 def spin(self):
233 def spin(self):
234 """spin the client, and sync"""
234 """spin the client, and sync"""
235 self.client.spin()
235 self.client.spin()
236
236
237 @sync_results
237 @sync_results
238 def wait(self, jobs=None, timeout=-1):
238 def wait(self, jobs=None, timeout=-1):
239 """waits on one or more `jobs`, for up to `timeout` seconds.
239 """waits on one or more `jobs`, for up to `timeout` seconds.
240
240
241 Parameters
241 Parameters
242 ----------
242 ----------
243
243
244 jobs : int, str, or list of ints and/or strs, or one or more AsyncResult objects
244 jobs : int, str, or list of ints and/or strs, or one or more AsyncResult objects
245 ints are indices to self.history
245 ints are indices to self.history
246 strs are msg_ids
246 strs are msg_ids
247 default: wait on all outstanding messages
247 default: wait on all outstanding messages
248 timeout : float
248 timeout : float
249 a time in seconds, after which to give up.
249 a time in seconds, after which to give up.
250 default is -1, which means no timeout
250 default is -1, which means no timeout
251
251
252 Returns
252 Returns
253 -------
253 -------
254
254
255 True : when all msg_ids are done
255 True : when all msg_ids are done
256 False : timeout reached, some msg_ids still outstanding
256 False : timeout reached, some msg_ids still outstanding
257 """
257 """
258 if jobs is None:
258 if jobs is None:
259 jobs = self.history
259 jobs = self.history
260 return self.client.wait(jobs, timeout)
260 return self.client.wait(jobs, timeout)
261
261
262 def abort(self, jobs=None, targets=None, block=None):
262 def abort(self, jobs=None, targets=None, block=None):
263 """Abort jobs on my engines.
263 """Abort jobs on my engines.
264
264
265 Parameters
265 Parameters
266 ----------
266 ----------
267
267
268 jobs : None, str, list of strs, optional
268 jobs : None, str, list of strs, optional
269 if None: abort all jobs.
269 if None: abort all jobs.
270 else: abort specific msg_id(s).
270 else: abort specific msg_id(s).
271 """
271 """
272 block = block if block is not None else self.block
272 block = block if block is not None else self.block
273 targets = targets if targets is not None else self.targets
273 targets = targets if targets is not None else self.targets
274 jobs = jobs if jobs is not None else list(self.outstanding)
274 jobs = jobs if jobs is not None else list(self.outstanding)
275
275
276 return self.client.abort(jobs=jobs, targets=targets, block=block)
276 return self.client.abort(jobs=jobs, targets=targets, block=block)
277
277
278 def queue_status(self, targets=None, verbose=False):
278 def queue_status(self, targets=None, verbose=False):
279 """Fetch the Queue status of my engines"""
279 """Fetch the Queue status of my engines"""
280 targets = targets if targets is not None else self.targets
280 targets = targets if targets is not None else self.targets
281 return self.client.queue_status(targets=targets, verbose=verbose)
281 return self.client.queue_status(targets=targets, verbose=verbose)
282
282
283 def purge_results(self, jobs=[], targets=[]):
283 def purge_results(self, jobs=[], targets=[]):
284 """Instruct the controller to forget specific results."""
284 """Instruct the controller to forget specific results."""
285 if targets is None or targets == 'all':
285 if targets is None or targets == 'all':
286 targets = self.targets
286 targets = self.targets
287 return self.client.purge_results(jobs=jobs, targets=targets)
287 return self.client.purge_results(jobs=jobs, targets=targets)
288
288
289 def shutdown(self, targets=None, restart=False, hub=False, block=None):
289 def shutdown(self, targets=None, restart=False, hub=False, block=None):
290 """Terminates one or more engine processes, optionally including the hub.
290 """Terminates one or more engine processes, optionally including the hub.
291 """
291 """
292 block = self.block if block is None else block
292 block = self.block if block is None else block
293 if targets is None or targets == 'all':
293 if targets is None or targets == 'all':
294 targets = self.targets
294 targets = self.targets
295 return self.client.shutdown(targets=targets, restart=restart, hub=hub, block=block)
295 return self.client.shutdown(targets=targets, restart=restart, hub=hub, block=block)
296
296
297 @spin_after
297 @spin_after
298 def get_result(self, indices_or_msg_ids=None):
298 def get_result(self, indices_or_msg_ids=None):
299 """return one or more results, specified by history index or msg_id.
299 """return one or more results, specified by history index or msg_id.
300
300
301 See client.get_result for details.
301 See client.get_result for details.
302
302
303 """
303 """
304
304
305 if indices_or_msg_ids is None:
305 if indices_or_msg_ids is None:
306 indices_or_msg_ids = -1
306 indices_or_msg_ids = -1
307 if isinstance(indices_or_msg_ids, int):
307 if isinstance(indices_or_msg_ids, int):
308 indices_or_msg_ids = self.history[indices_or_msg_ids]
308 indices_or_msg_ids = self.history[indices_or_msg_ids]
309 elif isinstance(indices_or_msg_ids, (list,tuple,set)):
309 elif isinstance(indices_or_msg_ids, (list,tuple,set)):
310 indices_or_msg_ids = list(indices_or_msg_ids)
310 indices_or_msg_ids = list(indices_or_msg_ids)
311 for i,index in enumerate(indices_or_msg_ids):
311 for i,index in enumerate(indices_or_msg_ids):
312 if isinstance(index, int):
312 if isinstance(index, int):
313 indices_or_msg_ids[i] = self.history[index]
313 indices_or_msg_ids[i] = self.history[index]
314 return self.client.get_result(indices_or_msg_ids)
314 return self.client.get_result(indices_or_msg_ids)
315
315
316 #-------------------------------------------------------------------
316 #-------------------------------------------------------------------
317 # Map
317 # Map
318 #-------------------------------------------------------------------
318 #-------------------------------------------------------------------
319
319
320 def map(self, f, *sequences, **kwargs):
320 def map(self, f, *sequences, **kwargs):
321 """override in subclasses"""
321 """override in subclasses"""
322 raise NotImplementedError
322 raise NotImplementedError
323
323
324 def map_async(self, f, *sequences, **kwargs):
324 def map_async(self, f, *sequences, **kwargs):
325 """Parallel version of builtin `map`, using this view's engines.
325 """Parallel version of builtin `map`, using this view's engines.
326
326
327 This is equivalent to map(...block=False)
327 This is equivalent to map(...block=False)
328
328
329 See `self.map` for details.
329 See `self.map` for details.
330 """
330 """
331 if 'block' in kwargs:
331 if 'block' in kwargs:
332 raise TypeError("map_async doesn't take a `block` keyword argument.")
332 raise TypeError("map_async doesn't take a `block` keyword argument.")
333 kwargs['block'] = False
333 kwargs['block'] = False
334 return self.map(f,*sequences,**kwargs)
334 return self.map(f,*sequences,**kwargs)
335
335
336 def map_sync(self, f, *sequences, **kwargs):
336 def map_sync(self, f, *sequences, **kwargs):
337 """Parallel version of builtin `map`, using this view's engines.
337 """Parallel version of builtin `map`, using this view's engines.
338
338
339 This is equivalent to map(...block=True)
339 This is equivalent to map(...block=True)
340
340
341 See `self.map` for details.
341 See `self.map` for details.
342 """
342 """
343 if 'block' in kwargs:
343 if 'block' in kwargs:
344 raise TypeError("map_sync doesn't take a `block` keyword argument.")
344 raise TypeError("map_sync doesn't take a `block` keyword argument.")
345 kwargs['block'] = True
345 kwargs['block'] = True
346 return self.map(f,*sequences,**kwargs)
346 return self.map(f,*sequences,**kwargs)
347
347
348 def imap(self, f, *sequences, **kwargs):
348 def imap(self, f, *sequences, **kwargs):
349 """Parallel version of `itertools.imap`.
349 """Parallel version of `itertools.imap`.
350
350
351 See `self.map` for details.
351 See `self.map` for details.
352
352
353 """
353 """
354
354
355 return iter(self.map_async(f,*sequences, **kwargs))
355 return iter(self.map_async(f,*sequences, **kwargs))
356
356
357 #-------------------------------------------------------------------
357 #-------------------------------------------------------------------
358 # Decorators
358 # Decorators
359 #-------------------------------------------------------------------
359 #-------------------------------------------------------------------
360
360
361 def remote(self, block=True, **flags):
361 def remote(self, block=True, **flags):
362 """Decorator for making a RemoteFunction"""
362 """Decorator for making a RemoteFunction"""
363 block = self.block if block is None else block
363 block = self.block if block is None else block
364 return remote(self, block=block, **flags)
364 return remote(self, block=block, **flags)
365
365
366 def parallel(self, dist='b', block=None, **flags):
366 def parallel(self, dist='b', block=None, **flags):
367 """Decorator for making a ParallelFunction"""
367 """Decorator for making a ParallelFunction"""
368 block = self.block if block is None else block
368 block = self.block if block is None else block
369 return parallel(self, dist=dist, block=block, **flags)
369 return parallel(self, dist=dist, block=block, **flags)
370
370
371 @skip_doctest
371 @skip_doctest
372 class DirectView(View):
372 class DirectView(View):
373 """Direct Multiplexer View of one or more engines.
373 """Direct Multiplexer View of one or more engines.
374
374
375 These are created via indexed access to a client:
375 These are created via indexed access to a client:
376
376
377 >>> dv_1 = client[1]
377 >>> dv_1 = client[1]
378 >>> dv_all = client[:]
378 >>> dv_all = client[:]
379 >>> dv_even = client[::2]
379 >>> dv_even = client[::2]
380 >>> dv_some = client[1:3]
380 >>> dv_some = client[1:3]
381
381
382 This object provides dictionary access to engine namespaces:
382 This object provides dictionary access to engine namespaces:
383
383
384 # push a=5:
384 # push a=5:
385 >>> dv['a'] = 5
385 >>> dv['a'] = 5
386 # pull 'foo':
386 # pull 'foo':
387 >>> db['foo']
387 >>> db['foo']
388
388
389 """
389 """
390
390
391 def __init__(self, client=None, socket=None, targets=None):
391 def __init__(self, client=None, socket=None, targets=None):
392 super(DirectView, self).__init__(client=client, socket=socket, targets=targets)
392 super(DirectView, self).__init__(client=client, socket=socket, targets=targets)
393
393
394 @property
394 @property
395 def importer(self):
395 def importer(self):
396 """sync_imports(local=True) as a property.
396 """sync_imports(local=True) as a property.
397
397
398 See sync_imports for details.
398 See sync_imports for details.
399
399
400 """
400 """
401 return self.sync_imports(True)
401 return self.sync_imports(True)
402
402
403 @contextmanager
403 @contextmanager
404 def sync_imports(self, local=True, quiet=False):
404 def sync_imports(self, local=True, quiet=False):
405 """Context Manager for performing simultaneous local and remote imports.
405 """Context Manager for performing simultaneous local and remote imports.
406
406
407 'import x as y' will *not* work. The 'as y' part will simply be ignored.
407 'import x as y' will *not* work. The 'as y' part will simply be ignored.
408
408
409 If `local=True`, then the package will also be imported locally.
409 If `local=True`, then the package will also be imported locally.
410
410
411 If `quiet=True`, no output will be produced when attempting remote
411 If `quiet=True`, no output will be produced when attempting remote
412 imports.
412 imports.
413
413
414 Note that remote-only (`local=False`) imports have not been implemented.
414 Note that remote-only (`local=False`) imports have not been implemented.
415
415
416 >>> with view.sync_imports():
416 >>> with view.sync_imports():
417 ... from numpy import recarray
417 ... from numpy import recarray
418 importing recarray from numpy on engine(s)
418 importing recarray from numpy on engine(s)
419
419
420 """
420 """
421 import __builtin__
421 import __builtin__
422 local_import = __builtin__.__import__
422 local_import = __builtin__.__import__
423 modules = set()
423 modules = set()
424 results = []
424 results = []
425 @util.interactive
425 @util.interactive
426 def remote_import(name, fromlist, level):
426 def remote_import(name, fromlist, level):
427 """the function to be passed to apply, that actually performs the import
427 """the function to be passed to apply, that actually performs the import
428 on the engine, and loads up the user namespace.
428 on the engine, and loads up the user namespace.
429 """
429 """
430 import sys
430 import sys
431 user_ns = globals()
431 user_ns = globals()
432 mod = __import__(name, fromlist=fromlist, level=level)
432 mod = __import__(name, fromlist=fromlist, level=level)
433 if fromlist:
433 if fromlist:
434 for key in fromlist:
434 for key in fromlist:
435 user_ns[key] = getattr(mod, key)
435 user_ns[key] = getattr(mod, key)
436 else:
436 else:
437 user_ns[name] = sys.modules[name]
437 user_ns[name] = sys.modules[name]
438
438
439 def view_import(name, globals={}, locals={}, fromlist=[], level=-1):
439 def view_import(name, globals={}, locals={}, fromlist=[], level=-1):
440 """the drop-in replacement for __import__, that optionally imports
440 """the drop-in replacement for __import__, that optionally imports
441 locally as well.
441 locally as well.
442 """
442 """
443 # don't override nested imports
443 # don't override nested imports
444 save_import = __builtin__.__import__
444 save_import = __builtin__.__import__
445 __builtin__.__import__ = local_import
445 __builtin__.__import__ = local_import
446
446
447 if imp.lock_held():
447 if imp.lock_held():
448 # this is a side-effect import, don't do it remotely, or even
448 # this is a side-effect import, don't do it remotely, or even
449 # ignore the local effects
449 # ignore the local effects
450 return local_import(name, globals, locals, fromlist, level)
450 return local_import(name, globals, locals, fromlist, level)
451
451
452 imp.acquire_lock()
452 imp.acquire_lock()
453 if local:
453 if local:
454 mod = local_import(name, globals, locals, fromlist, level)
454 mod = local_import(name, globals, locals, fromlist, level)
455 else:
455 else:
456 raise NotImplementedError("remote-only imports not yet implemented")
456 raise NotImplementedError("remote-only imports not yet implemented")
457 imp.release_lock()
457 imp.release_lock()
458
458
459 key = name+':'+','.join(fromlist or [])
459 key = name+':'+','.join(fromlist or [])
460 if level == -1 and key not in modules:
460 if level == -1 and key not in modules:
461 modules.add(key)
461 modules.add(key)
462 if not quiet:
462 if not quiet:
463 if fromlist:
463 if fromlist:
464 print "importing %s from %s on engine(s)"%(','.join(fromlist), name)
464 print "importing %s from %s on engine(s)"%(','.join(fromlist), name)
465 else:
465 else:
466 print "importing %s on engine(s)"%name
466 print "importing %s on engine(s)"%name
467 results.append(self.apply_async(remote_import, name, fromlist, level))
467 results.append(self.apply_async(remote_import, name, fromlist, level))
468 # restore override
468 # restore override
469 __builtin__.__import__ = save_import
469 __builtin__.__import__ = save_import
470
470
471 return mod
471 return mod
472
472
473 # override __import__
473 # override __import__
474 __builtin__.__import__ = view_import
474 __builtin__.__import__ = view_import
475 try:
475 try:
476 # enter the block
476 # enter the block
477 yield
477 yield
478 except ImportError:
478 except ImportError:
479 if local:
479 if local:
480 raise
480 raise
481 else:
481 else:
482 # ignore import errors if not doing local imports
482 # ignore import errors if not doing local imports
483 pass
483 pass
484 finally:
484 finally:
485 # always restore __import__
485 # always restore __import__
486 __builtin__.__import__ = local_import
486 __builtin__.__import__ = local_import
487
487
488 for r in results:
488 for r in results:
489 # raise possible remote ImportErrors here
489 # raise possible remote ImportErrors here
490 r.get()
490 r.get()
491
491
492
492
493 @sync_results
493 @sync_results
494 @save_ids
494 @save_ids
495 def _really_apply(self, f, args=None, kwargs=None, targets=None, block=None, track=None):
495 def _really_apply(self, f, args=None, kwargs=None, targets=None, block=None, track=None):
496 """calls f(*args, **kwargs) on remote engines, returning the result.
496 """calls f(*args, **kwargs) on remote engines, returning the result.
497
497
498 This method sets all of `apply`'s flags via this View's attributes.
498 This method sets all of `apply`'s flags via this View's attributes.
499
499
500 Parameters
500 Parameters
501 ----------
501 ----------
502
502
503 f : callable
503 f : callable
504
504
505 args : list [default: empty]
505 args : list [default: empty]
506
506
507 kwargs : dict [default: empty]
507 kwargs : dict [default: empty]
508
508
509 targets : target list [default: self.targets]
509 targets : target list [default: self.targets]
510 where to run
510 where to run
511 block : bool [default: self.block]
511 block : bool [default: self.block]
512 whether to block
512 whether to block
513 track : bool [default: self.track]
513 track : bool [default: self.track]
514 whether to ask zmq to track the message, for safe non-copying sends
514 whether to ask zmq to track the message, for safe non-copying sends
515
515
516 Returns
516 Returns
517 -------
517 -------
518
518
519 if self.block is False:
519 if self.block is False:
520 returns AsyncResult
520 returns AsyncResult
521 else:
521 else:
522 returns actual result of f(*args, **kwargs) on the engine(s)
522 returns actual result of f(*args, **kwargs) on the engine(s)
523 This will be a list of self.targets is also a list (even length 1), or
523 This will be a list of self.targets is also a list (even length 1), or
524 the single result if self.targets is an integer engine id
524 the single result if self.targets is an integer engine id
525 """
525 """
526 args = [] if args is None else args
526 args = [] if args is None else args
527 kwargs = {} if kwargs is None else kwargs
527 kwargs = {} if kwargs is None else kwargs
528 block = self.block if block is None else block
528 block = self.block if block is None else block
529 track = self.track if track is None else track
529 track = self.track if track is None else track
530 targets = self.targets if targets is None else targets
530 targets = self.targets if targets is None else targets
531
531
532 _idents = self.client._build_targets(targets)[0]
532 _idents = self.client._build_targets(targets)[0]
533 msg_ids = []
533 msg_ids = []
534 trackers = []
534 trackers = []
535 for ident in _idents:
535 for ident in _idents:
536 msg = self.client.send_apply_message(self._socket, f, args, kwargs, track=track,
536 msg = self.client.send_apply_message(self._socket, f, args, kwargs, track=track,
537 ident=ident)
537 ident=ident)
538 if track:
538 if track:
539 trackers.append(msg['tracker'])
539 trackers.append(msg['tracker'])
540 msg_ids.append(msg['header']['msg_id'])
540 msg_ids.append(msg['header']['msg_id'])
541 tracker = None if track is False else zmq.MessageTracker(*trackers)
541 tracker = None if track is False else zmq.MessageTracker(*trackers)
542 ar = AsyncResult(self.client, msg_ids, fname=getname(f), targets=targets, tracker=tracker)
542 ar = AsyncResult(self.client, msg_ids, fname=getname(f), targets=targets, tracker=tracker)
543 if block:
543 if block:
544 try:
544 try:
545 return ar.get()
545 return ar.get()
546 except KeyboardInterrupt:
546 except KeyboardInterrupt:
547 pass
547 pass
548 return ar
548 return ar
549
549
550 @spin_after
550 @spin_after
551 def map(self, f, *sequences, **kwargs):
551 def map(self, f, *sequences, **kwargs):
552 """view.map(f, *sequences, block=self.block) => list|AsyncMapResult
552 """view.map(f, *sequences, block=self.block) => list|AsyncMapResult
553
553
554 Parallel version of builtin `map`, using this View's `targets`.
554 Parallel version of builtin `map`, using this View's `targets`.
555
555
556 There will be one task per target, so work will be chunked
556 There will be one task per target, so work will be chunked
557 if the sequences are longer than `targets`.
557 if the sequences are longer than `targets`.
558
558
559 Results can be iterated as they are ready, but will become available in chunks.
559 Results can be iterated as they are ready, but will become available in chunks.
560
560
561 Parameters
561 Parameters
562 ----------
562 ----------
563
563
564 f : callable
564 f : callable
565 function to be mapped
565 function to be mapped
566 *sequences: one or more sequences of matching length
566 *sequences: one or more sequences of matching length
567 the sequences to be distributed and passed to `f`
567 the sequences to be distributed and passed to `f`
568 block : bool
568 block : bool
569 whether to wait for the result or not [default self.block]
569 whether to wait for the result or not [default self.block]
570
570
571 Returns
571 Returns
572 -------
572 -------
573
573
574 if block=False:
574 if block=False:
575 AsyncMapResult
575 AsyncMapResult
576 An object like AsyncResult, but which reassembles the sequence of results
576 An object like AsyncResult, but which reassembles the sequence of results
577 into a single list. AsyncMapResults can be iterated through before all
577 into a single list. AsyncMapResults can be iterated through before all
578 results are complete.
578 results are complete.
579 else:
579 else:
580 list
580 list
581 the result of map(f,*sequences)
581 the result of map(f,*sequences)
582 """
582 """
583
583
584 block = kwargs.pop('block', self.block)
584 block = kwargs.pop('block', self.block)
585 for k in kwargs.keys():
585 for k in kwargs.keys():
586 if k not in ['block', 'track']:
586 if k not in ['block', 'track']:
587 raise TypeError("invalid keyword arg, %r"%k)
587 raise TypeError("invalid keyword arg, %r"%k)
588
588
589 assert len(sequences) > 0, "must have some sequences to map onto!"
589 assert len(sequences) > 0, "must have some sequences to map onto!"
590 pf = ParallelFunction(self, f, block=block, **kwargs)
590 pf = ParallelFunction(self, f, block=block, **kwargs)
591 return pf.map(*sequences)
591 return pf.map(*sequences)
592
592
593 def execute(self, code, targets=None, block=None):
593 def execute(self, code, targets=None, block=None):
594 """Executes `code` on `targets` in blocking or nonblocking manner.
594 """Executes `code` on `targets` in blocking or nonblocking manner.
595
595
596 ``execute`` is always `bound` (affects engine namespace)
596 ``execute`` is always `bound` (affects engine namespace)
597
597
598 Parameters
598 Parameters
599 ----------
599 ----------
600
600
601 code : str
601 code : str
602 the code string to be executed
602 the code string to be executed
603 block : bool
603 block : bool
604 whether or not to wait until done to return
604 whether or not to wait until done to return
605 default: self.block
605 default: self.block
606 """
606 """
607 return self._really_apply(util._execute, args=(code,), block=block, targets=targets)
607 return self._really_apply(util._execute, args=(code,), block=block, targets=targets)
608
608
609 def run(self, filename, targets=None, block=None):
609 def run(self, filename, targets=None, block=None):
610 """Execute contents of `filename` on my engine(s).
610 """Execute contents of `filename` on my engine(s).
611
611
612 This simply reads the contents of the file and calls `execute`.
612 This simply reads the contents of the file and calls `execute`.
613
613
614 Parameters
614 Parameters
615 ----------
615 ----------
616
616
617 filename : str
617 filename : str
618 The path to the file
618 The path to the file
619 targets : int/str/list of ints/strs
619 targets : int/str/list of ints/strs
620 the engines on which to execute
620 the engines on which to execute
621 default : all
621 default : all
622 block : bool
622 block : bool
623 whether or not to wait until done
623 whether or not to wait until done
624 default: self.block
624 default: self.block
625
625
626 """
626 """
627 with open(filename, 'r') as f:
627 with open(filename, 'r') as f:
628 # add newline in case of trailing indented whitespace
628 # add newline in case of trailing indented whitespace
629 # which will cause SyntaxError
629 # which will cause SyntaxError
630 code = f.read()+'\n'
630 code = f.read()+'\n'
631 return self.execute(code, block=block, targets=targets)
631 return self.execute(code, block=block, targets=targets)
632
632
633 def update(self, ns):
633 def update(self, ns):
634 """update remote namespace with dict `ns`
634 """update remote namespace with dict `ns`
635
635
636 See `push` for details.
636 See `push` for details.
637 """
637 """
638 return self.push(ns, block=self.block, track=self.track)
638 return self.push(ns, block=self.block, track=self.track)
639
639
640 def push(self, ns, targets=None, block=None, track=None):
640 def push(self, ns, targets=None, block=None, track=None):
641 """update remote namespace with dict `ns`
641 """update remote namespace with dict `ns`
642
642
643 Parameters
643 Parameters
644 ----------
644 ----------
645
645
646 ns : dict
646 ns : dict
647 dict of keys with which to update engine namespace(s)
647 dict of keys with which to update engine namespace(s)
648 block : bool [default : self.block]
648 block : bool [default : self.block]
649 whether to wait to be notified of engine receipt
649 whether to wait to be notified of engine receipt
650
650
651 """
651 """
652
652
653 block = block if block is not None else self.block
653 block = block if block is not None else self.block
654 track = track if track is not None else self.track
654 track = track if track is not None else self.track
655 targets = targets if targets is not None else self.targets
655 targets = targets if targets is not None else self.targets
656 # applier = self.apply_sync if block else self.apply_async
656 # applier = self.apply_sync if block else self.apply_async
657 if not isinstance(ns, dict):
657 if not isinstance(ns, dict):
658 raise TypeError("Must be a dict, not %s"%type(ns))
658 raise TypeError("Must be a dict, not %s"%type(ns))
659 return self._really_apply(util._push, (ns,), block=block, track=track, targets=targets)
659 return self._really_apply(util._push, kwargs=ns, block=block, track=track, targets=targets)
660
660
661 def get(self, key_s):
661 def get(self, key_s):
662 """get object(s) by `key_s` from remote namespace
662 """get object(s) by `key_s` from remote namespace
663
663
664 see `pull` for details.
664 see `pull` for details.
665 """
665 """
666 # block = block if block is not None else self.block
666 # block = block if block is not None else self.block
667 return self.pull(key_s, block=True)
667 return self.pull(key_s, block=True)
668
668
669 def pull(self, names, targets=None, block=None):
669 def pull(self, names, targets=None, block=None):
670 """get object(s) by `name` from remote namespace
670 """get object(s) by `name` from remote namespace
671
671
672 will return one object if it is a key.
672 will return one object if it is a key.
673 can also take a list of keys, in which case it will return a list of objects.
673 can also take a list of keys, in which case it will return a list of objects.
674 """
674 """
675 block = block if block is not None else self.block
675 block = block if block is not None else self.block
676 targets = targets if targets is not None else self.targets
676 targets = targets if targets is not None else self.targets
677 applier = self.apply_sync if block else self.apply_async
677 applier = self.apply_sync if block else self.apply_async
678 if isinstance(names, basestring):
678 if isinstance(names, basestring):
679 pass
679 pass
680 elif isinstance(names, (list,tuple,set)):
680 elif isinstance(names, (list,tuple,set)):
681 for key in names:
681 for key in names:
682 if not isinstance(key, basestring):
682 if not isinstance(key, basestring):
683 raise TypeError("keys must be str, not type %r"%type(key))
683 raise TypeError("keys must be str, not type %r"%type(key))
684 else:
684 else:
685 raise TypeError("names must be strs, not %r"%names)
685 raise TypeError("names must be strs, not %r"%names)
686 return self._really_apply(util._pull, (names,), block=block, targets=targets)
686 return self._really_apply(util._pull, (names,), block=block, targets=targets)
687
687
688 def scatter(self, key, seq, dist='b', flatten=False, targets=None, block=None, track=None):
688 def scatter(self, key, seq, dist='b', flatten=False, targets=None, block=None, track=None):
689 """
689 """
690 Partition a Python sequence and send the partitions to a set of engines.
690 Partition a Python sequence and send the partitions to a set of engines.
691 """
691 """
692 block = block if block is not None else self.block
692 block = block if block is not None else self.block
693 track = track if track is not None else self.track
693 track = track if track is not None else self.track
694 targets = targets if targets is not None else self.targets
694 targets = targets if targets is not None else self.targets
695
695
696 mapObject = Map.dists[dist]()
696 mapObject = Map.dists[dist]()
697 nparts = len(targets)
697 nparts = len(targets)
698 msg_ids = []
698 msg_ids = []
699 trackers = []
699 trackers = []
700 for index, engineid in enumerate(targets):
700 for index, engineid in enumerate(targets):
701 partition = mapObject.getPartition(seq, index, nparts)
701 partition = mapObject.getPartition(seq, index, nparts)
702 if flatten and len(partition) == 1:
702 if flatten and len(partition) == 1:
703 ns = {key: partition[0]}
703 ns = {key: partition[0]}
704 else:
704 else:
705 ns = {key: partition}
705 ns = {key: partition}
706 r = self.push(ns, block=False, track=track, targets=engineid)
706 r = self.push(ns, block=False, track=track, targets=engineid)
707 msg_ids.extend(r.msg_ids)
707 msg_ids.extend(r.msg_ids)
708 if track:
708 if track:
709 trackers.append(r._tracker)
709 trackers.append(r._tracker)
710
710
711 if track:
711 if track:
712 tracker = zmq.MessageTracker(*trackers)
712 tracker = zmq.MessageTracker(*trackers)
713 else:
713 else:
714 tracker = None
714 tracker = None
715
715
716 r = AsyncResult(self.client, msg_ids, fname='scatter', targets=targets, tracker=tracker)
716 r = AsyncResult(self.client, msg_ids, fname='scatter', targets=targets, tracker=tracker)
717 if block:
717 if block:
718 r.wait()
718 r.wait()
719 else:
719 else:
720 return r
720 return r
721
721
722 @sync_results
722 @sync_results
723 @save_ids
723 @save_ids
724 def gather(self, key, dist='b', targets=None, block=None):
724 def gather(self, key, dist='b', targets=None, block=None):
725 """
725 """
726 Gather a partitioned sequence on a set of engines as a single local seq.
726 Gather a partitioned sequence on a set of engines as a single local seq.
727 """
727 """
728 block = block if block is not None else self.block
728 block = block if block is not None else self.block
729 targets = targets if targets is not None else self.targets
729 targets = targets if targets is not None else self.targets
730 mapObject = Map.dists[dist]()
730 mapObject = Map.dists[dist]()
731 msg_ids = []
731 msg_ids = []
732
732
733 for index, engineid in enumerate(targets):
733 for index, engineid in enumerate(targets):
734 msg_ids.extend(self.pull(key, block=False, targets=engineid).msg_ids)
734 msg_ids.extend(self.pull(key, block=False, targets=engineid).msg_ids)
735
735
736 r = AsyncMapResult(self.client, msg_ids, mapObject, fname='gather')
736 r = AsyncMapResult(self.client, msg_ids, mapObject, fname='gather')
737
737
738 if block:
738 if block:
739 try:
739 try:
740 return r.get()
740 return r.get()
741 except KeyboardInterrupt:
741 except KeyboardInterrupt:
742 pass
742 pass
743 return r
743 return r
744
744
745 def __getitem__(self, key):
745 def __getitem__(self, key):
746 return self.get(key)
746 return self.get(key)
747
747
748 def __setitem__(self,key, value):
748 def __setitem__(self,key, value):
749 self.update({key:value})
749 self.update({key:value})
750
750
751 def clear(self, targets=None, block=False):
751 def clear(self, targets=None, block=False):
752 """Clear the remote namespaces on my engines."""
752 """Clear the remote namespaces on my engines."""
753 block = block if block is not None else self.block
753 block = block if block is not None else self.block
754 targets = targets if targets is not None else self.targets
754 targets = targets if targets is not None else self.targets
755 return self.client.clear(targets=targets, block=block)
755 return self.client.clear(targets=targets, block=block)
756
756
757 def kill(self, targets=None, block=True):
757 def kill(self, targets=None, block=True):
758 """Kill my engines."""
758 """Kill my engines."""
759 block = block if block is not None else self.block
759 block = block if block is not None else self.block
760 targets = targets if targets is not None else self.targets
760 targets = targets if targets is not None else self.targets
761 return self.client.kill(targets=targets, block=block)
761 return self.client.kill(targets=targets, block=block)
762
762
763 #----------------------------------------
763 #----------------------------------------
764 # activate for %px,%autopx magics
764 # activate for %px,%autopx magics
765 #----------------------------------------
765 #----------------------------------------
766 def activate(self):
766 def activate(self):
767 """Make this `View` active for parallel magic commands.
767 """Make this `View` active for parallel magic commands.
768
768
769 IPython has a magic command syntax to work with `MultiEngineClient` objects.
769 IPython has a magic command syntax to work with `MultiEngineClient` objects.
770 In a given IPython session there is a single active one. While
770 In a given IPython session there is a single active one. While
771 there can be many `Views` created and used by the user,
771 there can be many `Views` created and used by the user,
772 there is only one active one. The active `View` is used whenever
772 there is only one active one. The active `View` is used whenever
773 the magic commands %px and %autopx are used.
773 the magic commands %px and %autopx are used.
774
774
775 The activate() method is called on a given `View` to make it
775 The activate() method is called on a given `View` to make it
776 active. Once this has been done, the magic commands can be used.
776 active. Once this has been done, the magic commands can be used.
777 """
777 """
778
778
779 try:
779 try:
780 # This is injected into __builtins__.
780 # This is injected into __builtins__.
781 ip = get_ipython()
781 ip = get_ipython()
782 except NameError:
782 except NameError:
783 print "The IPython parallel magics (%result, %px, %autopx) only work within IPython."
783 print "The IPython parallel magics (%result, %px, %autopx) only work within IPython."
784 else:
784 else:
785 pmagic = ip.plugin_manager.get_plugin('parallelmagic')
785 pmagic = ip.plugin_manager.get_plugin('parallelmagic')
786 if pmagic is None:
786 if pmagic is None:
787 ip.magic_load_ext('parallelmagic')
787 ip.magic_load_ext('parallelmagic')
788 pmagic = ip.plugin_manager.get_plugin('parallelmagic')
788 pmagic = ip.plugin_manager.get_plugin('parallelmagic')
789
789
790 pmagic.active_view = self
790 pmagic.active_view = self
791
791
792
792
793 @skip_doctest
793 @skip_doctest
794 class LoadBalancedView(View):
794 class LoadBalancedView(View):
795 """An load-balancing View that only executes via the Task scheduler.
795 """An load-balancing View that only executes via the Task scheduler.
796
796
797 Load-balanced views can be created with the client's `view` method:
797 Load-balanced views can be created with the client's `view` method:
798
798
799 >>> v = client.load_balanced_view()
799 >>> v = client.load_balanced_view()
800
800
801 or targets can be specified, to restrict the potential destinations:
801 or targets can be specified, to restrict the potential destinations:
802
802
803 >>> v = client.client.load_balanced_view([1,3])
803 >>> v = client.client.load_balanced_view([1,3])
804
804
805 which would restrict loadbalancing to between engines 1 and 3.
805 which would restrict loadbalancing to between engines 1 and 3.
806
806
807 """
807 """
808
808
809 follow=Any()
809 follow=Any()
810 after=Any()
810 after=Any()
811 timeout=CFloat()
811 timeout=CFloat()
812 retries = Integer(0)
812 retries = Integer(0)
813
813
814 _task_scheme = Any()
814 _task_scheme = Any()
815 _flag_names = List(['targets', 'block', 'track', 'follow', 'after', 'timeout', 'retries'])
815 _flag_names = List(['targets', 'block', 'track', 'follow', 'after', 'timeout', 'retries'])
816
816
817 def __init__(self, client=None, socket=None, **flags):
817 def __init__(self, client=None, socket=None, **flags):
818 super(LoadBalancedView, self).__init__(client=client, socket=socket, **flags)
818 super(LoadBalancedView, self).__init__(client=client, socket=socket, **flags)
819 self._task_scheme=client._task_scheme
819 self._task_scheme=client._task_scheme
820
820
821 def _validate_dependency(self, dep):
821 def _validate_dependency(self, dep):
822 """validate a dependency.
822 """validate a dependency.
823
823
824 For use in `set_flags`.
824 For use in `set_flags`.
825 """
825 """
826 if dep is None or isinstance(dep, (basestring, AsyncResult, Dependency)):
826 if dep is None or isinstance(dep, (basestring, AsyncResult, Dependency)):
827 return True
827 return True
828 elif isinstance(dep, (list,set, tuple)):
828 elif isinstance(dep, (list,set, tuple)):
829 for d in dep:
829 for d in dep:
830 if not isinstance(d, (basestring, AsyncResult)):
830 if not isinstance(d, (basestring, AsyncResult)):
831 return False
831 return False
832 elif isinstance(dep, dict):
832 elif isinstance(dep, dict):
833 if set(dep.keys()) != set(Dependency().as_dict().keys()):
833 if set(dep.keys()) != set(Dependency().as_dict().keys()):
834 return False
834 return False
835 if not isinstance(dep['msg_ids'], list):
835 if not isinstance(dep['msg_ids'], list):
836 return False
836 return False
837 for d in dep['msg_ids']:
837 for d in dep['msg_ids']:
838 if not isinstance(d, basestring):
838 if not isinstance(d, basestring):
839 return False
839 return False
840 else:
840 else:
841 return False
841 return False
842
842
843 return True
843 return True
844
844
845 def _render_dependency(self, dep):
845 def _render_dependency(self, dep):
846 """helper for building jsonable dependencies from various input forms."""
846 """helper for building jsonable dependencies from various input forms."""
847 if isinstance(dep, Dependency):
847 if isinstance(dep, Dependency):
848 return dep.as_dict()
848 return dep.as_dict()
849 elif isinstance(dep, AsyncResult):
849 elif isinstance(dep, AsyncResult):
850 return dep.msg_ids
850 return dep.msg_ids
851 elif dep is None:
851 elif dep is None:
852 return []
852 return []
853 else:
853 else:
854 # pass to Dependency constructor
854 # pass to Dependency constructor
855 return list(Dependency(dep))
855 return list(Dependency(dep))
856
856
857 def set_flags(self, **kwargs):
857 def set_flags(self, **kwargs):
858 """set my attribute flags by keyword.
858 """set my attribute flags by keyword.
859
859
860 A View is a wrapper for the Client's apply method, but with attributes
860 A View is a wrapper for the Client's apply method, but with attributes
861 that specify keyword arguments, those attributes can be set by keyword
861 that specify keyword arguments, those attributes can be set by keyword
862 argument with this method.
862 argument with this method.
863
863
864 Parameters
864 Parameters
865 ----------
865 ----------
866
866
867 block : bool
867 block : bool
868 whether to wait for results
868 whether to wait for results
869 track : bool
869 track : bool
870 whether to create a MessageTracker to allow the user to
870 whether to create a MessageTracker to allow the user to
871 safely edit after arrays and buffers during non-copying
871 safely edit after arrays and buffers during non-copying
872 sends.
872 sends.
873
873
874 after : Dependency or collection of msg_ids
874 after : Dependency or collection of msg_ids
875 Only for load-balanced execution (targets=None)
875 Only for load-balanced execution (targets=None)
876 Specify a list of msg_ids as a time-based dependency.
876 Specify a list of msg_ids as a time-based dependency.
877 This job will only be run *after* the dependencies
877 This job will only be run *after* the dependencies
878 have been met.
878 have been met.
879
879
880 follow : Dependency or collection of msg_ids
880 follow : Dependency or collection of msg_ids
881 Only for load-balanced execution (targets=None)
881 Only for load-balanced execution (targets=None)
882 Specify a list of msg_ids as a location-based dependency.
882 Specify a list of msg_ids as a location-based dependency.
883 This job will only be run on an engine where this dependency
883 This job will only be run on an engine where this dependency
884 is met.
884 is met.
885
885
886 timeout : float/int or None
886 timeout : float/int or None
887 Only for load-balanced execution (targets=None)
887 Only for load-balanced execution (targets=None)
888 Specify an amount of time (in seconds) for the scheduler to
888 Specify an amount of time (in seconds) for the scheduler to
889 wait for dependencies to be met before failing with a
889 wait for dependencies to be met before failing with a
890 DependencyTimeout.
890 DependencyTimeout.
891
891
892 retries : int
892 retries : int
893 Number of times a task will be retried on failure.
893 Number of times a task will be retried on failure.
894 """
894 """
895
895
896 super(LoadBalancedView, self).set_flags(**kwargs)
896 super(LoadBalancedView, self).set_flags(**kwargs)
897 for name in ('follow', 'after'):
897 for name in ('follow', 'after'):
898 if name in kwargs:
898 if name in kwargs:
899 value = kwargs[name]
899 value = kwargs[name]
900 if self._validate_dependency(value):
900 if self._validate_dependency(value):
901 setattr(self, name, value)
901 setattr(self, name, value)
902 else:
902 else:
903 raise ValueError("Invalid dependency: %r"%value)
903 raise ValueError("Invalid dependency: %r"%value)
904 if 'timeout' in kwargs:
904 if 'timeout' in kwargs:
905 t = kwargs['timeout']
905 t = kwargs['timeout']
906 if not isinstance(t, (int, long, float, type(None))):
906 if not isinstance(t, (int, long, float, type(None))):
907 raise TypeError("Invalid type for timeout: %r"%type(t))
907 raise TypeError("Invalid type for timeout: %r"%type(t))
908 if t is not None:
908 if t is not None:
909 if t < 0:
909 if t < 0:
910 raise ValueError("Invalid timeout: %s"%t)
910 raise ValueError("Invalid timeout: %s"%t)
911 self.timeout = t
911 self.timeout = t
912
912
913 @sync_results
913 @sync_results
914 @save_ids
914 @save_ids
915 def _really_apply(self, f, args=None, kwargs=None, block=None, track=None,
915 def _really_apply(self, f, args=None, kwargs=None, block=None, track=None,
916 after=None, follow=None, timeout=None,
916 after=None, follow=None, timeout=None,
917 targets=None, retries=None):
917 targets=None, retries=None):
918 """calls f(*args, **kwargs) on a remote engine, returning the result.
918 """calls f(*args, **kwargs) on a remote engine, returning the result.
919
919
920 This method temporarily sets all of `apply`'s flags for a single call.
920 This method temporarily sets all of `apply`'s flags for a single call.
921
921
922 Parameters
922 Parameters
923 ----------
923 ----------
924
924
925 f : callable
925 f : callable
926
926
927 args : list [default: empty]
927 args : list [default: empty]
928
928
929 kwargs : dict [default: empty]
929 kwargs : dict [default: empty]
930
930
931 block : bool [default: self.block]
931 block : bool [default: self.block]
932 whether to block
932 whether to block
933 track : bool [default: self.track]
933 track : bool [default: self.track]
934 whether to ask zmq to track the message, for safe non-copying sends
934 whether to ask zmq to track the message, for safe non-copying sends
935
935
936 !!!!!! TODO: THE REST HERE !!!!
936 !!!!!! TODO: THE REST HERE !!!!
937
937
938 Returns
938 Returns
939 -------
939 -------
940
940
941 if self.block is False:
941 if self.block is False:
942 returns AsyncResult
942 returns AsyncResult
943 else:
943 else:
944 returns actual result of f(*args, **kwargs) on the engine(s)
944 returns actual result of f(*args, **kwargs) on the engine(s)
945 This will be a list of self.targets is also a list (even length 1), or
945 This will be a list of self.targets is also a list (even length 1), or
946 the single result if self.targets is an integer engine id
946 the single result if self.targets is an integer engine id
947 """
947 """
948
948
949 # validate whether we can run
949 # validate whether we can run
950 if self._socket.closed:
950 if self._socket.closed:
951 msg = "Task farming is disabled"
951 msg = "Task farming is disabled"
952 if self._task_scheme == 'pure':
952 if self._task_scheme == 'pure':
953 msg += " because the pure ZMQ scheduler cannot handle"
953 msg += " because the pure ZMQ scheduler cannot handle"
954 msg += " disappearing engines."
954 msg += " disappearing engines."
955 raise RuntimeError(msg)
955 raise RuntimeError(msg)
956
956
957 if self._task_scheme == 'pure':
957 if self._task_scheme == 'pure':
958 # pure zmq scheme doesn't support extra features
958 # pure zmq scheme doesn't support extra features
959 msg = "Pure ZMQ scheduler doesn't support the following flags:"
959 msg = "Pure ZMQ scheduler doesn't support the following flags:"
960 "follow, after, retries, targets, timeout"
960 "follow, after, retries, targets, timeout"
961 if (follow or after or retries or targets or timeout):
961 if (follow or after or retries or targets or timeout):
962 # hard fail on Scheduler flags
962 # hard fail on Scheduler flags
963 raise RuntimeError(msg)
963 raise RuntimeError(msg)
964 if isinstance(f, dependent):
964 if isinstance(f, dependent):
965 # soft warn on functional dependencies
965 # soft warn on functional dependencies
966 warnings.warn(msg, RuntimeWarning)
966 warnings.warn(msg, RuntimeWarning)
967
967
968 # build args
968 # build args
969 args = [] if args is None else args
969 args = [] if args is None else args
970 kwargs = {} if kwargs is None else kwargs
970 kwargs = {} if kwargs is None else kwargs
971 block = self.block if block is None else block
971 block = self.block if block is None else block
972 track = self.track if track is None else track
972 track = self.track if track is None else track
973 after = self.after if after is None else after
973 after = self.after if after is None else after
974 retries = self.retries if retries is None else retries
974 retries = self.retries if retries is None else retries
975 follow = self.follow if follow is None else follow
975 follow = self.follow if follow is None else follow
976 timeout = self.timeout if timeout is None else timeout
976 timeout = self.timeout if timeout is None else timeout
977 targets = self.targets if targets is None else targets
977 targets = self.targets if targets is None else targets
978
978
979 if not isinstance(retries, int):
979 if not isinstance(retries, int):
980 raise TypeError('retries must be int, not %r'%type(retries))
980 raise TypeError('retries must be int, not %r'%type(retries))
981
981
982 if targets is None:
982 if targets is None:
983 idents = []
983 idents = []
984 else:
984 else:
985 idents = self.client._build_targets(targets)[0]
985 idents = self.client._build_targets(targets)[0]
986 # ensure *not* bytes
986 # ensure *not* bytes
987 idents = [ ident.decode() for ident in idents ]
987 idents = [ ident.decode() for ident in idents ]
988
988
989 after = self._render_dependency(after)
989 after = self._render_dependency(after)
990 follow = self._render_dependency(follow)
990 follow = self._render_dependency(follow)
991 subheader = dict(after=after, follow=follow, timeout=timeout, targets=idents, retries=retries)
991 subheader = dict(after=after, follow=follow, timeout=timeout, targets=idents, retries=retries)
992
992
993 msg = self.client.send_apply_message(self._socket, f, args, kwargs, track=track,
993 msg = self.client.send_apply_message(self._socket, f, args, kwargs, track=track,
994 subheader=subheader)
994 subheader=subheader)
995 tracker = None if track is False else msg['tracker']
995 tracker = None if track is False else msg['tracker']
996
996
997 ar = AsyncResult(self.client, msg['header']['msg_id'], fname=getname(f), targets=None, tracker=tracker)
997 ar = AsyncResult(self.client, msg['header']['msg_id'], fname=getname(f), targets=None, tracker=tracker)
998
998
999 if block:
999 if block:
1000 try:
1000 try:
1001 return ar.get()
1001 return ar.get()
1002 except KeyboardInterrupt:
1002 except KeyboardInterrupt:
1003 pass
1003 pass
1004 return ar
1004 return ar
1005
1005
1006 @spin_after
1006 @spin_after
1007 @save_ids
1007 @save_ids
1008 def map(self, f, *sequences, **kwargs):
1008 def map(self, f, *sequences, **kwargs):
1009 """view.map(f, *sequences, block=self.block, chunksize=1, ordered=True) => list|AsyncMapResult
1009 """view.map(f, *sequences, block=self.block, chunksize=1, ordered=True) => list|AsyncMapResult
1010
1010
1011 Parallel version of builtin `map`, load-balanced by this View.
1011 Parallel version of builtin `map`, load-balanced by this View.
1012
1012
1013 `block`, and `chunksize` can be specified by keyword only.
1013 `block`, and `chunksize` can be specified by keyword only.
1014
1014
1015 Each `chunksize` elements will be a separate task, and will be
1015 Each `chunksize` elements will be a separate task, and will be
1016 load-balanced. This lets individual elements be available for iteration
1016 load-balanced. This lets individual elements be available for iteration
1017 as soon as they arrive.
1017 as soon as they arrive.
1018
1018
1019 Parameters
1019 Parameters
1020 ----------
1020 ----------
1021
1021
1022 f : callable
1022 f : callable
1023 function to be mapped
1023 function to be mapped
1024 *sequences: one or more sequences of matching length
1024 *sequences: one or more sequences of matching length
1025 the sequences to be distributed and passed to `f`
1025 the sequences to be distributed and passed to `f`
1026 block : bool [default self.block]
1026 block : bool [default self.block]
1027 whether to wait for the result or not
1027 whether to wait for the result or not
1028 track : bool
1028 track : bool
1029 whether to create a MessageTracker to allow the user to
1029 whether to create a MessageTracker to allow the user to
1030 safely edit after arrays and buffers during non-copying
1030 safely edit after arrays and buffers during non-copying
1031 sends.
1031 sends.
1032 chunksize : int [default 1]
1032 chunksize : int [default 1]
1033 how many elements should be in each task.
1033 how many elements should be in each task.
1034 ordered : bool [default True]
1034 ordered : bool [default True]
1035 Whether the results should be gathered as they arrive, or enforce
1035 Whether the results should be gathered as they arrive, or enforce
1036 the order of submission.
1036 the order of submission.
1037
1037
1038 Only applies when iterating through AsyncMapResult as results arrive.
1038 Only applies when iterating through AsyncMapResult as results arrive.
1039 Has no effect when block=True.
1039 Has no effect when block=True.
1040
1040
1041 Returns
1041 Returns
1042 -------
1042 -------
1043
1043
1044 if block=False:
1044 if block=False:
1045 AsyncMapResult
1045 AsyncMapResult
1046 An object like AsyncResult, but which reassembles the sequence of results
1046 An object like AsyncResult, but which reassembles the sequence of results
1047 into a single list. AsyncMapResults can be iterated through before all
1047 into a single list. AsyncMapResults can be iterated through before all
1048 results are complete.
1048 results are complete.
1049 else:
1049 else:
1050 the result of map(f,*sequences)
1050 the result of map(f,*sequences)
1051
1051
1052 """
1052 """
1053
1053
1054 # default
1054 # default
1055 block = kwargs.get('block', self.block)
1055 block = kwargs.get('block', self.block)
1056 chunksize = kwargs.get('chunksize', 1)
1056 chunksize = kwargs.get('chunksize', 1)
1057 ordered = kwargs.get('ordered', True)
1057 ordered = kwargs.get('ordered', True)
1058
1058
1059 keyset = set(kwargs.keys())
1059 keyset = set(kwargs.keys())
1060 extra_keys = keyset.difference_update(set(['block', 'chunksize']))
1060 extra_keys = keyset.difference_update(set(['block', 'chunksize']))
1061 if extra_keys:
1061 if extra_keys:
1062 raise TypeError("Invalid kwargs: %s"%list(extra_keys))
1062 raise TypeError("Invalid kwargs: %s"%list(extra_keys))
1063
1063
1064 assert len(sequences) > 0, "must have some sequences to map onto!"
1064 assert len(sequences) > 0, "must have some sequences to map onto!"
1065
1065
1066 pf = ParallelFunction(self, f, block=block, chunksize=chunksize, ordered=ordered)
1066 pf = ParallelFunction(self, f, block=block, chunksize=chunksize, ordered=ordered)
1067 return pf.map(*sequences)
1067 return pf.map(*sequences)
1068
1068
1069 __all__ = ['LoadBalancedView', 'DirectView']
1069 __all__ = ['LoadBalancedView', 'DirectView']
@@ -1,520 +1,537
1 # -*- coding: utf-8 -*-
1 # -*- coding: utf-8 -*-
2 """test View objects
2 """test View objects
3
3
4 Authors:
4 Authors:
5
5
6 * Min RK
6 * Min RK
7 """
7 """
8 #-------------------------------------------------------------------------------
8 #-------------------------------------------------------------------------------
9 # Copyright (C) 2011 The IPython Development Team
9 # Copyright (C) 2011 The IPython Development Team
10 #
10 #
11 # Distributed under the terms of the BSD License. The full license is in
11 # Distributed under the terms of the BSD License. The full license is in
12 # the file COPYING, distributed as part of this software.
12 # the file COPYING, distributed as part of this software.
13 #-------------------------------------------------------------------------------
13 #-------------------------------------------------------------------------------
14
14
15 #-------------------------------------------------------------------------------
15 #-------------------------------------------------------------------------------
16 # Imports
16 # Imports
17 #-------------------------------------------------------------------------------
17 #-------------------------------------------------------------------------------
18
18
19 import sys
19 import sys
20 import time
20 import time
21 from tempfile import mktemp
21 from tempfile import mktemp
22 from StringIO import StringIO
22 from StringIO import StringIO
23
23
24 import zmq
24 import zmq
25 from nose import SkipTest
25 from nose import SkipTest
26
26
27 from IPython.testing import decorators as dec
27 from IPython.testing import decorators as dec
28
28
29 from IPython import parallel as pmod
29 from IPython import parallel as pmod
30 from IPython.parallel import error
30 from IPython.parallel import error
31 from IPython.parallel import AsyncResult, AsyncHubResult, AsyncMapResult
31 from IPython.parallel import AsyncResult, AsyncHubResult, AsyncMapResult
32 from IPython.parallel import DirectView
32 from IPython.parallel import DirectView
33 from IPython.parallel.util import interactive
33 from IPython.parallel.util import interactive
34
34
35 from IPython.parallel.tests import add_engines
35 from IPython.parallel.tests import add_engines
36
36
37 from .clienttest import ClusterTestCase, crash, wait, skip_without
37 from .clienttest import ClusterTestCase, crash, wait, skip_without
38
38
39 def setup():
39 def setup():
40 add_engines(3, total=True)
40 add_engines(3, total=True)
41
41
42 class TestView(ClusterTestCase):
42 class TestView(ClusterTestCase):
43
43
44 def test_z_crash_mux(self):
44 def test_z_crash_mux(self):
45 """test graceful handling of engine death (direct)"""
45 """test graceful handling of engine death (direct)"""
46 raise SkipTest("crash tests disabled, due to undesirable crash reports")
46 raise SkipTest("crash tests disabled, due to undesirable crash reports")
47 # self.add_engines(1)
47 # self.add_engines(1)
48 eid = self.client.ids[-1]
48 eid = self.client.ids[-1]
49 ar = self.client[eid].apply_async(crash)
49 ar = self.client[eid].apply_async(crash)
50 self.assertRaisesRemote(error.EngineError, ar.get, 10)
50 self.assertRaisesRemote(error.EngineError, ar.get, 10)
51 eid = ar.engine_id
51 eid = ar.engine_id
52 tic = time.time()
52 tic = time.time()
53 while eid in self.client.ids and time.time()-tic < 5:
53 while eid in self.client.ids and time.time()-tic < 5:
54 time.sleep(.01)
54 time.sleep(.01)
55 self.client.spin()
55 self.client.spin()
56 self.assertFalse(eid in self.client.ids, "Engine should have died")
56 self.assertFalse(eid in self.client.ids, "Engine should have died")
57
57
58 def test_push_pull(self):
58 def test_push_pull(self):
59 """test pushing and pulling"""
59 """test pushing and pulling"""
60 data = dict(a=10, b=1.05, c=range(10), d={'e':(1,2),'f':'hi'})
60 data = dict(a=10, b=1.05, c=range(10), d={'e':(1,2),'f':'hi'})
61 t = self.client.ids[-1]
61 t = self.client.ids[-1]
62 v = self.client[t]
62 v = self.client[t]
63 push = v.push
63 push = v.push
64 pull = v.pull
64 pull = v.pull
65 v.block=True
65 v.block=True
66 nengines = len(self.client)
66 nengines = len(self.client)
67 push({'data':data})
67 push({'data':data})
68 d = pull('data')
68 d = pull('data')
69 self.assertEquals(d, data)
69 self.assertEquals(d, data)
70 self.client[:].push({'data':data})
70 self.client[:].push({'data':data})
71 d = self.client[:].pull('data', block=True)
71 d = self.client[:].pull('data', block=True)
72 self.assertEquals(d, nengines*[data])
72 self.assertEquals(d, nengines*[data])
73 ar = push({'data':data}, block=False)
73 ar = push({'data':data}, block=False)
74 self.assertTrue(isinstance(ar, AsyncResult))
74 self.assertTrue(isinstance(ar, AsyncResult))
75 r = ar.get()
75 r = ar.get()
76 ar = self.client[:].pull('data', block=False)
76 ar = self.client[:].pull('data', block=False)
77 self.assertTrue(isinstance(ar, AsyncResult))
77 self.assertTrue(isinstance(ar, AsyncResult))
78 r = ar.get()
78 r = ar.get()
79 self.assertEquals(r, nengines*[data])
79 self.assertEquals(r, nengines*[data])
80 self.client[:].push(dict(a=10,b=20))
80 self.client[:].push(dict(a=10,b=20))
81 r = self.client[:].pull(('a','b'), block=True)
81 r = self.client[:].pull(('a','b'), block=True)
82 self.assertEquals(r, nengines*[[10,20]])
82 self.assertEquals(r, nengines*[[10,20]])
83
83
84 def test_push_pull_function(self):
84 def test_push_pull_function(self):
85 "test pushing and pulling functions"
85 "test pushing and pulling functions"
86 def testf(x):
86 def testf(x):
87 return 2.0*x
87 return 2.0*x
88
88
89 t = self.client.ids[-1]
89 t = self.client.ids[-1]
90 v = self.client[t]
90 v = self.client[t]
91 v.block=True
91 v.block=True
92 push = v.push
92 push = v.push
93 pull = v.pull
93 pull = v.pull
94 execute = v.execute
94 execute = v.execute
95 push({'testf':testf})
95 push({'testf':testf})
96 r = pull('testf')
96 r = pull('testf')
97 self.assertEqual(r(1.0), testf(1.0))
97 self.assertEqual(r(1.0), testf(1.0))
98 execute('r = testf(10)')
98 execute('r = testf(10)')
99 r = pull('r')
99 r = pull('r')
100 self.assertEquals(r, testf(10))
100 self.assertEquals(r, testf(10))
101 ar = self.client[:].push({'testf':testf}, block=False)
101 ar = self.client[:].push({'testf':testf}, block=False)
102 ar.get()
102 ar.get()
103 ar = self.client[:].pull('testf', block=False)
103 ar = self.client[:].pull('testf', block=False)
104 rlist = ar.get()
104 rlist = ar.get()
105 for r in rlist:
105 for r in rlist:
106 self.assertEqual(r(1.0), testf(1.0))
106 self.assertEqual(r(1.0), testf(1.0))
107 execute("def g(x): return x*x")
107 execute("def g(x): return x*x")
108 r = pull(('testf','g'))
108 r = pull(('testf','g'))
109 self.assertEquals((r[0](10),r[1](10)), (testf(10), 100))
109 self.assertEquals((r[0](10),r[1](10)), (testf(10), 100))
110
110
111 def test_push_function_globals(self):
111 def test_push_function_globals(self):
112 """test that pushed functions have access to globals"""
112 """test that pushed functions have access to globals"""
113 @interactive
113 @interactive
114 def geta():
114 def geta():
115 return a
115 return a
116 # self.add_engines(1)
116 # self.add_engines(1)
117 v = self.client[-1]
117 v = self.client[-1]
118 v.block=True
118 v.block=True
119 v['f'] = geta
119 v['f'] = geta
120 self.assertRaisesRemote(NameError, v.execute, 'b=f()')
120 self.assertRaisesRemote(NameError, v.execute, 'b=f()')
121 v.execute('a=5')
121 v.execute('a=5')
122 v.execute('b=f()')
122 v.execute('b=f()')
123 self.assertEquals(v['b'], 5)
123 self.assertEquals(v['b'], 5)
124
124
125 def test_push_function_defaults(self):
125 def test_push_function_defaults(self):
126 """test that pushed functions preserve default args"""
126 """test that pushed functions preserve default args"""
127 def echo(a=10):
127 def echo(a=10):
128 return a
128 return a
129 v = self.client[-1]
129 v = self.client[-1]
130 v.block=True
130 v.block=True
131 v['f'] = echo
131 v['f'] = echo
132 v.execute('b=f()')
132 v.execute('b=f()')
133 self.assertEquals(v['b'], 10)
133 self.assertEquals(v['b'], 10)
134
134
135 def test_get_result(self):
135 def test_get_result(self):
136 """test getting results from the Hub."""
136 """test getting results from the Hub."""
137 c = pmod.Client(profile='iptest')
137 c = pmod.Client(profile='iptest')
138 # self.add_engines(1)
138 # self.add_engines(1)
139 t = c.ids[-1]
139 t = c.ids[-1]
140 v = c[t]
140 v = c[t]
141 v2 = self.client[t]
141 v2 = self.client[t]
142 ar = v.apply_async(wait, 1)
142 ar = v.apply_async(wait, 1)
143 # give the monitor time to notice the message
143 # give the monitor time to notice the message
144 time.sleep(.25)
144 time.sleep(.25)
145 ahr = v2.get_result(ar.msg_ids)
145 ahr = v2.get_result(ar.msg_ids)
146 self.assertTrue(isinstance(ahr, AsyncHubResult))
146 self.assertTrue(isinstance(ahr, AsyncHubResult))
147 self.assertEquals(ahr.get(), ar.get())
147 self.assertEquals(ahr.get(), ar.get())
148 ar2 = v2.get_result(ar.msg_ids)
148 ar2 = v2.get_result(ar.msg_ids)
149 self.assertFalse(isinstance(ar2, AsyncHubResult))
149 self.assertFalse(isinstance(ar2, AsyncHubResult))
150 c.spin()
150 c.spin()
151 c.close()
151 c.close()
152
152
153 def test_run_newline(self):
153 def test_run_newline(self):
154 """test that run appends newline to files"""
154 """test that run appends newline to files"""
155 tmpfile = mktemp()
155 tmpfile = mktemp()
156 with open(tmpfile, 'w') as f:
156 with open(tmpfile, 'w') as f:
157 f.write("""def g():
157 f.write("""def g():
158 return 5
158 return 5
159 """)
159 """)
160 v = self.client[-1]
160 v = self.client[-1]
161 v.run(tmpfile, block=True)
161 v.run(tmpfile, block=True)
162 self.assertEquals(v.apply_sync(lambda f: f(), pmod.Reference('g')), 5)
162 self.assertEquals(v.apply_sync(lambda f: f(), pmod.Reference('g')), 5)
163
163
164 def test_apply_tracked(self):
164 def test_apply_tracked(self):
165 """test tracking for apply"""
165 """test tracking for apply"""
166 # self.add_engines(1)
166 # self.add_engines(1)
167 t = self.client.ids[-1]
167 t = self.client.ids[-1]
168 v = self.client[t]
168 v = self.client[t]
169 v.block=False
169 v.block=False
170 def echo(n=1024*1024, **kwargs):
170 def echo(n=1024*1024, **kwargs):
171 with v.temp_flags(**kwargs):
171 with v.temp_flags(**kwargs):
172 return v.apply(lambda x: x, 'x'*n)
172 return v.apply(lambda x: x, 'x'*n)
173 ar = echo(1, track=False)
173 ar = echo(1, track=False)
174 self.assertTrue(isinstance(ar._tracker, zmq.MessageTracker))
174 self.assertTrue(isinstance(ar._tracker, zmq.MessageTracker))
175 self.assertTrue(ar.sent)
175 self.assertTrue(ar.sent)
176 ar = echo(track=True)
176 ar = echo(track=True)
177 self.assertTrue(isinstance(ar._tracker, zmq.MessageTracker))
177 self.assertTrue(isinstance(ar._tracker, zmq.MessageTracker))
178 self.assertEquals(ar.sent, ar._tracker.done)
178 self.assertEquals(ar.sent, ar._tracker.done)
179 ar._tracker.wait()
179 ar._tracker.wait()
180 self.assertTrue(ar.sent)
180 self.assertTrue(ar.sent)
181
181
182 def test_push_tracked(self):
182 def test_push_tracked(self):
183 t = self.client.ids[-1]
183 t = self.client.ids[-1]
184 ns = dict(x='x'*1024*1024)
184 ns = dict(x='x'*1024*1024)
185 v = self.client[t]
185 v = self.client[t]
186 ar = v.push(ns, block=False, track=False)
186 ar = v.push(ns, block=False, track=False)
187 self.assertTrue(isinstance(ar._tracker, zmq.MessageTracker))
187 self.assertTrue(isinstance(ar._tracker, zmq.MessageTracker))
188 self.assertTrue(ar.sent)
188 self.assertTrue(ar.sent)
189
189
190 ar = v.push(ns, block=False, track=True)
190 ar = v.push(ns, block=False, track=True)
191 self.assertTrue(isinstance(ar._tracker, zmq.MessageTracker))
191 self.assertTrue(isinstance(ar._tracker, zmq.MessageTracker))
192 ar._tracker.wait()
192 ar._tracker.wait()
193 self.assertEquals(ar.sent, ar._tracker.done)
193 self.assertEquals(ar.sent, ar._tracker.done)
194 self.assertTrue(ar.sent)
194 self.assertTrue(ar.sent)
195 ar.get()
195 ar.get()
196
196
197 def test_scatter_tracked(self):
197 def test_scatter_tracked(self):
198 t = self.client.ids
198 t = self.client.ids
199 x='x'*1024*1024
199 x='x'*1024*1024
200 ar = self.client[t].scatter('x', x, block=False, track=False)
200 ar = self.client[t].scatter('x', x, block=False, track=False)
201 self.assertTrue(isinstance(ar._tracker, zmq.MessageTracker))
201 self.assertTrue(isinstance(ar._tracker, zmq.MessageTracker))
202 self.assertTrue(ar.sent)
202 self.assertTrue(ar.sent)
203
203
204 ar = self.client[t].scatter('x', x, block=False, track=True)
204 ar = self.client[t].scatter('x', x, block=False, track=True)
205 self.assertTrue(isinstance(ar._tracker, zmq.MessageTracker))
205 self.assertTrue(isinstance(ar._tracker, zmq.MessageTracker))
206 self.assertEquals(ar.sent, ar._tracker.done)
206 self.assertEquals(ar.sent, ar._tracker.done)
207 ar._tracker.wait()
207 ar._tracker.wait()
208 self.assertTrue(ar.sent)
208 self.assertTrue(ar.sent)
209 ar.get()
209 ar.get()
210
210
211 def test_remote_reference(self):
211 def test_remote_reference(self):
212 v = self.client[-1]
212 v = self.client[-1]
213 v['a'] = 123
213 v['a'] = 123
214 ra = pmod.Reference('a')
214 ra = pmod.Reference('a')
215 b = v.apply_sync(lambda x: x, ra)
215 b = v.apply_sync(lambda x: x, ra)
216 self.assertEquals(b, 123)
216 self.assertEquals(b, 123)
217
217
218
218
219 def test_scatter_gather(self):
219 def test_scatter_gather(self):
220 view = self.client[:]
220 view = self.client[:]
221 seq1 = range(16)
221 seq1 = range(16)
222 view.scatter('a', seq1)
222 view.scatter('a', seq1)
223 seq2 = view.gather('a', block=True)
223 seq2 = view.gather('a', block=True)
224 self.assertEquals(seq2, seq1)
224 self.assertEquals(seq2, seq1)
225 self.assertRaisesRemote(NameError, view.gather, 'asdf', block=True)
225 self.assertRaisesRemote(NameError, view.gather, 'asdf', block=True)
226
226
227 @skip_without('numpy')
227 @skip_without('numpy')
228 def test_scatter_gather_numpy(self):
228 def test_scatter_gather_numpy(self):
229 import numpy
229 import numpy
230 from numpy.testing.utils import assert_array_equal, assert_array_almost_equal
230 from numpy.testing.utils import assert_array_equal, assert_array_almost_equal
231 view = self.client[:]
231 view = self.client[:]
232 a = numpy.arange(64)
232 a = numpy.arange(64)
233 view.scatter('a', a)
233 view.scatter('a', a)
234 b = view.gather('a', block=True)
234 b = view.gather('a', block=True)
235 assert_array_equal(b, a)
235 assert_array_equal(b, a)
236
237 @skip_without('numpy')
238 def test_push_numpy_nocopy(self):
239 import numpy
240 view = self.client[:]
241 a = numpy.arange(64)
242 view['A'] = a
243 @interactive
244 def check_writeable(x):
245 return x.flags.writeable
246
247 for flag in view.apply_sync(check_writeable, pmod.Reference('A')):
248 self.assertFalse(flag, "array is writeable, push shouldn't have pickled it")
249
250 view.push(dict(B=a))
251 for flag in view.apply_sync(check_writeable, pmod.Reference('B')):
252 self.assertFalse(flag, "array is writeable, push shouldn't have pickled it")
236
253
237 @skip_without('numpy')
254 @skip_without('numpy')
238 def test_apply_numpy(self):
255 def test_apply_numpy(self):
239 """view.apply(f, ndarray)"""
256 """view.apply(f, ndarray)"""
240 import numpy
257 import numpy
241 from numpy.testing.utils import assert_array_equal, assert_array_almost_equal
258 from numpy.testing.utils import assert_array_equal, assert_array_almost_equal
242
259
243 A = numpy.random.random((100,100))
260 A = numpy.random.random((100,100))
244 view = self.client[-1]
261 view = self.client[-1]
245 for dt in [ 'int32', 'uint8', 'float32', 'float64' ]:
262 for dt in [ 'int32', 'uint8', 'float32', 'float64' ]:
246 B = A.astype(dt)
263 B = A.astype(dt)
247 C = view.apply_sync(lambda x:x, B)
264 C = view.apply_sync(lambda x:x, B)
248 assert_array_equal(B,C)
265 assert_array_equal(B,C)
249
266
250 def test_map(self):
267 def test_map(self):
251 view = self.client[:]
268 view = self.client[:]
252 def f(x):
269 def f(x):
253 return x**2
270 return x**2
254 data = range(16)
271 data = range(16)
255 r = view.map_sync(f, data)
272 r = view.map_sync(f, data)
256 self.assertEquals(r, map(f, data))
273 self.assertEquals(r, map(f, data))
257
274
258 def test_map_iterable(self):
275 def test_map_iterable(self):
259 """test map on iterables (direct)"""
276 """test map on iterables (direct)"""
260 view = self.client[:]
277 view = self.client[:]
261 # 101 is prime, so it won't be evenly distributed
278 # 101 is prime, so it won't be evenly distributed
262 arr = range(101)
279 arr = range(101)
263 # ensure it will be an iterator, even in Python 3
280 # ensure it will be an iterator, even in Python 3
264 it = iter(arr)
281 it = iter(arr)
265 r = view.map_sync(lambda x:x, arr)
282 r = view.map_sync(lambda x:x, arr)
266 self.assertEquals(r, list(arr))
283 self.assertEquals(r, list(arr))
267
284
268 def test_scatterGatherNonblocking(self):
285 def test_scatterGatherNonblocking(self):
269 data = range(16)
286 data = range(16)
270 view = self.client[:]
287 view = self.client[:]
271 view.scatter('a', data, block=False)
288 view.scatter('a', data, block=False)
272 ar = view.gather('a', block=False)
289 ar = view.gather('a', block=False)
273 self.assertEquals(ar.get(), data)
290 self.assertEquals(ar.get(), data)
274
291
275 @skip_without('numpy')
292 @skip_without('numpy')
276 def test_scatter_gather_numpy_nonblocking(self):
293 def test_scatter_gather_numpy_nonblocking(self):
277 import numpy
294 import numpy
278 from numpy.testing.utils import assert_array_equal, assert_array_almost_equal
295 from numpy.testing.utils import assert_array_equal, assert_array_almost_equal
279 a = numpy.arange(64)
296 a = numpy.arange(64)
280 view = self.client[:]
297 view = self.client[:]
281 ar = view.scatter('a', a, block=False)
298 ar = view.scatter('a', a, block=False)
282 self.assertTrue(isinstance(ar, AsyncResult))
299 self.assertTrue(isinstance(ar, AsyncResult))
283 amr = view.gather('a', block=False)
300 amr = view.gather('a', block=False)
284 self.assertTrue(isinstance(amr, AsyncMapResult))
301 self.assertTrue(isinstance(amr, AsyncMapResult))
285 assert_array_equal(amr.get(), a)
302 assert_array_equal(amr.get(), a)
286
303
287 def test_execute(self):
304 def test_execute(self):
288 view = self.client[:]
305 view = self.client[:]
289 # self.client.debug=True
306 # self.client.debug=True
290 execute = view.execute
307 execute = view.execute
291 ar = execute('c=30', block=False)
308 ar = execute('c=30', block=False)
292 self.assertTrue(isinstance(ar, AsyncResult))
309 self.assertTrue(isinstance(ar, AsyncResult))
293 ar = execute('d=[0,1,2]', block=False)
310 ar = execute('d=[0,1,2]', block=False)
294 self.client.wait(ar, 1)
311 self.client.wait(ar, 1)
295 self.assertEquals(len(ar.get()), len(self.client))
312 self.assertEquals(len(ar.get()), len(self.client))
296 for c in view['c']:
313 for c in view['c']:
297 self.assertEquals(c, 30)
314 self.assertEquals(c, 30)
298
315
299 def test_abort(self):
316 def test_abort(self):
300 view = self.client[-1]
317 view = self.client[-1]
301 ar = view.execute('import time; time.sleep(1)', block=False)
318 ar = view.execute('import time; time.sleep(1)', block=False)
302 ar2 = view.apply_async(lambda : 2)
319 ar2 = view.apply_async(lambda : 2)
303 ar3 = view.apply_async(lambda : 3)
320 ar3 = view.apply_async(lambda : 3)
304 view.abort(ar2)
321 view.abort(ar2)
305 view.abort(ar3.msg_ids)
322 view.abort(ar3.msg_ids)
306 self.assertRaises(error.TaskAborted, ar2.get)
323 self.assertRaises(error.TaskAborted, ar2.get)
307 self.assertRaises(error.TaskAborted, ar3.get)
324 self.assertRaises(error.TaskAborted, ar3.get)
308
325
309 def test_abort_all(self):
326 def test_abort_all(self):
310 """view.abort() aborts all outstanding tasks"""
327 """view.abort() aborts all outstanding tasks"""
311 view = self.client[-1]
328 view = self.client[-1]
312 ars = [ view.apply_async(time.sleep, 0.25) for i in range(10) ]
329 ars = [ view.apply_async(time.sleep, 0.25) for i in range(10) ]
313 view.abort()
330 view.abort()
314 view.wait(timeout=5)
331 view.wait(timeout=5)
315 for ar in ars[5:]:
332 for ar in ars[5:]:
316 self.assertRaises(error.TaskAborted, ar.get)
333 self.assertRaises(error.TaskAborted, ar.get)
317
334
318 def test_temp_flags(self):
335 def test_temp_flags(self):
319 view = self.client[-1]
336 view = self.client[-1]
320 view.block=True
337 view.block=True
321 with view.temp_flags(block=False):
338 with view.temp_flags(block=False):
322 self.assertFalse(view.block)
339 self.assertFalse(view.block)
323 self.assertTrue(view.block)
340 self.assertTrue(view.block)
324
341
325 @dec.known_failure_py3
342 @dec.known_failure_py3
326 def test_importer(self):
343 def test_importer(self):
327 view = self.client[-1]
344 view = self.client[-1]
328 view.clear(block=True)
345 view.clear(block=True)
329 with view.importer:
346 with view.importer:
330 import re
347 import re
331
348
332 @interactive
349 @interactive
333 def findall(pat, s):
350 def findall(pat, s):
334 # this globals() step isn't necessary in real code
351 # this globals() step isn't necessary in real code
335 # only to prevent a closure in the test
352 # only to prevent a closure in the test
336 re = globals()['re']
353 re = globals()['re']
337 return re.findall(pat, s)
354 return re.findall(pat, s)
338
355
339 self.assertEquals(view.apply_sync(findall, '\w+', 'hello world'), 'hello world'.split())
356 self.assertEquals(view.apply_sync(findall, '\w+', 'hello world'), 'hello world'.split())
340
357
341 # parallel magic tests
358 # parallel magic tests
342
359
343 def test_magic_px_blocking(self):
360 def test_magic_px_blocking(self):
344 ip = get_ipython()
361 ip = get_ipython()
345 v = self.client[-1]
362 v = self.client[-1]
346 v.activate()
363 v.activate()
347 v.block=True
364 v.block=True
348
365
349 ip.magic_px('a=5')
366 ip.magic_px('a=5')
350 self.assertEquals(v['a'], 5)
367 self.assertEquals(v['a'], 5)
351 ip.magic_px('a=10')
368 ip.magic_px('a=10')
352 self.assertEquals(v['a'], 10)
369 self.assertEquals(v['a'], 10)
353 sio = StringIO()
370 sio = StringIO()
354 savestdout = sys.stdout
371 savestdout = sys.stdout
355 sys.stdout = sio
372 sys.stdout = sio
356 # just 'print a' worst ~99% of the time, but this ensures that
373 # just 'print a' worst ~99% of the time, but this ensures that
357 # the stdout message has arrived when the result is finished:
374 # the stdout message has arrived when the result is finished:
358 ip.magic_px('import sys,time;print (a); sys.stdout.flush();time.sleep(0.2)')
375 ip.magic_px('import sys,time;print (a); sys.stdout.flush();time.sleep(0.2)')
359 sys.stdout = savestdout
376 sys.stdout = savestdout
360 buf = sio.getvalue()
377 buf = sio.getvalue()
361 self.assertTrue('[stdout:' in buf, buf)
378 self.assertTrue('[stdout:' in buf, buf)
362 self.assertTrue(buf.rstrip().endswith('10'))
379 self.assertTrue(buf.rstrip().endswith('10'))
363 self.assertRaisesRemote(ZeroDivisionError, ip.magic_px, '1/0')
380 self.assertRaisesRemote(ZeroDivisionError, ip.magic_px, '1/0')
364
381
365 def test_magic_px_nonblocking(self):
382 def test_magic_px_nonblocking(self):
366 ip = get_ipython()
383 ip = get_ipython()
367 v = self.client[-1]
384 v = self.client[-1]
368 v.activate()
385 v.activate()
369 v.block=False
386 v.block=False
370
387
371 ip.magic_px('a=5')
388 ip.magic_px('a=5')
372 self.assertEquals(v['a'], 5)
389 self.assertEquals(v['a'], 5)
373 ip.magic_px('a=10')
390 ip.magic_px('a=10')
374 self.assertEquals(v['a'], 10)
391 self.assertEquals(v['a'], 10)
375 sio = StringIO()
392 sio = StringIO()
376 savestdout = sys.stdout
393 savestdout = sys.stdout
377 sys.stdout = sio
394 sys.stdout = sio
378 ip.magic_px('print a')
395 ip.magic_px('print a')
379 sys.stdout = savestdout
396 sys.stdout = savestdout
380 buf = sio.getvalue()
397 buf = sio.getvalue()
381 self.assertFalse('[stdout:%i]'%v.targets in buf)
398 self.assertFalse('[stdout:%i]'%v.targets in buf)
382 ip.magic_px('1/0')
399 ip.magic_px('1/0')
383 ar = v.get_result(-1)
400 ar = v.get_result(-1)
384 self.assertRaisesRemote(ZeroDivisionError, ar.get)
401 self.assertRaisesRemote(ZeroDivisionError, ar.get)
385
402
386 def test_magic_autopx_blocking(self):
403 def test_magic_autopx_blocking(self):
387 ip = get_ipython()
404 ip = get_ipython()
388 v = self.client[-1]
405 v = self.client[-1]
389 v.activate()
406 v.activate()
390 v.block=True
407 v.block=True
391
408
392 sio = StringIO()
409 sio = StringIO()
393 savestdout = sys.stdout
410 savestdout = sys.stdout
394 sys.stdout = sio
411 sys.stdout = sio
395 ip.magic_autopx()
412 ip.magic_autopx()
396 ip.run_cell('\n'.join(('a=5','b=10','c=0')))
413 ip.run_cell('\n'.join(('a=5','b=10','c=0')))
397 ip.run_cell('print b')
414 ip.run_cell('print b')
398 ip.run_cell("b/c")
415 ip.run_cell("b/c")
399 ip.run_code(compile('b*=2', '', 'single'))
416 ip.run_code(compile('b*=2', '', 'single'))
400 ip.magic_autopx()
417 ip.magic_autopx()
401 sys.stdout = savestdout
418 sys.stdout = savestdout
402 output = sio.getvalue().strip()
419 output = sio.getvalue().strip()
403 self.assertTrue(output.startswith('%autopx enabled'))
420 self.assertTrue(output.startswith('%autopx enabled'))
404 self.assertTrue(output.endswith('%autopx disabled'))
421 self.assertTrue(output.endswith('%autopx disabled'))
405 self.assertTrue('RemoteError: ZeroDivisionError' in output)
422 self.assertTrue('RemoteError: ZeroDivisionError' in output)
406 ar = v.get_result(-2)
423 ar = v.get_result(-2)
407 self.assertEquals(v['a'], 5)
424 self.assertEquals(v['a'], 5)
408 self.assertEquals(v['b'], 20)
425 self.assertEquals(v['b'], 20)
409 self.assertRaisesRemote(ZeroDivisionError, ar.get)
426 self.assertRaisesRemote(ZeroDivisionError, ar.get)
410
427
411 def test_magic_autopx_nonblocking(self):
428 def test_magic_autopx_nonblocking(self):
412 ip = get_ipython()
429 ip = get_ipython()
413 v = self.client[-1]
430 v = self.client[-1]
414 v.activate()
431 v.activate()
415 v.block=False
432 v.block=False
416
433
417 sio = StringIO()
434 sio = StringIO()
418 savestdout = sys.stdout
435 savestdout = sys.stdout
419 sys.stdout = sio
436 sys.stdout = sio
420 ip.magic_autopx()
437 ip.magic_autopx()
421 ip.run_cell('\n'.join(('a=5','b=10','c=0')))
438 ip.run_cell('\n'.join(('a=5','b=10','c=0')))
422 ip.run_cell('print b')
439 ip.run_cell('print b')
423 ip.run_cell("b/c")
440 ip.run_cell("b/c")
424 ip.run_code(compile('b*=2', '', 'single'))
441 ip.run_code(compile('b*=2', '', 'single'))
425 ip.magic_autopx()
442 ip.magic_autopx()
426 sys.stdout = savestdout
443 sys.stdout = savestdout
427 output = sio.getvalue().strip()
444 output = sio.getvalue().strip()
428 self.assertTrue(output.startswith('%autopx enabled'))
445 self.assertTrue(output.startswith('%autopx enabled'))
429 self.assertTrue(output.endswith('%autopx disabled'))
446 self.assertTrue(output.endswith('%autopx disabled'))
430 self.assertFalse('ZeroDivisionError' in output)
447 self.assertFalse('ZeroDivisionError' in output)
431 ar = v.get_result(-2)
448 ar = v.get_result(-2)
432 self.assertEquals(v['a'], 5)
449 self.assertEquals(v['a'], 5)
433 self.assertEquals(v['b'], 20)
450 self.assertEquals(v['b'], 20)
434 self.assertRaisesRemote(ZeroDivisionError, ar.get)
451 self.assertRaisesRemote(ZeroDivisionError, ar.get)
435
452
436 def test_magic_result(self):
453 def test_magic_result(self):
437 ip = get_ipython()
454 ip = get_ipython()
438 v = self.client[-1]
455 v = self.client[-1]
439 v.activate()
456 v.activate()
440 v['a'] = 111
457 v['a'] = 111
441 ra = v['a']
458 ra = v['a']
442
459
443 ar = ip.magic_result()
460 ar = ip.magic_result()
444 self.assertEquals(ar.msg_ids, [v.history[-1]])
461 self.assertEquals(ar.msg_ids, [v.history[-1]])
445 self.assertEquals(ar.get(), 111)
462 self.assertEquals(ar.get(), 111)
446 ar = ip.magic_result('-2')
463 ar = ip.magic_result('-2')
447 self.assertEquals(ar.msg_ids, [v.history[-2]])
464 self.assertEquals(ar.msg_ids, [v.history[-2]])
448
465
449 def test_unicode_execute(self):
466 def test_unicode_execute(self):
450 """test executing unicode strings"""
467 """test executing unicode strings"""
451 v = self.client[-1]
468 v = self.client[-1]
452 v.block=True
469 v.block=True
453 if sys.version_info[0] >= 3:
470 if sys.version_info[0] >= 3:
454 code="a='é'"
471 code="a='é'"
455 else:
472 else:
456 code=u"a=u'é'"
473 code=u"a=u'é'"
457 v.execute(code)
474 v.execute(code)
458 self.assertEquals(v['a'], u'é')
475 self.assertEquals(v['a'], u'é')
459
476
460 def test_unicode_apply_result(self):
477 def test_unicode_apply_result(self):
461 """test unicode apply results"""
478 """test unicode apply results"""
462 v = self.client[-1]
479 v = self.client[-1]
463 r = v.apply_sync(lambda : u'é')
480 r = v.apply_sync(lambda : u'é')
464 self.assertEquals(r, u'é')
481 self.assertEquals(r, u'é')
465
482
466 def test_unicode_apply_arg(self):
483 def test_unicode_apply_arg(self):
467 """test passing unicode arguments to apply"""
484 """test passing unicode arguments to apply"""
468 v = self.client[-1]
485 v = self.client[-1]
469
486
470 @interactive
487 @interactive
471 def check_unicode(a, check):
488 def check_unicode(a, check):
472 assert isinstance(a, unicode), "%r is not unicode"%a
489 assert isinstance(a, unicode), "%r is not unicode"%a
473 assert isinstance(check, bytes), "%r is not bytes"%check
490 assert isinstance(check, bytes), "%r is not bytes"%check
474 assert a.encode('utf8') == check, "%s != %s"%(a,check)
491 assert a.encode('utf8') == check, "%s != %s"%(a,check)
475
492
476 for s in [ u'é', u'ßø®∫',u'asdf' ]:
493 for s in [ u'é', u'ßø®∫',u'asdf' ]:
477 try:
494 try:
478 v.apply_sync(check_unicode, s, s.encode('utf8'))
495 v.apply_sync(check_unicode, s, s.encode('utf8'))
479 except error.RemoteError as e:
496 except error.RemoteError as e:
480 if e.ename == 'AssertionError':
497 if e.ename == 'AssertionError':
481 self.fail(e.evalue)
498 self.fail(e.evalue)
482 else:
499 else:
483 raise e
500 raise e
484
501
485 def test_map_reference(self):
502 def test_map_reference(self):
486 """view.map(<Reference>, *seqs) should work"""
503 """view.map(<Reference>, *seqs) should work"""
487 v = self.client[:]
504 v = self.client[:]
488 v.scatter('n', self.client.ids, flatten=True)
505 v.scatter('n', self.client.ids, flatten=True)
489 v.execute("f = lambda x,y: x*y")
506 v.execute("f = lambda x,y: x*y")
490 rf = pmod.Reference('f')
507 rf = pmod.Reference('f')
491 nlist = list(range(10))
508 nlist = list(range(10))
492 mlist = nlist[::-1]
509 mlist = nlist[::-1]
493 expected = [ m*n for m,n in zip(mlist, nlist) ]
510 expected = [ m*n for m,n in zip(mlist, nlist) ]
494 result = v.map_sync(rf, mlist, nlist)
511 result = v.map_sync(rf, mlist, nlist)
495 self.assertEquals(result, expected)
512 self.assertEquals(result, expected)
496
513
497 def test_apply_reference(self):
514 def test_apply_reference(self):
498 """view.apply(<Reference>, *args) should work"""
515 """view.apply(<Reference>, *args) should work"""
499 v = self.client[:]
516 v = self.client[:]
500 v.scatter('n', self.client.ids, flatten=True)
517 v.scatter('n', self.client.ids, flatten=True)
501 v.execute("f = lambda x: n*x")
518 v.execute("f = lambda x: n*x")
502 rf = pmod.Reference('f')
519 rf = pmod.Reference('f')
503 result = v.apply_sync(rf, 5)
520 result = v.apply_sync(rf, 5)
504 expected = [ 5*id for id in self.client.ids ]
521 expected = [ 5*id for id in self.client.ids ]
505 self.assertEquals(result, expected)
522 self.assertEquals(result, expected)
506
523
507 def test_eval_reference(self):
524 def test_eval_reference(self):
508 v = self.client[self.client.ids[0]]
525 v = self.client[self.client.ids[0]]
509 v['g'] = range(5)
526 v['g'] = range(5)
510 rg = pmod.Reference('g[0]')
527 rg = pmod.Reference('g[0]')
511 echo = lambda x:x
528 echo = lambda x:x
512 self.assertEquals(v.apply_sync(echo, rg), 0)
529 self.assertEquals(v.apply_sync(echo, rg), 0)
513
530
514 def test_reference_nameerror(self):
531 def test_reference_nameerror(self):
515 v = self.client[self.client.ids[0]]
532 v = self.client[self.client.ids[0]]
516 r = pmod.Reference('elvis_has_left')
533 r = pmod.Reference('elvis_has_left')
517 echo = lambda x:x
534 echo = lambda x:x
518 self.assertRaisesRemote(NameError, v.apply_sync, echo, r)
535 self.assertRaisesRemote(NameError, v.apply_sync, echo, r)
519
536
520
537
@@ -1,480 +1,480
1 """some generic utilities for dealing with classes, urls, and serialization
1 """some generic utilities for dealing with classes, urls, and serialization
2
2
3 Authors:
3 Authors:
4
4
5 * Min RK
5 * Min RK
6 """
6 """
7 #-----------------------------------------------------------------------------
7 #-----------------------------------------------------------------------------
8 # Copyright (C) 2010-2011 The IPython Development Team
8 # Copyright (C) 2010-2011 The IPython Development Team
9 #
9 #
10 # Distributed under the terms of the BSD License. The full license is in
10 # Distributed under the terms of the BSD License. The full license is in
11 # the file COPYING, distributed as part of this software.
11 # the file COPYING, distributed as part of this software.
12 #-----------------------------------------------------------------------------
12 #-----------------------------------------------------------------------------
13
13
14 #-----------------------------------------------------------------------------
14 #-----------------------------------------------------------------------------
15 # Imports
15 # Imports
16 #-----------------------------------------------------------------------------
16 #-----------------------------------------------------------------------------
17
17
18 # Standard library imports.
18 # Standard library imports.
19 import logging
19 import logging
20 import os
20 import os
21 import re
21 import re
22 import stat
22 import stat
23 import socket
23 import socket
24 import sys
24 import sys
25 from signal import signal, SIGINT, SIGABRT, SIGTERM
25 from signal import signal, SIGINT, SIGABRT, SIGTERM
26 try:
26 try:
27 from signal import SIGKILL
27 from signal import SIGKILL
28 except ImportError:
28 except ImportError:
29 SIGKILL=None
29 SIGKILL=None
30
30
31 try:
31 try:
32 import cPickle
32 import cPickle
33 pickle = cPickle
33 pickle = cPickle
34 except:
34 except:
35 cPickle = None
35 cPickle = None
36 import pickle
36 import pickle
37
37
38 # System library imports
38 # System library imports
39 import zmq
39 import zmq
40 from zmq.log import handlers
40 from zmq.log import handlers
41
41
42 # IPython imports
42 # IPython imports
43 from IPython.config.application import Application
43 from IPython.config.application import Application
44 from IPython.utils import py3compat
44 from IPython.utils import py3compat
45 from IPython.utils.pickleutil import can, uncan, canSequence, uncanSequence
45 from IPython.utils.pickleutil import can, uncan, canSequence, uncanSequence
46 from IPython.utils.newserialized import serialize, unserialize
46 from IPython.utils.newserialized import serialize, unserialize
47 from IPython.zmq.log import EnginePUBHandler
47 from IPython.zmq.log import EnginePUBHandler
48
48
49 if py3compat.PY3:
49 if py3compat.PY3:
50 buffer = memoryview
50 buffer = memoryview
51
51
52 #-----------------------------------------------------------------------------
52 #-----------------------------------------------------------------------------
53 # Classes
53 # Classes
54 #-----------------------------------------------------------------------------
54 #-----------------------------------------------------------------------------
55
55
56 class Namespace(dict):
56 class Namespace(dict):
57 """Subclass of dict for attribute access to keys."""
57 """Subclass of dict for attribute access to keys."""
58
58
59 def __getattr__(self, key):
59 def __getattr__(self, key):
60 """getattr aliased to getitem"""
60 """getattr aliased to getitem"""
61 if key in self.iterkeys():
61 if key in self.iterkeys():
62 return self[key]
62 return self[key]
63 else:
63 else:
64 raise NameError(key)
64 raise NameError(key)
65
65
66 def __setattr__(self, key, value):
66 def __setattr__(self, key, value):
67 """setattr aliased to setitem, with strict"""
67 """setattr aliased to setitem, with strict"""
68 if hasattr(dict, key):
68 if hasattr(dict, key):
69 raise KeyError("Cannot override dict keys %r"%key)
69 raise KeyError("Cannot override dict keys %r"%key)
70 self[key] = value
70 self[key] = value
71
71
72
72
73 class ReverseDict(dict):
73 class ReverseDict(dict):
74 """simple double-keyed subset of dict methods."""
74 """simple double-keyed subset of dict methods."""
75
75
76 def __init__(self, *args, **kwargs):
76 def __init__(self, *args, **kwargs):
77 dict.__init__(self, *args, **kwargs)
77 dict.__init__(self, *args, **kwargs)
78 self._reverse = dict()
78 self._reverse = dict()
79 for key, value in self.iteritems():
79 for key, value in self.iteritems():
80 self._reverse[value] = key
80 self._reverse[value] = key
81
81
82 def __getitem__(self, key):
82 def __getitem__(self, key):
83 try:
83 try:
84 return dict.__getitem__(self, key)
84 return dict.__getitem__(self, key)
85 except KeyError:
85 except KeyError:
86 return self._reverse[key]
86 return self._reverse[key]
87
87
88 def __setitem__(self, key, value):
88 def __setitem__(self, key, value):
89 if key in self._reverse:
89 if key in self._reverse:
90 raise KeyError("Can't have key %r on both sides!"%key)
90 raise KeyError("Can't have key %r on both sides!"%key)
91 dict.__setitem__(self, key, value)
91 dict.__setitem__(self, key, value)
92 self._reverse[value] = key
92 self._reverse[value] = key
93
93
94 def pop(self, key):
94 def pop(self, key):
95 value = dict.pop(self, key)
95 value = dict.pop(self, key)
96 self._reverse.pop(value)
96 self._reverse.pop(value)
97 return value
97 return value
98
98
99 def get(self, key, default=None):
99 def get(self, key, default=None):
100 try:
100 try:
101 return self[key]
101 return self[key]
102 except KeyError:
102 except KeyError:
103 return default
103 return default
104
104
105 #-----------------------------------------------------------------------------
105 #-----------------------------------------------------------------------------
106 # Functions
106 # Functions
107 #-----------------------------------------------------------------------------
107 #-----------------------------------------------------------------------------
108
108
109 def asbytes(s):
109 def asbytes(s):
110 """ensure that an object is ascii bytes"""
110 """ensure that an object is ascii bytes"""
111 if isinstance(s, unicode):
111 if isinstance(s, unicode):
112 s = s.encode('ascii')
112 s = s.encode('ascii')
113 return s
113 return s
114
114
115 def is_url(url):
115 def is_url(url):
116 """boolean check for whether a string is a zmq url"""
116 """boolean check for whether a string is a zmq url"""
117 if '://' not in url:
117 if '://' not in url:
118 return False
118 return False
119 proto, addr = url.split('://', 1)
119 proto, addr = url.split('://', 1)
120 if proto.lower() not in ['tcp','pgm','epgm','ipc','inproc']:
120 if proto.lower() not in ['tcp','pgm','epgm','ipc','inproc']:
121 return False
121 return False
122 return True
122 return True
123
123
124 def validate_url(url):
124 def validate_url(url):
125 """validate a url for zeromq"""
125 """validate a url for zeromq"""
126 if not isinstance(url, basestring):
126 if not isinstance(url, basestring):
127 raise TypeError("url must be a string, not %r"%type(url))
127 raise TypeError("url must be a string, not %r"%type(url))
128 url = url.lower()
128 url = url.lower()
129
129
130 proto_addr = url.split('://')
130 proto_addr = url.split('://')
131 assert len(proto_addr) == 2, 'Invalid url: %r'%url
131 assert len(proto_addr) == 2, 'Invalid url: %r'%url
132 proto, addr = proto_addr
132 proto, addr = proto_addr
133 assert proto in ['tcp','pgm','epgm','ipc','inproc'], "Invalid protocol: %r"%proto
133 assert proto in ['tcp','pgm','epgm','ipc','inproc'], "Invalid protocol: %r"%proto
134
134
135 # domain pattern adapted from http://www.regexlib.com/REDetails.aspx?regexp_id=391
135 # domain pattern adapted from http://www.regexlib.com/REDetails.aspx?regexp_id=391
136 # author: Remi Sabourin
136 # author: Remi Sabourin
137 pat = re.compile(r'^([\w\d]([\w\d\-]{0,61}[\w\d])?\.)*[\w\d]([\w\d\-]{0,61}[\w\d])?$')
137 pat = re.compile(r'^([\w\d]([\w\d\-]{0,61}[\w\d])?\.)*[\w\d]([\w\d\-]{0,61}[\w\d])?$')
138
138
139 if proto == 'tcp':
139 if proto == 'tcp':
140 lis = addr.split(':')
140 lis = addr.split(':')
141 assert len(lis) == 2, 'Invalid url: %r'%url
141 assert len(lis) == 2, 'Invalid url: %r'%url
142 addr,s_port = lis
142 addr,s_port = lis
143 try:
143 try:
144 port = int(s_port)
144 port = int(s_port)
145 except ValueError:
145 except ValueError:
146 raise AssertionError("Invalid port %r in url: %r"%(port, url))
146 raise AssertionError("Invalid port %r in url: %r"%(port, url))
147
147
148 assert addr == '*' or pat.match(addr) is not None, 'Invalid url: %r'%url
148 assert addr == '*' or pat.match(addr) is not None, 'Invalid url: %r'%url
149
149
150 else:
150 else:
151 # only validate tcp urls currently
151 # only validate tcp urls currently
152 pass
152 pass
153
153
154 return True
154 return True
155
155
156
156
157 def validate_url_container(container):
157 def validate_url_container(container):
158 """validate a potentially nested collection of urls."""
158 """validate a potentially nested collection of urls."""
159 if isinstance(container, basestring):
159 if isinstance(container, basestring):
160 url = container
160 url = container
161 return validate_url(url)
161 return validate_url(url)
162 elif isinstance(container, dict):
162 elif isinstance(container, dict):
163 container = container.itervalues()
163 container = container.itervalues()
164
164
165 for element in container:
165 for element in container:
166 validate_url_container(element)
166 validate_url_container(element)
167
167
168
168
169 def split_url(url):
169 def split_url(url):
170 """split a zmq url (tcp://ip:port) into ('tcp','ip','port')."""
170 """split a zmq url (tcp://ip:port) into ('tcp','ip','port')."""
171 proto_addr = url.split('://')
171 proto_addr = url.split('://')
172 assert len(proto_addr) == 2, 'Invalid url: %r'%url
172 assert len(proto_addr) == 2, 'Invalid url: %r'%url
173 proto, addr = proto_addr
173 proto, addr = proto_addr
174 lis = addr.split(':')
174 lis = addr.split(':')
175 assert len(lis) == 2, 'Invalid url: %r'%url
175 assert len(lis) == 2, 'Invalid url: %r'%url
176 addr,s_port = lis
176 addr,s_port = lis
177 return proto,addr,s_port
177 return proto,addr,s_port
178
178
179 def disambiguate_ip_address(ip, location=None):
179 def disambiguate_ip_address(ip, location=None):
180 """turn multi-ip interfaces '0.0.0.0' and '*' into connectable
180 """turn multi-ip interfaces '0.0.0.0' and '*' into connectable
181 ones, based on the location (default interpretation of location is localhost)."""
181 ones, based on the location (default interpretation of location is localhost)."""
182 if ip in ('0.0.0.0', '*'):
182 if ip in ('0.0.0.0', '*'):
183 try:
183 try:
184 external_ips = socket.gethostbyname_ex(socket.gethostname())[2]
184 external_ips = socket.gethostbyname_ex(socket.gethostname())[2]
185 except (socket.gaierror, IndexError):
185 except (socket.gaierror, IndexError):
186 # couldn't identify this machine, assume localhost
186 # couldn't identify this machine, assume localhost
187 external_ips = []
187 external_ips = []
188 if location is None or location in external_ips or not external_ips:
188 if location is None or location in external_ips or not external_ips:
189 # If location is unspecified or cannot be determined, assume local
189 # If location is unspecified or cannot be determined, assume local
190 ip='127.0.0.1'
190 ip='127.0.0.1'
191 elif location:
191 elif location:
192 return location
192 return location
193 return ip
193 return ip
194
194
195 def disambiguate_url(url, location=None):
195 def disambiguate_url(url, location=None):
196 """turn multi-ip interfaces '0.0.0.0' and '*' into connectable
196 """turn multi-ip interfaces '0.0.0.0' and '*' into connectable
197 ones, based on the location (default interpretation is localhost).
197 ones, based on the location (default interpretation is localhost).
198
198
199 This is for zeromq urls, such as tcp://*:10101."""
199 This is for zeromq urls, such as tcp://*:10101."""
200 try:
200 try:
201 proto,ip,port = split_url(url)
201 proto,ip,port = split_url(url)
202 except AssertionError:
202 except AssertionError:
203 # probably not tcp url; could be ipc, etc.
203 # probably not tcp url; could be ipc, etc.
204 return url
204 return url
205
205
206 ip = disambiguate_ip_address(ip,location)
206 ip = disambiguate_ip_address(ip,location)
207
207
208 return "%s://%s:%s"%(proto,ip,port)
208 return "%s://%s:%s"%(proto,ip,port)
209
209
210 def serialize_object(obj, threshold=64e-6):
210 def serialize_object(obj, threshold=64e-6):
211 """Serialize an object into a list of sendable buffers.
211 """Serialize an object into a list of sendable buffers.
212
212
213 Parameters
213 Parameters
214 ----------
214 ----------
215
215
216 obj : object
216 obj : object
217 The object to be serialized
217 The object to be serialized
218 threshold : float
218 threshold : float
219 The threshold for not double-pickling the content.
219 The threshold for not double-pickling the content.
220
220
221
221
222 Returns
222 Returns
223 -------
223 -------
224 ('pmd', [bufs]) :
224 ('pmd', [bufs]) :
225 where pmd is the pickled metadata wrapper,
225 where pmd is the pickled metadata wrapper,
226 bufs is a list of data buffers
226 bufs is a list of data buffers
227 """
227 """
228 databuffers = []
228 databuffers = []
229 if isinstance(obj, (list, tuple)):
229 if isinstance(obj, (list, tuple)):
230 clist = canSequence(obj)
230 clist = canSequence(obj)
231 slist = map(serialize, clist)
231 slist = map(serialize, clist)
232 for s in slist:
232 for s in slist:
233 if s.typeDescriptor in ('buffer', 'ndarray') or s.getDataSize() > threshold:
233 if s.typeDescriptor in ('buffer', 'ndarray') or s.getDataSize() > threshold:
234 databuffers.append(s.getData())
234 databuffers.append(s.getData())
235 s.data = None
235 s.data = None
236 return pickle.dumps(slist,-1), databuffers
236 return pickle.dumps(slist,-1), databuffers
237 elif isinstance(obj, dict):
237 elif isinstance(obj, dict):
238 sobj = {}
238 sobj = {}
239 for k in sorted(obj.iterkeys()):
239 for k in sorted(obj.iterkeys()):
240 s = serialize(can(obj[k]))
240 s = serialize(can(obj[k]))
241 if s.typeDescriptor in ('buffer', 'ndarray') or s.getDataSize() > threshold:
241 if s.typeDescriptor in ('buffer', 'ndarray') or s.getDataSize() > threshold:
242 databuffers.append(s.getData())
242 databuffers.append(s.getData())
243 s.data = None
243 s.data = None
244 sobj[k] = s
244 sobj[k] = s
245 return pickle.dumps(sobj,-1),databuffers
245 return pickle.dumps(sobj,-1),databuffers
246 else:
246 else:
247 s = serialize(can(obj))
247 s = serialize(can(obj))
248 if s.typeDescriptor in ('buffer', 'ndarray') or s.getDataSize() > threshold:
248 if s.typeDescriptor in ('buffer', 'ndarray') or s.getDataSize() > threshold:
249 databuffers.append(s.getData())
249 databuffers.append(s.getData())
250 s.data = None
250 s.data = None
251 return pickle.dumps(s,-1),databuffers
251 return pickle.dumps(s,-1),databuffers
252
252
253
253
254 def unserialize_object(bufs):
254 def unserialize_object(bufs):
255 """reconstruct an object serialized by serialize_object from data buffers."""
255 """reconstruct an object serialized by serialize_object from data buffers."""
256 bufs = list(bufs)
256 bufs = list(bufs)
257 sobj = pickle.loads(bufs.pop(0))
257 sobj = pickle.loads(bufs.pop(0))
258 if isinstance(sobj, (list, tuple)):
258 if isinstance(sobj, (list, tuple)):
259 for s in sobj:
259 for s in sobj:
260 if s.data is None:
260 if s.data is None:
261 s.data = bufs.pop(0)
261 s.data = bufs.pop(0)
262 return uncanSequence(map(unserialize, sobj)), bufs
262 return uncanSequence(map(unserialize, sobj)), bufs
263 elif isinstance(sobj, dict):
263 elif isinstance(sobj, dict):
264 newobj = {}
264 newobj = {}
265 for k in sorted(sobj.iterkeys()):
265 for k in sorted(sobj.iterkeys()):
266 s = sobj[k]
266 s = sobj[k]
267 if s.data is None:
267 if s.data is None:
268 s.data = bufs.pop(0)
268 s.data = bufs.pop(0)
269 newobj[k] = uncan(unserialize(s))
269 newobj[k] = uncan(unserialize(s))
270 return newobj, bufs
270 return newobj, bufs
271 else:
271 else:
272 if sobj.data is None:
272 if sobj.data is None:
273 sobj.data = bufs.pop(0)
273 sobj.data = bufs.pop(0)
274 return uncan(unserialize(sobj)), bufs
274 return uncan(unserialize(sobj)), bufs
275
275
276 def pack_apply_message(f, args, kwargs, threshold=64e-6):
276 def pack_apply_message(f, args, kwargs, threshold=64e-6):
277 """pack up a function, args, and kwargs to be sent over the wire
277 """pack up a function, args, and kwargs to be sent over the wire
278 as a series of buffers. Any object whose data is larger than `threshold`
278 as a series of buffers. Any object whose data is larger than `threshold`
279 will not have their data copied (currently only numpy arrays support zero-copy)"""
279 will not have their data copied (currently only numpy arrays support zero-copy)"""
280 msg = [pickle.dumps(can(f),-1)]
280 msg = [pickle.dumps(can(f),-1)]
281 databuffers = [] # for large objects
281 databuffers = [] # for large objects
282 sargs, bufs = serialize_object(args,threshold)
282 sargs, bufs = serialize_object(args,threshold)
283 msg.append(sargs)
283 msg.append(sargs)
284 databuffers.extend(bufs)
284 databuffers.extend(bufs)
285 skwargs, bufs = serialize_object(kwargs,threshold)
285 skwargs, bufs = serialize_object(kwargs,threshold)
286 msg.append(skwargs)
286 msg.append(skwargs)
287 databuffers.extend(bufs)
287 databuffers.extend(bufs)
288 msg.extend(databuffers)
288 msg.extend(databuffers)
289 return msg
289 return msg
290
290
291 def unpack_apply_message(bufs, g=None, copy=True):
291 def unpack_apply_message(bufs, g=None, copy=True):
292 """unpack f,args,kwargs from buffers packed by pack_apply_message()
292 """unpack f,args,kwargs from buffers packed by pack_apply_message()
293 Returns: original f,args,kwargs"""
293 Returns: original f,args,kwargs"""
294 bufs = list(bufs) # allow us to pop
294 bufs = list(bufs) # allow us to pop
295 assert len(bufs) >= 3, "not enough buffers!"
295 assert len(bufs) >= 3, "not enough buffers!"
296 if not copy:
296 if not copy:
297 for i in range(3):
297 for i in range(3):
298 bufs[i] = bufs[i].bytes
298 bufs[i] = bufs[i].bytes
299 cf = pickle.loads(bufs.pop(0))
299 cf = pickle.loads(bufs.pop(0))
300 sargs = list(pickle.loads(bufs.pop(0)))
300 sargs = list(pickle.loads(bufs.pop(0)))
301 skwargs = dict(pickle.loads(bufs.pop(0)))
301 skwargs = dict(pickle.loads(bufs.pop(0)))
302 # print sargs, skwargs
302 # print sargs, skwargs
303 f = uncan(cf, g)
303 f = uncan(cf, g)
304 for sa in sargs:
304 for sa in sargs:
305 if sa.data is None:
305 if sa.data is None:
306 m = bufs.pop(0)
306 m = bufs.pop(0)
307 if sa.getTypeDescriptor() in ('buffer', 'ndarray'):
307 if sa.getTypeDescriptor() in ('buffer', 'ndarray'):
308 # always use a buffer, until memoryviews get sorted out
308 # always use a buffer, until memoryviews get sorted out
309 sa.data = buffer(m)
309 sa.data = buffer(m)
310 # disable memoryview support
310 # disable memoryview support
311 # if copy:
311 # if copy:
312 # sa.data = buffer(m)
312 # sa.data = buffer(m)
313 # else:
313 # else:
314 # sa.data = m.buffer
314 # sa.data = m.buffer
315 else:
315 else:
316 if copy:
316 if copy:
317 sa.data = m
317 sa.data = m
318 else:
318 else:
319 sa.data = m.bytes
319 sa.data = m.bytes
320
320
321 args = uncanSequence(map(unserialize, sargs), g)
321 args = uncanSequence(map(unserialize, sargs), g)
322 kwargs = {}
322 kwargs = {}
323 for k in sorted(skwargs.iterkeys()):
323 for k in sorted(skwargs.iterkeys()):
324 sa = skwargs[k]
324 sa = skwargs[k]
325 if sa.data is None:
325 if sa.data is None:
326 m = bufs.pop(0)
326 m = bufs.pop(0)
327 if sa.getTypeDescriptor() in ('buffer', 'ndarray'):
327 if sa.getTypeDescriptor() in ('buffer', 'ndarray'):
328 # always use a buffer, until memoryviews get sorted out
328 # always use a buffer, until memoryviews get sorted out
329 sa.data = buffer(m)
329 sa.data = buffer(m)
330 # disable memoryview support
330 # disable memoryview support
331 # if copy:
331 # if copy:
332 # sa.data = buffer(m)
332 # sa.data = buffer(m)
333 # else:
333 # else:
334 # sa.data = m.buffer
334 # sa.data = m.buffer
335 else:
335 else:
336 if copy:
336 if copy:
337 sa.data = m
337 sa.data = m
338 else:
338 else:
339 sa.data = m.bytes
339 sa.data = m.bytes
340
340
341 kwargs[k] = uncan(unserialize(sa), g)
341 kwargs[k] = uncan(unserialize(sa), g)
342
342
343 return f,args,kwargs
343 return f,args,kwargs
344
344
345 #--------------------------------------------------------------------------
345 #--------------------------------------------------------------------------
346 # helpers for implementing old MEC API via view.apply
346 # helpers for implementing old MEC API via view.apply
347 #--------------------------------------------------------------------------
347 #--------------------------------------------------------------------------
348
348
349 def interactive(f):
349 def interactive(f):
350 """decorator for making functions appear as interactively defined.
350 """decorator for making functions appear as interactively defined.
351 This results in the function being linked to the user_ns as globals()
351 This results in the function being linked to the user_ns as globals()
352 instead of the module globals().
352 instead of the module globals().
353 """
353 """
354 f.__module__ = '__main__'
354 f.__module__ = '__main__'
355 return f
355 return f
356
356
357 @interactive
357 @interactive
358 def _push(ns):
358 def _push(**ns):
359 """helper method for implementing `client.push` via `client.apply`"""
359 """helper method for implementing `client.push` via `client.apply`"""
360 globals().update(ns)
360 globals().update(ns)
361
361
362 @interactive
362 @interactive
363 def _pull(keys):
363 def _pull(keys):
364 """helper method for implementing `client.pull` via `client.apply`"""
364 """helper method for implementing `client.pull` via `client.apply`"""
365 user_ns = globals()
365 user_ns = globals()
366 if isinstance(keys, (list,tuple, set)):
366 if isinstance(keys, (list,tuple, set)):
367 for key in keys:
367 for key in keys:
368 if not user_ns.has_key(key):
368 if not user_ns.has_key(key):
369 raise NameError("name '%s' is not defined"%key)
369 raise NameError("name '%s' is not defined"%key)
370 return map(user_ns.get, keys)
370 return map(user_ns.get, keys)
371 else:
371 else:
372 if not user_ns.has_key(keys):
372 if not user_ns.has_key(keys):
373 raise NameError("name '%s' is not defined"%keys)
373 raise NameError("name '%s' is not defined"%keys)
374 return user_ns.get(keys)
374 return user_ns.get(keys)
375
375
376 @interactive
376 @interactive
377 def _execute(code):
377 def _execute(code):
378 """helper method for implementing `client.execute` via `client.apply`"""
378 """helper method for implementing `client.execute` via `client.apply`"""
379 exec code in globals()
379 exec code in globals()
380
380
381 #--------------------------------------------------------------------------
381 #--------------------------------------------------------------------------
382 # extra process management utilities
382 # extra process management utilities
383 #--------------------------------------------------------------------------
383 #--------------------------------------------------------------------------
384
384
385 _random_ports = set()
385 _random_ports = set()
386
386
387 def select_random_ports(n):
387 def select_random_ports(n):
388 """Selects and return n random ports that are available."""
388 """Selects and return n random ports that are available."""
389 ports = []
389 ports = []
390 for i in xrange(n):
390 for i in xrange(n):
391 sock = socket.socket()
391 sock = socket.socket()
392 sock.bind(('', 0))
392 sock.bind(('', 0))
393 while sock.getsockname()[1] in _random_ports:
393 while sock.getsockname()[1] in _random_ports:
394 sock.close()
394 sock.close()
395 sock = socket.socket()
395 sock = socket.socket()
396 sock.bind(('', 0))
396 sock.bind(('', 0))
397 ports.append(sock)
397 ports.append(sock)
398 for i, sock in enumerate(ports):
398 for i, sock in enumerate(ports):
399 port = sock.getsockname()[1]
399 port = sock.getsockname()[1]
400 sock.close()
400 sock.close()
401 ports[i] = port
401 ports[i] = port
402 _random_ports.add(port)
402 _random_ports.add(port)
403 return ports
403 return ports
404
404
405 def signal_children(children):
405 def signal_children(children):
406 """Relay interupt/term signals to children, for more solid process cleanup."""
406 """Relay interupt/term signals to children, for more solid process cleanup."""
407 def terminate_children(sig, frame):
407 def terminate_children(sig, frame):
408 log = Application.instance().log
408 log = Application.instance().log
409 log.critical("Got signal %i, terminating children..."%sig)
409 log.critical("Got signal %i, terminating children..."%sig)
410 for child in children:
410 for child in children:
411 child.terminate()
411 child.terminate()
412
412
413 sys.exit(sig != SIGINT)
413 sys.exit(sig != SIGINT)
414 # sys.exit(sig)
414 # sys.exit(sig)
415 for sig in (SIGINT, SIGABRT, SIGTERM):
415 for sig in (SIGINT, SIGABRT, SIGTERM):
416 signal(sig, terminate_children)
416 signal(sig, terminate_children)
417
417
418 def generate_exec_key(keyfile):
418 def generate_exec_key(keyfile):
419 import uuid
419 import uuid
420 newkey = str(uuid.uuid4())
420 newkey = str(uuid.uuid4())
421 with open(keyfile, 'w') as f:
421 with open(keyfile, 'w') as f:
422 # f.write('ipython-key ')
422 # f.write('ipython-key ')
423 f.write(newkey+'\n')
423 f.write(newkey+'\n')
424 # set user-only RW permissions (0600)
424 # set user-only RW permissions (0600)
425 # this will have no effect on Windows
425 # this will have no effect on Windows
426 os.chmod(keyfile, stat.S_IRUSR|stat.S_IWUSR)
426 os.chmod(keyfile, stat.S_IRUSR|stat.S_IWUSR)
427
427
428
428
429 def integer_loglevel(loglevel):
429 def integer_loglevel(loglevel):
430 try:
430 try:
431 loglevel = int(loglevel)
431 loglevel = int(loglevel)
432 except ValueError:
432 except ValueError:
433 if isinstance(loglevel, str):
433 if isinstance(loglevel, str):
434 loglevel = getattr(logging, loglevel)
434 loglevel = getattr(logging, loglevel)
435 return loglevel
435 return loglevel
436
436
437 def connect_logger(logname, context, iface, root="ip", loglevel=logging.DEBUG):
437 def connect_logger(logname, context, iface, root="ip", loglevel=logging.DEBUG):
438 logger = logging.getLogger(logname)
438 logger = logging.getLogger(logname)
439 if any([isinstance(h, handlers.PUBHandler) for h in logger.handlers]):
439 if any([isinstance(h, handlers.PUBHandler) for h in logger.handlers]):
440 # don't add a second PUBHandler
440 # don't add a second PUBHandler
441 return
441 return
442 loglevel = integer_loglevel(loglevel)
442 loglevel = integer_loglevel(loglevel)
443 lsock = context.socket(zmq.PUB)
443 lsock = context.socket(zmq.PUB)
444 lsock.connect(iface)
444 lsock.connect(iface)
445 handler = handlers.PUBHandler(lsock)
445 handler = handlers.PUBHandler(lsock)
446 handler.setLevel(loglevel)
446 handler.setLevel(loglevel)
447 handler.root_topic = root
447 handler.root_topic = root
448 logger.addHandler(handler)
448 logger.addHandler(handler)
449 logger.setLevel(loglevel)
449 logger.setLevel(loglevel)
450
450
451 def connect_engine_logger(context, iface, engine, loglevel=logging.DEBUG):
451 def connect_engine_logger(context, iface, engine, loglevel=logging.DEBUG):
452 logger = logging.getLogger()
452 logger = logging.getLogger()
453 if any([isinstance(h, handlers.PUBHandler) for h in logger.handlers]):
453 if any([isinstance(h, handlers.PUBHandler) for h in logger.handlers]):
454 # don't add a second PUBHandler
454 # don't add a second PUBHandler
455 return
455 return
456 loglevel = integer_loglevel(loglevel)
456 loglevel = integer_loglevel(loglevel)
457 lsock = context.socket(zmq.PUB)
457 lsock = context.socket(zmq.PUB)
458 lsock.connect(iface)
458 lsock.connect(iface)
459 handler = EnginePUBHandler(engine, lsock)
459 handler = EnginePUBHandler(engine, lsock)
460 handler.setLevel(loglevel)
460 handler.setLevel(loglevel)
461 logger.addHandler(handler)
461 logger.addHandler(handler)
462 logger.setLevel(loglevel)
462 logger.setLevel(loglevel)
463 return logger
463 return logger
464
464
465 def local_logger(logname, loglevel=logging.DEBUG):
465 def local_logger(logname, loglevel=logging.DEBUG):
466 loglevel = integer_loglevel(loglevel)
466 loglevel = integer_loglevel(loglevel)
467 logger = logging.getLogger(logname)
467 logger = logging.getLogger(logname)
468 if any([isinstance(h, logging.StreamHandler) for h in logger.handlers]):
468 if any([isinstance(h, logging.StreamHandler) for h in logger.handlers]):
469 # don't add a second StreamHandler
469 # don't add a second StreamHandler
470 return
470 return
471 handler = logging.StreamHandler()
471 handler = logging.StreamHandler()
472 handler.setLevel(loglevel)
472 handler.setLevel(loglevel)
473 formatter = logging.Formatter("%(asctime)s.%(msecs).03d [%(name)s] %(message)s",
473 formatter = logging.Formatter("%(asctime)s.%(msecs).03d [%(name)s] %(message)s",
474 datefmt="%Y-%m-%d %H:%M:%S")
474 datefmt="%Y-%m-%d %H:%M:%S")
475 handler.setFormatter(formatter)
475 handler.setFormatter(formatter)
476
476
477 logger.addHandler(handler)
477 logger.addHandler(handler)
478 logger.setLevel(loglevel)
478 logger.setLevel(loglevel)
479 return logger
479 return logger
480
480
General Comments 0
You need to be logged in to leave comments. Login now