##// END OF EJS Templates
add DirectView.importer contextmanager, demote targets to mutable flag...
MinRK -
Show More
@@ -10,14 +10,15 b''
10 # Imports
10 # Imports
11 #-----------------------------------------------------------------------------
11 #-----------------------------------------------------------------------------
12
12
13 # from .asyncresult import *
14 # from .client import Client
15 # from .dependency import *
16 # from .remotefunction import *
17 # from .view import *
18
19 import zmq
13 import zmq
20
14
21 if zmq.__version__ < '2.1.3':
15 if zmq.__version__ < '2.1.3':
22 raise ImportError("IPython.zmq.parallel requires pyzmq/0MQ >= 2.1.3, you appear to have %s"%zmq.__version__)
16 raise ImportError("IPython.zmq.parallel requires pyzmq/0MQ >= 2.1.3, you appear to have %s"%zmq.__version__)
23
17
18 from .asyncresult import *
19 from .client import Client
20 from .dependency import *
21 from .remotefunction import *
22 from .view import *
23
24
@@ -245,8 +245,6 b' class Client(HasTraits):'
245 _mux_socket=Instance('zmq.Socket')
245 _mux_socket=Instance('zmq.Socket')
246 _task_socket=Instance('zmq.Socket')
246 _task_socket=Instance('zmq.Socket')
247 _task_scheme=Str()
247 _task_scheme=Str()
248 _balanced_views=Dict()
249 _direct_views=Dict()
250 _closed = False
248 _closed = False
251 _ignored_control_replies=Int(0)
249 _ignored_control_replies=Int(0)
252 _ignored_hub_replies=Int(0)
250 _ignored_hub_replies=Int(0)
@@ -389,7 +387,20 b' class Client(HasTraits):'
389 else:
387 else:
390 raise TypeError("%r not valid str target, must be 'all'"%(targets))
388 raise TypeError("%r not valid str target, must be 'all'"%(targets))
391 elif isinstance(targets, int):
389 elif isinstance(targets, int):
390 if targets < 0:
391 targets = self.ids[targets]
392 if targets not in self.ids:
393 raise IndexError("No such engine: %i"%targets)
392 targets = [targets]
394 targets = [targets]
395
396 if isinstance(targets, slice):
397 indices = range(len(self._ids))[targets]
398 ids = self.ids
399 targets = [ ids[i] for i in indices ]
400
401 if not isinstance(targets, (tuple, list, xrange)):
402 raise TypeError("targets by int/slice/collection of ints only, not %s"%(type(targets)))
403
393 return [self._engines[t] for t in targets], list(targets)
404 return [self._engines[t] for t in targets], list(targets)
394
405
395 def _connect(self, sshserver, ssh_kwargs, timeout):
406 def _connect(self, sshserver, ssh_kwargs, timeout):
@@ -688,7 +699,7 b' class Client(HasTraits):'
688 if not isinstance(key, (int, slice, tuple, list, xrange)):
699 if not isinstance(key, (int, slice, tuple, list, xrange)):
689 raise TypeError("key by int/slice/iterable of ints only, not %s"%(type(key)))
700 raise TypeError("key by int/slice/iterable of ints only, not %s"%(type(key)))
690 else:
701 else:
691 return self._get_view(key, balanced=False)
702 return self.direct_view(key)
692
703
693 #--------------------------------------------------------------------------
704 #--------------------------------------------------------------------------
694 # Begin public methods
705 # Begin public methods
@@ -962,31 +973,6 b' class Client(HasTraits):'
962 # construct a View object
973 # construct a View object
963 #--------------------------------------------------------------------------
974 #--------------------------------------------------------------------------
964
975
965 def _cache_view(self, targets, balanced):
966 """save views, so subsequent requests don't create new objects."""
967 if balanced:
968 # validate whether we can run
969 if not self._task_socket:
970 msg = "Task farming is disabled"
971 if self._task_scheme == 'pure':
972 msg += " because the pure ZMQ scheduler cannot handle"
973 msg += " disappearing engines."
974 raise RuntimeError(msg)
975 socket = self._task_socket
976 view_class = LoadBalancedView
977 view_cache = self._balanced_views
978 else:
979 socket = self._mux_socket
980 view_class = DirectView
981 view_cache = self._direct_views
982
983 # use str, since often targets will be a list
984 key = str(targets)
985 if key not in view_cache:
986 view_cache[key] = view_class(client=self, socket=socket, targets=targets)
987
988 return view_cache[key]
989
990 def load_balanced_view(self, targets=None):
976 def load_balanced_view(self, targets=None):
991 """construct a DirectView object.
977 """construct a DirectView object.
992
978
@@ -999,7 +985,9 b' class Client(HasTraits):'
999 targets: list,slice,int,etc. [default: use all engines]
985 targets: list,slice,int,etc. [default: use all engines]
1000 The subset of engines across which to load-balance
986 The subset of engines across which to load-balance
1001 """
987 """
1002 return self._get_view(targets, balanced=True)
988 if targets is None:
989 targets = self._build_targets(targets)[1]
990 return LoadBalancedView(client=self, socket=self._task_socket, targets=targets)
1003
991
1004 def direct_view(self, targets='all'):
992 def direct_view(self, targets='all'):
1005 """construct a DirectView object.
993 """construct a DirectView object.
@@ -1013,49 +1001,11 b' class Client(HasTraits):'
1013 targets: list,slice,int,etc. [default: use all engines]
1001 targets: list,slice,int,etc. [default: use all engines]
1014 The engines to use for the View
1002 The engines to use for the View
1015 """
1003 """
1016 return self._get_view(targets, balanced=False)
1004 single = isinstance(targets, int)
1017
1005 targets = self._build_targets(targets)[1]
1018 def _get_view(self, targets, balanced):
1006 if single:
1019 """Method for constructing View objects.
1007 targets = targets[0]
1020
1008 return DirectView(client=self, socket=self._mux_socket, targets=targets)
1021 If no arguments are specified, create a LoadBalancedView
1022 using all engines. If only `targets` specified, it will
1023 be a DirectView. This method is the underlying implementation
1024 of ``client.__getitem__``.
1025
1026 Parameters
1027 ----------
1028
1029 targets: list,slice,int,etc. [default: use all engines]
1030 The engines to use for the View
1031 balanced : bool [default: False if targets specified, True else]
1032 whether to build a LoadBalancedView or a DirectView
1033
1034 """
1035
1036 if targets in (None,'all'):
1037 if balanced:
1038 return self._cache_view(None,True)
1039 else:
1040 targets = slice(None)
1041
1042 if isinstance(targets, int):
1043 if targets < 0:
1044 targets = self.ids[targets]
1045 if targets not in self.ids:
1046 raise IndexError("No such engine: %i"%targets)
1047 return self._cache_view(targets, balanced)
1048
1049 if isinstance(targets, slice):
1050 indices = range(len(self.ids))[targets]
1051 ids = sorted(self._ids)
1052 targets = [ ids[i] for i in indices ]
1053
1054 if isinstance(targets, (tuple, list, xrange)):
1055 _,targets = self._build_targets(list(targets))
1056 return self._cache_view(targets, balanced)
1057 else:
1058 raise TypeError("targets by int/slice/collection of ints only, not %s"%(type(targets)))
1059
1009
1060 #--------------------------------------------------------------------------
1010 #--------------------------------------------------------------------------
1061 # Data movement (TO BE REMOVED)
1011 # Data movement (TO BE REMOVED)
@@ -6,6 +6,8 b''
6 # the file COPYING, distributed as part of this software.
6 # the file COPYING, distributed as part of this software.
7 #-----------------------------------------------------------------------------
7 #-----------------------------------------------------------------------------
8
8
9 from types import ModuleType
10
9 from .asyncresult import AsyncResult
11 from .asyncresult import AsyncResult
10 from .error import UnmetDependency
12 from .error import UnmetDependency
11 from .util import interactive
13 from .util import interactive
@@ -76,7 +78,7 b' def _require(*names):'
76 raise UnmetDependency(name)
78 raise UnmetDependency(name)
77 return True
79 return True
78
80
79 def require(*names):
81 def require(*mods):
80 """Simple decorator for requiring names to be importable.
82 """Simple decorator for requiring names to be importable.
81
83
82 Examples
84 Examples
@@ -87,6 +89,16 b' def require(*names):'
87 ...: import numpy
89 ...: import numpy
88 ...: return numpy.linalg.norm(a,2)
90 ...: return numpy.linalg.norm(a,2)
89 """
91 """
92 names = []
93 for mod in mods:
94 if isinstance(mod, ModuleType):
95 mod = mod.__name__
96
97 if isinstance(mod, basestring):
98 names.append(mod)
99 else:
100 raise TypeError("names must be modules or module names, not %s"%type(mod))
101
90 return depend(_require, *names)
102 return depend(_require, *names)
91
103
92 class Dependency(set):
104 class Dependency(set):
@@ -48,9 +48,9 b' from IPython.utils.process import find_cmd, pycmd2argv, FindCmdError'
48
48
49 from .factory import LoggingFactory
49 from .factory import LoggingFactory
50
50
51 # load winhpcjob from IPython.kernel
51 # load winhpcjob only on Windows
52 try:
52 try:
53 from IPython.kernel.winhpcjob import (
53 from .winhpcjob import (
54 IPControllerTask, IPEngineTask,
54 IPControllerTask, IPEngineTask,
55 IPControllerJob, IPEngineSetJob
55 IPControllerJob, IPEngineSetJob
56 )
56 )
@@ -61,15 +61,6 b' class TestClient(ClusterTestCase):'
61 self.assertEquals(v.targets, targets[-1])
61 self.assertEquals(v.targets, targets[-1])
62 self.assertRaises(TypeError, lambda : self.client[None])
62 self.assertRaises(TypeError, lambda : self.client[None])
63
63
64 def test_view_cache(self):
65 """test that multiple view requests return the same object"""
66 v = self.client[:2]
67 v2 =self.client[:2]
68 self.assertTrue(v is v2)
69 v = self.client.load_balanced_view()
70 v2 = self.client.load_balanced_view(targets=None)
71 self.assertTrue(v is v2)
72
73 def test_targets(self):
64 def test_targets(self):
74 """test various valid targets arguments"""
65 """test various valid targets arguments"""
75 build = self.client._build_targets
66 build = self.client._build_targets
@@ -285,3 +285,17 b' class TestView(ClusterTestCase):'
285 self.assertFalse(view.block)
285 self.assertFalse(view.block)
286 self.assertTrue(view.block)
286 self.assertTrue(view.block)
287
287
288 def test_importer(self):
289 view = self.client[-1]
290 view.clear(block=True)
291 with view.importer:
292 import re
293
294 @interactive
295 def findall(pat, s):
296 # this globals() step isn't necessary in real code
297 # only to prevent a closure in the test
298 return globals()['re'].findall(pat, s)
299
300 self.assertEquals(view.apply_sync(findall, '\w+', 'hello world'), 'hello world'.split())
301
@@ -10,13 +10,16 b''
10 # Imports
10 # Imports
11 #-----------------------------------------------------------------------------
11 #-----------------------------------------------------------------------------
12
12
13 import imp
14 import sys
13 import warnings
15 import warnings
14 from contextlib import contextmanager
16 from contextlib import contextmanager
17 from types import ModuleType
15
18
16 import zmq
19 import zmq
17
20
18 from IPython.testing import decorators as testdec
21 from IPython.testing import decorators as testdec
19 from IPython.utils.traitlets import HasTraits, Any, Bool, List, Dict, Set, Int, Instance
22 from IPython.utils.traitlets import HasTraits, Any, Bool, List, Dict, Set, Int, Instance, CFloat
20
23
21 from IPython.external.decorator import decorator
24 from IPython.external.decorator import decorator
22
25
@@ -94,48 +97,36 b' class View(HasTraits):'
94 abort, shutdown
97 abort, shutdown
95
98
96 """
99 """
100 # flags
97 block=Bool(False)
101 block=Bool(False)
98 track=Bool(True)
102 track=Bool(True)
103 targets = Any()
104
99 history=List()
105 history=List()
100 outstanding = Set()
106 outstanding = Set()
101 results = Dict()
107 results = Dict()
102 client = Instance('IPython.zmq.parallel.client.Client')
108 client = Instance('IPython.zmq.parallel.client.Client')
103
109
104 _socket = Instance('zmq.Socket')
110 _socket = Instance('zmq.Socket')
105 _ntargets = Int(1)
111 _flag_names = List(['targets', 'block', 'track'])
106 _flag_names = List(['block', 'track'])
107 _targets = Any()
112 _targets = Any()
108 _idents = Any()
113 _idents = Any()
109
114
110 def __init__(self, client=None, socket=None, targets=None):
115 def __init__(self, client=None, socket=None, **flags):
111 super(View, self).__init__(client=client, _socket=socket)
116 super(View, self).__init__(client=client, _socket=socket)
112 self._ntargets = 1 if isinstance(targets, (int,type(None))) else len(targets)
113 self.block = client.block
117 self.block = client.block
114
118
115 self._idents, self._targets = self.client._build_targets(targets)
119 self.set_flags(**flags)
116 if targets is None or isinstance(targets, int):
117 self._targets = targets
118 for name in self._flag_names:
119 # set flags, if they haven't been set yet
120 setattr(self, name, getattr(self, name, None))
121
120
122 assert not self.__class__ is View, "Don't use base View objects, use subclasses"
121 assert not self.__class__ is View, "Don't use base View objects, use subclasses"
123
122
124
123
125 def __repr__(self):
124 def __repr__(self):
126 strtargets = str(self._targets)
125 strtargets = str(self.targets)
127 if len(strtargets) > 16:
126 if len(strtargets) > 16:
128 strtargets = strtargets[:12]+'...]'
127 strtargets = strtargets[:12]+'...]'
129 return "<%s %s>"%(self.__class__.__name__, strtargets)
128 return "<%s %s>"%(self.__class__.__name__, strtargets)
130
129
131 @property
132 def targets(self):
133 return self._targets
134
135 @targets.setter
136 def targets(self, value):
137 raise AttributeError("Cannot set View `targets` after construction!")
138
139 def set_flags(self, **kwargs):
130 def set_flags(self, **kwargs):
140 """set my attribute flags by keyword.
131 """set my attribute flags by keyword.
141
132
@@ -182,9 +173,11 b' class View(HasTraits):'
182 saved_flags[f] = getattr(self, f)
173 saved_flags[f] = getattr(self, f)
183 self.set_flags(**kwargs)
174 self.set_flags(**kwargs)
184 # yield to the with-statement block
175 # yield to the with-statement block
185 yield
176 try:
186 # postflight: restore saved flags
177 yield
187 self.set_flags(**saved_flags)
178 finally:
179 # postflight: restore saved flags
180 self.set_flags(**saved_flags)
188
181
189
182
190 #----------------------------------------------------------------
183 #----------------------------------------------------------------
@@ -258,7 +251,7 b' class View(HasTraits):'
258 jobs = self.history
251 jobs = self.history
259 return self.client.wait(jobs, timeout)
252 return self.client.wait(jobs, timeout)
260
253
261 def abort(self, jobs=None, block=None):
254 def abort(self, jobs=None, targets=None, block=None):
262 """Abort jobs on my engines.
255 """Abort jobs on my engines.
263
256
264 Parameters
257 Parameters
@@ -269,16 +262,18 b' class View(HasTraits):'
269 else: abort specific msg_id(s).
262 else: abort specific msg_id(s).
270 """
263 """
271 block = block if block is not None else self.block
264 block = block if block is not None else self.block
272 return self.client.abort(jobs=jobs, targets=self._targets, block=block)
265 targets = targets if targets is not None else self.targets
266 return self.client.abort(jobs=jobs, targets=targets, block=block)
273
267
274 def queue_status(self, verbose=False):
268 def queue_status(self, targets=None, verbose=False):
275 """Fetch the Queue status of my engines"""
269 """Fetch the Queue status of my engines"""
276 return self.client.queue_status(targets=self._targets, verbose=verbose)
270 targets = targets if targets is not None else self.targets
271 return self.client.queue_status(targets=targets, verbose=verbose)
277
272
278 def purge_results(self, jobs=[], targets=[]):
273 def purge_results(self, jobs=[], targets=[]):
279 """Instruct the controller to forget specific results."""
274 """Instruct the controller to forget specific results."""
280 if targets is None or targets == 'all':
275 if targets is None or targets == 'all':
281 targets = self._targets
276 targets = self.targets
282 return self.client.purge_results(jobs=jobs, targets=targets)
277 return self.client.purge_results(jobs=jobs, targets=targets)
283
278
284 @spin_after
279 @spin_after
@@ -377,11 +372,104 b' class DirectView(View):'
377
372
378 def __init__(self, client=None, socket=None, targets=None):
373 def __init__(self, client=None, socket=None, targets=None):
379 super(DirectView, self).__init__(client=client, socket=socket, targets=targets)
374 super(DirectView, self).__init__(client=client, socket=socket, targets=targets)
375
376 @property
377 def importer(self):
378 """sync_imports(local=True) as a property.
380
379
380 See sync_imports for details.
381
382 In [10]: with v.importer:
383 ....: import numpy
384 ....:
385 importing numpy on engine(s)
386
387 """
388 return self.sync_imports(True)
389
390 @contextmanager
391 def sync_imports(self, local=True):
392 """Context Manager for performing simultaneous local and remote imports.
393
394 'import x as y' will *not* work. The 'as y' part will simply be ignored.
395
396 >>> with view.sync_imports():
397 ... from numpy import recarray
398 importing recarray from numpy on engine(s)
399
400 """
401 import __builtin__
402 local_import = __builtin__.__import__
403 modules = set()
404 results = []
405 @util.interactive
406 def remote_import(name, fromlist, level):
407 """the function to be passed to apply, that actually performs the import
408 on the engine, and loads up the user namespace.
409 """
410 import sys
411 user_ns = globals()
412 mod = __import__(name, fromlist=fromlist, level=level)
413 if fromlist:
414 for key in fromlist:
415 user_ns[key] = getattr(mod, key)
416 else:
417 user_ns[name] = sys.modules[name]
418
419 def view_import(name, globals={}, locals={}, fromlist=[], level=-1):
420 """the drop-in replacement for __import__, that optionally imports
421 locally as well.
422 """
423 # don't override nested imports
424 save_import = __builtin__.__import__
425 __builtin__.__import__ = local_import
426
427 if imp.lock_held():
428 # this is a side-effect import, don't do it remotely, or even
429 # ignore the local effects
430 return local_import(name, globals, locals, fromlist, level)
431
432 imp.acquire_lock()
433 if local:
434 mod = local_import(name, globals, locals, fromlist, level)
435 else:
436 raise NotImplementedError("remote-only imports not yet implemented")
437 imp.release_lock()
438
439 key = name+':'+','.join(fromlist or [])
440 if level == -1 and key not in modules:
441 modules.add(key)
442 if fromlist:
443 print "importing %s from %s on engine(s)"%(','.join(fromlist), name)
444 else:
445 print "importing %s on engine(s)"%name
446 results.append(self.apply_async(remote_import, name, fromlist, level))
447 # restore override
448 __builtin__.__import__ = save_import
449
450 return mod
451
452 # override __import__
453 __builtin__.__import__ = view_import
454 try:
455 # enter the block
456 yield
457 except ImportError:
458 if not local:
459 # ignore import errors if not doing local imports
460 pass
461 finally:
462 # always restore __import__
463 __builtin__.__import__ = local_import
464
465 for r in results:
466 # raise possible remote ImportErrors here
467 r.get()
468
381
469
382 @sync_results
470 @sync_results
383 @save_ids
471 @save_ids
384 def _really_apply(self, f, args=None, kwargs=None, block=None, track=None):
472 def _really_apply(self, f, args=None, kwargs=None, targets=None, block=None, track=None):
385 """calls f(*args, **kwargs) on remote engines, returning the result.
473 """calls f(*args, **kwargs) on remote engines, returning the result.
386
474
387 This method sets all of `apply`'s flags via this View's attributes.
475 This method sets all of `apply`'s flags via this View's attributes.
@@ -395,6 +483,8 b' class DirectView(View):'
395
483
396 kwargs : dict [default: empty]
484 kwargs : dict [default: empty]
397
485
486 targets : target list [default: self.targets]
487 where to run
398 block : bool [default: self.block]
488 block : bool [default: self.block]
399 whether to block
489 whether to block
400 track : bool [default: self.track]
490 track : bool [default: self.track]
@@ -414,16 +504,19 b' class DirectView(View):'
414 kwargs = {} if kwargs is None else kwargs
504 kwargs = {} if kwargs is None else kwargs
415 block = self.block if block is None else block
505 block = self.block if block is None else block
416 track = self.track if track is None else track
506 track = self.track if track is None else track
507 targets = self.targets if targets is None else targets
508
509 _idents = self.client._build_targets(targets)[0]
417 msg_ids = []
510 msg_ids = []
418 trackers = []
511 trackers = []
419 for ident in self._idents:
512 for ident in _idents:
420 msg = self.client.send_apply_message(self._socket, f, args, kwargs, track=track,
513 msg = self.client.send_apply_message(self._socket, f, args, kwargs, track=track,
421 ident=ident)
514 ident=ident)
422 if track:
515 if track:
423 trackers.append(msg['tracker'])
516 trackers.append(msg['tracker'])
424 msg_ids.append(msg['msg_id'])
517 msg_ids.append(msg['msg_id'])
425 tracker = None if track is False else zmq.MessageTracker(*trackers)
518 tracker = None if track is False else zmq.MessageTracker(*trackers)
426 ar = AsyncResult(self.client, msg_ids, fname=f.__name__, targets=self._targets, tracker=tracker)
519 ar = AsyncResult(self.client, msg_ids, fname=f.__name__, targets=targets, tracker=tracker)
427 if block:
520 if block:
428 try:
521 try:
429 return ar.get()
522 return ar.get()
@@ -474,7 +567,7 b' class DirectView(View):'
474 pf = ParallelFunction(self, f, block=block, **kwargs)
567 pf = ParallelFunction(self, f, block=block, **kwargs)
475 return pf.map(*sequences)
568 return pf.map(*sequences)
476
569
477 def execute(self, code, block=None):
570 def execute(self, code, targets=None, block=None):
478 """Executes `code` on `targets` in blocking or nonblocking manner.
571 """Executes `code` on `targets` in blocking or nonblocking manner.
479
572
480 ``execute`` is always `bound` (affects engine namespace)
573 ``execute`` is always `bound` (affects engine namespace)
@@ -488,9 +581,9 b' class DirectView(View):'
488 whether or not to wait until done to return
581 whether or not to wait until done to return
489 default: self.block
582 default: self.block
490 """
583 """
491 return self._really_apply(util._execute, args=(code,), block=block)
584 return self._really_apply(util._execute, args=(code,), block=block, targets=targets)
492
585
493 def run(self, filename, block=None):
586 def run(self, filename, targets=None, block=None):
494 """Execute contents of `filename` on my engine(s).
587 """Execute contents of `filename` on my engine(s).
495
588
496 This simply reads the contents of the file and calls `execute`.
589 This simply reads the contents of the file and calls `execute`.
@@ -512,7 +605,7 b' class DirectView(View):'
512 # add newline in case of trailing indented whitespace
605 # add newline in case of trailing indented whitespace
513 # which will cause SyntaxError
606 # which will cause SyntaxError
514 code = f.read()+'\n'
607 code = f.read()+'\n'
515 return self.execute(code, block=block)
608 return self.execute(code, block=block, targets=targets)
516
609
517 def update(self, ns):
610 def update(self, ns):
518 """update remote namespace with dict `ns`
611 """update remote namespace with dict `ns`
@@ -521,7 +614,7 b' class DirectView(View):'
521 """
614 """
522 return self.push(ns, block=self.block, track=self.track)
615 return self.push(ns, block=self.block, track=self.track)
523
616
524 def push(self, ns, block=None, track=None):
617 def push(self, ns, targets=None, block=None, track=None):
525 """update remote namespace with dict `ns`
618 """update remote namespace with dict `ns`
526
619
527 Parameters
620 Parameters
@@ -536,10 +629,11 b' class DirectView(View):'
536
629
537 block = block if block is not None else self.block
630 block = block if block is not None else self.block
538 track = track if track is not None else self.track
631 track = track if track is not None else self.track
632 targets = targets if targets is not None else self.targets
539 # applier = self.apply_sync if block else self.apply_async
633 # applier = self.apply_sync if block else self.apply_async
540 if not isinstance(ns, dict):
634 if not isinstance(ns, dict):
541 raise TypeError("Must be a dict, not %s"%type(ns))
635 raise TypeError("Must be a dict, not %s"%type(ns))
542 return self._really_apply(util._push, (ns,),block=block, track=track)
636 return self._really_apply(util._push, (ns,), block=block, track=track, targets=targets)
543
637
544 def get(self, key_s):
638 def get(self, key_s):
545 """get object(s) by `key_s` from remote namespace
639 """get object(s) by `key_s` from remote namespace
@@ -549,13 +643,14 b' class DirectView(View):'
549 # block = block if block is not None else self.block
643 # block = block if block is not None else self.block
550 return self.pull(key_s, block=True)
644 return self.pull(key_s, block=True)
551
645
552 def pull(self, names, block=True):
646 def pull(self, names, targets=None, block=True):
553 """get object(s) by `name` from remote namespace
647 """get object(s) by `name` from remote namespace
554
648
555 will return one object if it is a key.
649 will return one object if it is a key.
556 can also take a list of keys, in which case it will return a list of objects.
650 can also take a list of keys, in which case it will return a list of objects.
557 """
651 """
558 block = block if block is not None else self.block
652 block = block if block is not None else self.block
653 targets = targets if targets is not None else self.targets
559 applier = self.apply_sync if block else self.apply_async
654 applier = self.apply_sync if block else self.apply_async
560 if isinstance(names, basestring):
655 if isinstance(names, basestring):
561 pass
656 pass
@@ -565,26 +660,27 b' class DirectView(View):'
565 raise TypeError("keys must be str, not type %r"%type(key))
660 raise TypeError("keys must be str, not type %r"%type(key))
566 else:
661 else:
567 raise TypeError("names must be strs, not %r"%names)
662 raise TypeError("names must be strs, not %r"%names)
568 return applier(util._pull, names)
663 return self._really_apply(util._pull, (names,), block=block, targets=targets)
569
664
570 def scatter(self, key, seq, dist='b', flatten=False, block=None, track=None):
665 def scatter(self, key, seq, dist='b', flatten=False, targets=None, block=None, track=None):
571 """
666 """
572 Partition a Python sequence and send the partitions to a set of engines.
667 Partition a Python sequence and send the partitions to a set of engines.
573 """
668 """
574 block = block if block is not None else self.block
669 block = block if block is not None else self.block
575 track = track if track is not None else self.track
670 track = track if track is not None else self.track
576 targets = self._targets
671 targets = targets if targets is not None else self.targets
672
577 mapObject = Map.dists[dist]()
673 mapObject = Map.dists[dist]()
578 nparts = len(targets)
674 nparts = len(targets)
579 msg_ids = []
675 msg_ids = []
580 trackers = []
676 trackers = []
581 for index, engineid in enumerate(targets):
677 for index, engineid in enumerate(targets):
582 push = self.client[engineid].push
583 partition = mapObject.getPartition(seq, index, nparts)
678 partition = mapObject.getPartition(seq, index, nparts)
584 if flatten and len(partition) == 1:
679 if flatten and len(partition) == 1:
585 r = push({key: partition[0]}, block=False, track=track)
680 ns = {key: partition[0]}
586 else:
681 else:
587 r = push({key: partition},block=False, track=track)
682 ns = {key: partition}
683 r = self.push(ns, block=False, track=track, targets=engineid)
588 msg_ids.extend(r.msg_ids)
684 msg_ids.extend(r.msg_ids)
589 if track:
685 if track:
590 trackers.append(r._tracker)
686 trackers.append(r._tracker)
@@ -602,16 +698,17 b' class DirectView(View):'
602
698
603 @sync_results
699 @sync_results
604 @save_ids
700 @save_ids
605 def gather(self, key, dist='b', block=None):
701 def gather(self, key, dist='b', targets=None, block=None):
606 """
702 """
607 Gather a partitioned sequence on a set of engines as a single local seq.
703 Gather a partitioned sequence on a set of engines as a single local seq.
608 """
704 """
609 block = block if block is not None else self.block
705 block = block if block is not None else self.block
706 targets = targets if targets is not None else self.targets
610 mapObject = Map.dists[dist]()
707 mapObject = Map.dists[dist]()
611 msg_ids = []
708 msg_ids = []
612 for index, engineid in enumerate(self._targets):
709
613
710 for index, engineid in enumerate(targets):
614 msg_ids.extend(self.client[engineid].pull(key, block=False).msg_ids)
711 msg_ids.extend(self.pull(key, block=False, targets=engineid).msg_ids)
615
712
616 r = AsyncMapResult(self.client, msg_ids, mapObject, fname='gather')
713 r = AsyncMapResult(self.client, msg_ids, mapObject, fname='gather')
617
714
@@ -628,15 +725,17 b' class DirectView(View):'
628 def __setitem__(self,key, value):
725 def __setitem__(self,key, value):
629 self.update({key:value})
726 self.update({key:value})
630
727
631 def clear(self, block=False):
728 def clear(self, targets=None, block=False):
632 """Clear the remote namespaces on my engines."""
729 """Clear the remote namespaces on my engines."""
633 block = block if block is not None else self.block
730 block = block if block is not None else self.block
634 return self.client.clear(targets=self._targets, block=block)
731 targets = targets if targets is not None else self.targets
732 return self.client.clear(targets=targets, block=block)
635
733
636 def kill(self, block=True):
734 def kill(self, targets=None, block=True):
637 """Kill my engines."""
735 """Kill my engines."""
638 block = block if block is not None else self.block
736 block = block if block is not None else self.block
639 return self.client.kill(targets=self._targets, block=block)
737 targets = targets if targets is not None else self.targets
738 return self.client.kill(targets=targets, block=block)
640
739
641 #----------------------------------------
740 #----------------------------------------
642 # activate for %px,%autopx magics
741 # activate for %px,%autopx magics
@@ -684,15 +783,16 b' class LoadBalancedView(View):'
684
783
685 """
784 """
686
785
687 _flag_names = ['block', 'track', 'follow', 'after', 'timeout']
786 follow=Any()
787 after=Any()
788 timeout=CFloat()
688
789
689 def __init__(self, client=None, socket=None, targets=None):
790 _task_scheme = Any()
690 super(LoadBalancedView, self).__init__(client=client, socket=socket, targets=targets)
791 _flag_names = List(['targets', 'block', 'track', 'follow', 'after', 'timeout'])
691 self._ntargets = 1
792
793 def __init__(self, client=None, socket=None, **flags):
794 super(LoadBalancedView, self).__init__(client=client, socket=socket, **flags)
692 self._task_scheme=client._task_scheme
795 self._task_scheme=client._task_scheme
693 if targets is None:
694 self._targets = None
695 self._idents=[]
696
796
697 def _validate_dependency(self, dep):
797 def _validate_dependency(self, dep):
698 """validate a dependency.
798 """validate a dependency.
@@ -786,7 +886,8 b' class LoadBalancedView(View):'
786 @sync_results
886 @sync_results
787 @save_ids
887 @save_ids
788 def _really_apply(self, f, args=None, kwargs=None, block=None, track=None,
888 def _really_apply(self, f, args=None, kwargs=None, block=None, track=None,
789 after=None, follow=None, timeout=None):
889 after=None, follow=None, timeout=None,
890 targets=None):
790 """calls f(*args, **kwargs) on a remote engine, returning the result.
891 """calls f(*args, **kwargs) on a remote engine, returning the result.
791
892
792 This method temporarily sets all of `apply`'s flags for a single call.
893 This method temporarily sets all of `apply`'s flags for a single call.
@@ -844,9 +945,16 b' class LoadBalancedView(View):'
844 after = self.after if after is None else after
945 after = self.after if after is None else after
845 follow = self.follow if follow is None else follow
946 follow = self.follow if follow is None else follow
846 timeout = self.timeout if timeout is None else timeout
947 timeout = self.timeout if timeout is None else timeout
948 targets = self.targets if targets is None else targets
949
950 if targets is None:
951 idents = []
952 else:
953 idents = self.client._build_targets(targets)[0]
954
847 after = self._render_dependency(after)
955 after = self._render_dependency(after)
848 follow = self._render_dependency(follow)
956 follow = self._render_dependency(follow)
849 subheader = dict(after=after, follow=follow, timeout=timeout, targets=self._idents)
957 subheader = dict(after=after, follow=follow, timeout=timeout, targets=idents)
850
958
851 msg = self.client.send_apply_message(self._socket, f, args, kwargs, track=track,
959 msg = self.client.send_apply_message(self._socket, f, args, kwargs, track=track,
852 subheader=subheader)
960 subheader=subheader)
@@ -916,5 +1024,5 b' class LoadBalancedView(View):'
916
1024
917 pf = ParallelFunction(self, f, block=block, chunksize=chunksize)
1025 pf = ParallelFunction(self, f, block=block, chunksize=chunksize)
918 return pf.map(*sequences)
1026 return pf.map(*sequences)
919
1027
920 __all__ = ['LoadBalancedView', 'DirectView'] No newline at end of file
1028 __all__ = ['LoadBalancedView', 'DirectView']
General Comments 0
You need to be logged in to leave comments. Login now