Show More
@@ -10,14 +10,15 b'' | |||
|
10 | 10 | # Imports |
|
11 | 11 | #----------------------------------------------------------------------------- |
|
12 | 12 | |
|
13 | # from .asyncresult import * | |
|
14 | # from .client import Client | |
|
15 | # from .dependency import * | |
|
16 | # from .remotefunction import * | |
|
17 | # from .view import * | |
|
18 | ||
|
19 | 13 | import zmq |
|
20 | 14 | |
|
21 | 15 | if zmq.__version__ < '2.1.3': |
|
22 | 16 | raise ImportError("IPython.zmq.parallel requires pyzmq/0MQ >= 2.1.3, you appear to have %s"%zmq.__version__) |
|
23 | 17 | |
|
18 | from .asyncresult import * | |
|
19 | from .client import Client | |
|
20 | from .dependency import * | |
|
21 | from .remotefunction import * | |
|
22 | from .view import * | |
|
23 | ||
|
24 |
@@ -245,8 +245,6 b' class Client(HasTraits):' | |||
|
245 | 245 | _mux_socket=Instance('zmq.Socket') |
|
246 | 246 | _task_socket=Instance('zmq.Socket') |
|
247 | 247 | _task_scheme=Str() |
|
248 | _balanced_views=Dict() | |
|
249 | _direct_views=Dict() | |
|
250 | 248 | _closed = False |
|
251 | 249 | _ignored_control_replies=Int(0) |
|
252 | 250 | _ignored_hub_replies=Int(0) |
@@ -389,7 +387,20 b' class Client(HasTraits):' | |||
|
389 | 387 | else: |
|
390 | 388 | raise TypeError("%r not valid str target, must be 'all'"%(targets)) |
|
391 | 389 | elif isinstance(targets, int): |
|
390 | if targets < 0: | |
|
391 | targets = self.ids[targets] | |
|
392 | if targets not in self.ids: | |
|
393 | raise IndexError("No such engine: %i"%targets) | |
|
392 | 394 | targets = [targets] |
|
395 | ||
|
396 | if isinstance(targets, slice): | |
|
397 | indices = range(len(self._ids))[targets] | |
|
398 | ids = self.ids | |
|
399 | targets = [ ids[i] for i in indices ] | |
|
400 | ||
|
401 | if not isinstance(targets, (tuple, list, xrange)): | |
|
402 | raise TypeError("targets by int/slice/collection of ints only, not %s"%(type(targets))) | |
|
403 | ||
|
393 | 404 | return [self._engines[t] for t in targets], list(targets) |
|
394 | 405 | |
|
395 | 406 | def _connect(self, sshserver, ssh_kwargs, timeout): |
@@ -688,7 +699,7 b' class Client(HasTraits):' | |||
|
688 | 699 | if not isinstance(key, (int, slice, tuple, list, xrange)): |
|
689 | 700 | raise TypeError("key by int/slice/iterable of ints only, not %s"%(type(key))) |
|
690 | 701 | else: |
|
691 |
return self. |
|
|
702 | return self.direct_view(key) | |
|
692 | 703 | |
|
693 | 704 | #-------------------------------------------------------------------------- |
|
694 | 705 | # Begin public methods |
@@ -962,31 +973,6 b' class Client(HasTraits):' | |||
|
962 | 973 | # construct a View object |
|
963 | 974 | #-------------------------------------------------------------------------- |
|
964 | 975 | |
|
965 | def _cache_view(self, targets, balanced): | |
|
966 | """save views, so subsequent requests don't create new objects.""" | |
|
967 | if balanced: | |
|
968 | # validate whether we can run | |
|
969 | if not self._task_socket: | |
|
970 | msg = "Task farming is disabled" | |
|
971 | if self._task_scheme == 'pure': | |
|
972 | msg += " because the pure ZMQ scheduler cannot handle" | |
|
973 | msg += " disappearing engines." | |
|
974 | raise RuntimeError(msg) | |
|
975 | socket = self._task_socket | |
|
976 | view_class = LoadBalancedView | |
|
977 | view_cache = self._balanced_views | |
|
978 | else: | |
|
979 | socket = self._mux_socket | |
|
980 | view_class = DirectView | |
|
981 | view_cache = self._direct_views | |
|
982 | ||
|
983 | # use str, since often targets will be a list | |
|
984 | key = str(targets) | |
|
985 | if key not in view_cache: | |
|
986 | view_cache[key] = view_class(client=self, socket=socket, targets=targets) | |
|
987 | ||
|
988 | return view_cache[key] | |
|
989 | ||
|
990 | 976 | def load_balanced_view(self, targets=None): |
|
991 | 977 | """construct a DirectView object. |
|
992 | 978 | |
@@ -999,7 +985,9 b' class Client(HasTraits):' | |||
|
999 | 985 | targets: list,slice,int,etc. [default: use all engines] |
|
1000 | 986 | The subset of engines across which to load-balance |
|
1001 | 987 | """ |
|
1002 | return self._get_view(targets, balanced=True) | |
|
988 | if targets is None: | |
|
989 | targets = self._build_targets(targets)[1] | |
|
990 | return LoadBalancedView(client=self, socket=self._task_socket, targets=targets) | |
|
1003 | 991 | |
|
1004 | 992 | def direct_view(self, targets='all'): |
|
1005 | 993 | """construct a DirectView object. |
@@ -1013,49 +1001,11 b' class Client(HasTraits):' | |||
|
1013 | 1001 | targets: list,slice,int,etc. [default: use all engines] |
|
1014 | 1002 | The engines to use for the View |
|
1015 | 1003 | """ |
|
1016 | return self._get_view(targets, balanced=False) | |
|
1017 | ||
|
1018 | def _get_view(self, targets, balanced): | |
|
1019 | """Method for constructing View objects. | |
|
1020 | ||
|
1021 | If no arguments are specified, create a LoadBalancedView | |
|
1022 | using all engines. If only `targets` specified, it will | |
|
1023 | be a DirectView. This method is the underlying implementation | |
|
1024 | of ``client.__getitem__``. | |
|
1025 | ||
|
1026 | Parameters | |
|
1027 | ---------- | |
|
1028 | ||
|
1029 | targets: list,slice,int,etc. [default: use all engines] | |
|
1030 | The engines to use for the View | |
|
1031 | balanced : bool [default: False if targets specified, True else] | |
|
1032 | whether to build a LoadBalancedView or a DirectView | |
|
1033 | ||
|
1034 | """ | |
|
1035 | ||
|
1036 | if targets in (None,'all'): | |
|
1037 | if balanced: | |
|
1038 | return self._cache_view(None,True) | |
|
1039 | else: | |
|
1040 | targets = slice(None) | |
|
1041 | ||
|
1042 | if isinstance(targets, int): | |
|
1043 | if targets < 0: | |
|
1044 | targets = self.ids[targets] | |
|
1045 | if targets not in self.ids: | |
|
1046 | raise IndexError("No such engine: %i"%targets) | |
|
1047 | return self._cache_view(targets, balanced) | |
|
1048 | ||
|
1049 | if isinstance(targets, slice): | |
|
1050 | indices = range(len(self.ids))[targets] | |
|
1051 | ids = sorted(self._ids) | |
|
1052 | targets = [ ids[i] for i in indices ] | |
|
1053 | ||
|
1054 | if isinstance(targets, (tuple, list, xrange)): | |
|
1055 | _,targets = self._build_targets(list(targets)) | |
|
1056 | return self._cache_view(targets, balanced) | |
|
1057 | else: | |
|
1058 | raise TypeError("targets by int/slice/collection of ints only, not %s"%(type(targets))) | |
|
1004 | single = isinstance(targets, int) | |
|
1005 | targets = self._build_targets(targets)[1] | |
|
1006 | if single: | |
|
1007 | targets = targets[0] | |
|
1008 | return DirectView(client=self, socket=self._mux_socket, targets=targets) | |
|
1059 | 1009 | |
|
1060 | 1010 | #-------------------------------------------------------------------------- |
|
1061 | 1011 | # Data movement (TO BE REMOVED) |
@@ -6,6 +6,8 b'' | |||
|
6 | 6 | # the file COPYING, distributed as part of this software. |
|
7 | 7 | #----------------------------------------------------------------------------- |
|
8 | 8 | |
|
9 | from types import ModuleType | |
|
10 | ||
|
9 | 11 | from .asyncresult import AsyncResult |
|
10 | 12 | from .error import UnmetDependency |
|
11 | 13 | from .util import interactive |
@@ -76,7 +78,7 b' def _require(*names):' | |||
|
76 | 78 | raise UnmetDependency(name) |
|
77 | 79 | return True |
|
78 | 80 | |
|
79 |
def require(* |
|
|
81 | def require(*mods): | |
|
80 | 82 | """Simple decorator for requiring names to be importable. |
|
81 | 83 | |
|
82 | 84 | Examples |
@@ -87,6 +89,16 b' def require(*names):' | |||
|
87 | 89 | ...: import numpy |
|
88 | 90 | ...: return numpy.linalg.norm(a,2) |
|
89 | 91 | """ |
|
92 | names = [] | |
|
93 | for mod in mods: | |
|
94 | if isinstance(mod, ModuleType): | |
|
95 | mod = mod.__name__ | |
|
96 | ||
|
97 | if isinstance(mod, basestring): | |
|
98 | names.append(mod) | |
|
99 | else: | |
|
100 | raise TypeError("names must be modules or module names, not %s"%type(mod)) | |
|
101 | ||
|
90 | 102 | return depend(_require, *names) |
|
91 | 103 | |
|
92 | 104 | class Dependency(set): |
@@ -48,9 +48,9 b' from IPython.utils.process import find_cmd, pycmd2argv, FindCmdError' | |||
|
48 | 48 | |
|
49 | 49 | from .factory import LoggingFactory |
|
50 | 50 | |
|
51 | # load winhpcjob from IPython.kernel | |
|
51 | # load winhpcjob only on Windows | |
|
52 | 52 | try: |
|
53 |
from |
|
|
53 | from .winhpcjob import ( | |
|
54 | 54 | IPControllerTask, IPEngineTask, |
|
55 | 55 | IPControllerJob, IPEngineSetJob |
|
56 | 56 | ) |
@@ -61,15 +61,6 b' class TestClient(ClusterTestCase):' | |||
|
61 | 61 | self.assertEquals(v.targets, targets[-1]) |
|
62 | 62 | self.assertRaises(TypeError, lambda : self.client[None]) |
|
63 | 63 | |
|
64 | def test_view_cache(self): | |
|
65 | """test that multiple view requests return the same object""" | |
|
66 | v = self.client[:2] | |
|
67 | v2 =self.client[:2] | |
|
68 | self.assertTrue(v is v2) | |
|
69 | v = self.client.load_balanced_view() | |
|
70 | v2 = self.client.load_balanced_view(targets=None) | |
|
71 | self.assertTrue(v is v2) | |
|
72 | ||
|
73 | 64 | def test_targets(self): |
|
74 | 65 | """test various valid targets arguments""" |
|
75 | 66 | build = self.client._build_targets |
@@ -285,3 +285,17 b' class TestView(ClusterTestCase):' | |||
|
285 | 285 | self.assertFalse(view.block) |
|
286 | 286 | self.assertTrue(view.block) |
|
287 | 287 | |
|
288 | def test_importer(self): | |
|
289 | view = self.client[-1] | |
|
290 | view.clear(block=True) | |
|
291 | with view.importer: | |
|
292 | import re | |
|
293 | ||
|
294 | @interactive | |
|
295 | def findall(pat, s): | |
|
296 | # this globals() step isn't necessary in real code | |
|
297 | # only to prevent a closure in the test | |
|
298 | return globals()['re'].findall(pat, s) | |
|
299 | ||
|
300 | self.assertEquals(view.apply_sync(findall, '\w+', 'hello world'), 'hello world'.split()) | |
|
301 |
@@ -10,13 +10,16 b'' | |||
|
10 | 10 | # Imports |
|
11 | 11 | #----------------------------------------------------------------------------- |
|
12 | 12 | |
|
13 | import imp | |
|
14 | import sys | |
|
13 | 15 | import warnings |
|
14 | 16 | from contextlib import contextmanager |
|
17 | from types import ModuleType | |
|
15 | 18 | |
|
16 | 19 | import zmq |
|
17 | 20 | |
|
18 | 21 | from IPython.testing import decorators as testdec |
|
19 | from IPython.utils.traitlets import HasTraits, Any, Bool, List, Dict, Set, Int, Instance | |
|
22 | from IPython.utils.traitlets import HasTraits, Any, Bool, List, Dict, Set, Int, Instance, CFloat | |
|
20 | 23 | |
|
21 | 24 | from IPython.external.decorator import decorator |
|
22 | 25 | |
@@ -94,48 +97,36 b' class View(HasTraits):' | |||
|
94 | 97 | abort, shutdown |
|
95 | 98 | |
|
96 | 99 | """ |
|
100 | # flags | |
|
97 | 101 | block=Bool(False) |
|
98 | 102 | track=Bool(True) |
|
103 | targets = Any() | |
|
104 | ||
|
99 | 105 | history=List() |
|
100 | 106 | outstanding = Set() |
|
101 | 107 | results = Dict() |
|
102 | 108 | client = Instance('IPython.zmq.parallel.client.Client') |
|
103 | 109 | |
|
104 | 110 | _socket = Instance('zmq.Socket') |
|
105 | _ntargets = Int(1) | |
|
106 | _flag_names = List(['block', 'track']) | |
|
111 | _flag_names = List(['targets', 'block', 'track']) | |
|
107 | 112 | _targets = Any() |
|
108 | 113 | _idents = Any() |
|
109 | 114 | |
|
110 |
def __init__(self, client=None, socket=None, |
|
|
115 | def __init__(self, client=None, socket=None, **flags): | |
|
111 | 116 | super(View, self).__init__(client=client, _socket=socket) |
|
112 | self._ntargets = 1 if isinstance(targets, (int,type(None))) else len(targets) | |
|
113 | 117 | self.block = client.block |
|
114 | 118 | |
|
115 | self._idents, self._targets = self.client._build_targets(targets) | |
|
116 | if targets is None or isinstance(targets, int): | |
|
117 | self._targets = targets | |
|
118 | for name in self._flag_names: | |
|
119 | # set flags, if they haven't been set yet | |
|
120 | setattr(self, name, getattr(self, name, None)) | |
|
119 | self.set_flags(**flags) | |
|
121 | 120 | |
|
122 | 121 | assert not self.__class__ is View, "Don't use base View objects, use subclasses" |
|
123 | 122 | |
|
124 | 123 | |
|
125 | 124 | def __repr__(self): |
|
126 |
strtargets = str(self. |
|
|
125 | strtargets = str(self.targets) | |
|
127 | 126 | if len(strtargets) > 16: |
|
128 | 127 | strtargets = strtargets[:12]+'...]' |
|
129 | 128 | return "<%s %s>"%(self.__class__.__name__, strtargets) |
|
130 | 129 | |
|
131 | @property | |
|
132 | def targets(self): | |
|
133 | return self._targets | |
|
134 | ||
|
135 | @targets.setter | |
|
136 | def targets(self, value): | |
|
137 | raise AttributeError("Cannot set View `targets` after construction!") | |
|
138 | ||
|
139 | 130 | def set_flags(self, **kwargs): |
|
140 | 131 | """set my attribute flags by keyword. |
|
141 | 132 | |
@@ -182,9 +173,11 b' class View(HasTraits):' | |||
|
182 | 173 | saved_flags[f] = getattr(self, f) |
|
183 | 174 | self.set_flags(**kwargs) |
|
184 | 175 | # yield to the with-statement block |
|
185 |
|
|
|
186 | # postflight: restore saved flags | |
|
187 | self.set_flags(**saved_flags) | |
|
176 | try: | |
|
177 | yield | |
|
178 | finally: | |
|
179 | # postflight: restore saved flags | |
|
180 | self.set_flags(**saved_flags) | |
|
188 | 181 | |
|
189 | 182 | |
|
190 | 183 | #---------------------------------------------------------------- |
@@ -258,7 +251,7 b' class View(HasTraits):' | |||
|
258 | 251 | jobs = self.history |
|
259 | 252 | return self.client.wait(jobs, timeout) |
|
260 | 253 | |
|
261 | def abort(self, jobs=None, block=None): | |
|
254 | def abort(self, jobs=None, targets=None, block=None): | |
|
262 | 255 | """Abort jobs on my engines. |
|
263 | 256 | |
|
264 | 257 | Parameters |
@@ -269,16 +262,18 b' class View(HasTraits):' | |||
|
269 | 262 | else: abort specific msg_id(s). |
|
270 | 263 | """ |
|
271 | 264 | block = block if block is not None else self.block |
|
272 | return self.client.abort(jobs=jobs, targets=self._targets, block=block) | |
|
265 | targets = targets if targets is not None else self.targets | |
|
266 | return self.client.abort(jobs=jobs, targets=targets, block=block) | |
|
273 | 267 | |
|
274 | def queue_status(self, verbose=False): | |
|
268 | def queue_status(self, targets=None, verbose=False): | |
|
275 | 269 | """Fetch the Queue status of my engines""" |
|
276 | return self.client.queue_status(targets=self._targets, verbose=verbose) | |
|
270 | targets = targets if targets is not None else self.targets | |
|
271 | return self.client.queue_status(targets=targets, verbose=verbose) | |
|
277 | 272 | |
|
278 | 273 | def purge_results(self, jobs=[], targets=[]): |
|
279 | 274 | """Instruct the controller to forget specific results.""" |
|
280 | 275 | if targets is None or targets == 'all': |
|
281 |
targets = self. |
|
|
276 | targets = self.targets | |
|
282 | 277 | return self.client.purge_results(jobs=jobs, targets=targets) |
|
283 | 278 | |
|
284 | 279 | @spin_after |
@@ -377,11 +372,104 b' class DirectView(View):' | |||
|
377 | 372 | |
|
378 | 373 | def __init__(self, client=None, socket=None, targets=None): |
|
379 | 374 | super(DirectView, self).__init__(client=client, socket=socket, targets=targets) |
|
375 | ||
|
376 | @property | |
|
377 | def importer(self): | |
|
378 | """sync_imports(local=True) as a property. | |
|
380 | 379 |
|
|
380 | See sync_imports for details. | |
|
381 | ||
|
382 | In [10]: with v.importer: | |
|
383 | ....: import numpy | |
|
384 | ....: | |
|
385 | importing numpy on engine(s) | |
|
386 | ||
|
387 | """ | |
|
388 | return self.sync_imports(True) | |
|
389 | ||
|
390 | @contextmanager | |
|
391 | def sync_imports(self, local=True): | |
|
392 | """Context Manager for performing simultaneous local and remote imports. | |
|
393 | ||
|
394 | 'import x as y' will *not* work. The 'as y' part will simply be ignored. | |
|
395 | ||
|
396 | >>> with view.sync_imports(): | |
|
397 | ... from numpy import recarray | |
|
398 | importing recarray from numpy on engine(s) | |
|
399 | ||
|
400 | """ | |
|
401 | import __builtin__ | |
|
402 | local_import = __builtin__.__import__ | |
|
403 | modules = set() | |
|
404 | results = [] | |
|
405 | @util.interactive | |
|
406 | def remote_import(name, fromlist, level): | |
|
407 | """the function to be passed to apply, that actually performs the import | |
|
408 | on the engine, and loads up the user namespace. | |
|
409 | """ | |
|
410 | import sys | |
|
411 | user_ns = globals() | |
|
412 | mod = __import__(name, fromlist=fromlist, level=level) | |
|
413 | if fromlist: | |
|
414 | for key in fromlist: | |
|
415 | user_ns[key] = getattr(mod, key) | |
|
416 | else: | |
|
417 | user_ns[name] = sys.modules[name] | |
|
418 | ||
|
419 | def view_import(name, globals={}, locals={}, fromlist=[], level=-1): | |
|
420 | """the drop-in replacement for __import__, that optionally imports | |
|
421 | locally as well. | |
|
422 | """ | |
|
423 | # don't override nested imports | |
|
424 | save_import = __builtin__.__import__ | |
|
425 | __builtin__.__import__ = local_import | |
|
426 | ||
|
427 | if imp.lock_held(): | |
|
428 | # this is a side-effect import, don't do it remotely, or even | |
|
429 | # ignore the local effects | |
|
430 | return local_import(name, globals, locals, fromlist, level) | |
|
431 | ||
|
432 | imp.acquire_lock() | |
|
433 | if local: | |
|
434 | mod = local_import(name, globals, locals, fromlist, level) | |
|
435 | else: | |
|
436 | raise NotImplementedError("remote-only imports not yet implemented") | |
|
437 | imp.release_lock() | |
|
438 | ||
|
439 | key = name+':'+','.join(fromlist or []) | |
|
440 | if level == -1 and key not in modules: | |
|
441 | modules.add(key) | |
|
442 | if fromlist: | |
|
443 | print "importing %s from %s on engine(s)"%(','.join(fromlist), name) | |
|
444 | else: | |
|
445 | print "importing %s on engine(s)"%name | |
|
446 | results.append(self.apply_async(remote_import, name, fromlist, level)) | |
|
447 | # restore override | |
|
448 | __builtin__.__import__ = save_import | |
|
449 | ||
|
450 | return mod | |
|
451 | ||
|
452 | # override __import__ | |
|
453 | __builtin__.__import__ = view_import | |
|
454 | try: | |
|
455 | # enter the block | |
|
456 | yield | |
|
457 | except ImportError: | |
|
458 | if not local: | |
|
459 | # ignore import errors if not doing local imports | |
|
460 | pass | |
|
461 | finally: | |
|
462 | # always restore __import__ | |
|
463 | __builtin__.__import__ = local_import | |
|
464 | ||
|
465 | for r in results: | |
|
466 | # raise possible remote ImportErrors here | |
|
467 | r.get() | |
|
468 | ||
|
381 | 469 | |
|
382 | 470 | @sync_results |
|
383 | 471 | @save_ids |
|
384 | def _really_apply(self, f, args=None, kwargs=None, block=None, track=None): | |
|
472 | def _really_apply(self, f, args=None, kwargs=None, targets=None, block=None, track=None): | |
|
385 | 473 | """calls f(*args, **kwargs) on remote engines, returning the result. |
|
386 | 474 | |
|
387 | 475 | This method sets all of `apply`'s flags via this View's attributes. |
@@ -395,6 +483,8 b' class DirectView(View):' | |||
|
395 | 483 | |
|
396 | 484 | kwargs : dict [default: empty] |
|
397 | 485 | |
|
486 | targets : target list [default: self.targets] | |
|
487 | where to run | |
|
398 | 488 | block : bool [default: self.block] |
|
399 | 489 | whether to block |
|
400 | 490 | track : bool [default: self.track] |
@@ -414,16 +504,19 b' class DirectView(View):' | |||
|
414 | 504 | kwargs = {} if kwargs is None else kwargs |
|
415 | 505 | block = self.block if block is None else block |
|
416 | 506 | track = self.track if track is None else track |
|
507 | targets = self.targets if targets is None else targets | |
|
508 | ||
|
509 | _idents = self.client._build_targets(targets)[0] | |
|
417 | 510 | msg_ids = [] |
|
418 | 511 | trackers = [] |
|
419 |
for ident in |
|
|
512 | for ident in _idents: | |
|
420 | 513 | msg = self.client.send_apply_message(self._socket, f, args, kwargs, track=track, |
|
421 | 514 | ident=ident) |
|
422 | 515 | if track: |
|
423 | 516 | trackers.append(msg['tracker']) |
|
424 | 517 | msg_ids.append(msg['msg_id']) |
|
425 | 518 | tracker = None if track is False else zmq.MessageTracker(*trackers) |
|
426 |
ar = AsyncResult(self.client, msg_ids, fname=f.__name__, targets= |
|
|
519 | ar = AsyncResult(self.client, msg_ids, fname=f.__name__, targets=targets, tracker=tracker) | |
|
427 | 520 | if block: |
|
428 | 521 | try: |
|
429 | 522 | return ar.get() |
@@ -474,7 +567,7 b' class DirectView(View):' | |||
|
474 | 567 | pf = ParallelFunction(self, f, block=block, **kwargs) |
|
475 | 568 | return pf.map(*sequences) |
|
476 | 569 | |
|
477 | def execute(self, code, block=None): | |
|
570 | def execute(self, code, targets=None, block=None): | |
|
478 | 571 | """Executes `code` on `targets` in blocking or nonblocking manner. |
|
479 | 572 | |
|
480 | 573 | ``execute`` is always `bound` (affects engine namespace) |
@@ -488,9 +581,9 b' class DirectView(View):' | |||
|
488 | 581 | whether or not to wait until done to return |
|
489 | 582 | default: self.block |
|
490 | 583 | """ |
|
491 | return self._really_apply(util._execute, args=(code,), block=block) | |
|
584 | return self._really_apply(util._execute, args=(code,), block=block, targets=targets) | |
|
492 | 585 | |
|
493 | def run(self, filename, block=None): | |
|
586 | def run(self, filename, targets=None, block=None): | |
|
494 | 587 | """Execute contents of `filename` on my engine(s). |
|
495 | 588 | |
|
496 | 589 | This simply reads the contents of the file and calls `execute`. |
@@ -512,7 +605,7 b' class DirectView(View):' | |||
|
512 | 605 | # add newline in case of trailing indented whitespace |
|
513 | 606 | # which will cause SyntaxError |
|
514 | 607 | code = f.read()+'\n' |
|
515 | return self.execute(code, block=block) | |
|
608 | return self.execute(code, block=block, targets=targets) | |
|
516 | 609 | |
|
517 | 610 | def update(self, ns): |
|
518 | 611 | """update remote namespace with dict `ns` |
@@ -521,7 +614,7 b' class DirectView(View):' | |||
|
521 | 614 | """ |
|
522 | 615 | return self.push(ns, block=self.block, track=self.track) |
|
523 | 616 | |
|
524 | def push(self, ns, block=None, track=None): | |
|
617 | def push(self, ns, targets=None, block=None, track=None): | |
|
525 | 618 | """update remote namespace with dict `ns` |
|
526 | 619 | |
|
527 | 620 | Parameters |
@@ -536,10 +629,11 b' class DirectView(View):' | |||
|
536 | 629 | |
|
537 | 630 | block = block if block is not None else self.block |
|
538 | 631 | track = track if track is not None else self.track |
|
632 | targets = targets if targets is not None else self.targets | |
|
539 | 633 | # applier = self.apply_sync if block else self.apply_async |
|
540 | 634 | if not isinstance(ns, dict): |
|
541 | 635 | raise TypeError("Must be a dict, not %s"%type(ns)) |
|
542 | return self._really_apply(util._push, (ns,),block=block, track=track) | |
|
636 | return self._really_apply(util._push, (ns,), block=block, track=track, targets=targets) | |
|
543 | 637 | |
|
544 | 638 | def get(self, key_s): |
|
545 | 639 | """get object(s) by `key_s` from remote namespace |
@@ -549,13 +643,14 b' class DirectView(View):' | |||
|
549 | 643 | # block = block if block is not None else self.block |
|
550 | 644 | return self.pull(key_s, block=True) |
|
551 | 645 | |
|
552 | def pull(self, names, block=True): | |
|
646 | def pull(self, names, targets=None, block=True): | |
|
553 | 647 | """get object(s) by `name` from remote namespace |
|
554 | 648 | |
|
555 | 649 | will return one object if it is a key. |
|
556 | 650 | can also take a list of keys, in which case it will return a list of objects. |
|
557 | 651 | """ |
|
558 | 652 | block = block if block is not None else self.block |
|
653 | targets = targets if targets is not None else self.targets | |
|
559 | 654 | applier = self.apply_sync if block else self.apply_async |
|
560 | 655 | if isinstance(names, basestring): |
|
561 | 656 | pass |
@@ -565,26 +660,27 b' class DirectView(View):' | |||
|
565 | 660 | raise TypeError("keys must be str, not type %r"%type(key)) |
|
566 | 661 | else: |
|
567 | 662 | raise TypeError("names must be strs, not %r"%names) |
|
568 |
return |
|
|
663 | return self._really_apply(util._pull, (names,), block=block, targets=targets) | |
|
569 | 664 | |
|
570 | def scatter(self, key, seq, dist='b', flatten=False, block=None, track=None): | |
|
665 | def scatter(self, key, seq, dist='b', flatten=False, targets=None, block=None, track=None): | |
|
571 | 666 | """ |
|
572 | 667 | Partition a Python sequence and send the partitions to a set of engines. |
|
573 | 668 | """ |
|
574 | 669 | block = block if block is not None else self.block |
|
575 | 670 | track = track if track is not None else self.track |
|
576 | targets = self._targets | |
|
671 | targets = targets if targets is not None else self.targets | |
|
672 | ||
|
577 | 673 | mapObject = Map.dists[dist]() |
|
578 | 674 | nparts = len(targets) |
|
579 | 675 | msg_ids = [] |
|
580 | 676 | trackers = [] |
|
581 | 677 | for index, engineid in enumerate(targets): |
|
582 | push = self.client[engineid].push | |
|
583 | 678 | partition = mapObject.getPartition(seq, index, nparts) |
|
584 | 679 | if flatten and len(partition) == 1: |
|
585 |
|
|
|
680 | ns = {key: partition[0]} | |
|
586 | 681 | else: |
|
587 |
|
|
|
682 | ns = {key: partition} | |
|
683 | r = self.push(ns, block=False, track=track, targets=engineid) | |
|
588 | 684 | msg_ids.extend(r.msg_ids) |
|
589 | 685 | if track: |
|
590 | 686 | trackers.append(r._tracker) |
@@ -602,16 +698,17 b' class DirectView(View):' | |||
|
602 | 698 | |
|
603 | 699 | @sync_results |
|
604 | 700 | @save_ids |
|
605 | def gather(self, key, dist='b', block=None): | |
|
701 | def gather(self, key, dist='b', targets=None, block=None): | |
|
606 | 702 | """ |
|
607 | 703 | Gather a partitioned sequence on a set of engines as a single local seq. |
|
608 | 704 | """ |
|
609 | 705 | block = block if block is not None else self.block |
|
706 | targets = targets if targets is not None else self.targets | |
|
610 | 707 | mapObject = Map.dists[dist]() |
|
611 | 708 | msg_ids = [] |
|
612 | for index, engineid in enumerate(self._targets): | |
|
613 | ||
|
614 |
msg_ids.extend(self |
|
|
709 | ||
|
710 | for index, engineid in enumerate(targets): | |
|
711 | msg_ids.extend(self.pull(key, block=False, targets=engineid).msg_ids) | |
|
615 | 712 | |
|
616 | 713 | r = AsyncMapResult(self.client, msg_ids, mapObject, fname='gather') |
|
617 | 714 | |
@@ -628,15 +725,17 b' class DirectView(View):' | |||
|
628 | 725 | def __setitem__(self,key, value): |
|
629 | 726 | self.update({key:value}) |
|
630 | 727 | |
|
631 | def clear(self, block=False): | |
|
728 | def clear(self, targets=None, block=False): | |
|
632 | 729 | """Clear the remote namespaces on my engines.""" |
|
633 | 730 | block = block if block is not None else self.block |
|
634 | return self.client.clear(targets=self._targets, block=block) | |
|
731 | targets = targets if targets is not None else self.targets | |
|
732 | return self.client.clear(targets=targets, block=block) | |
|
635 | 733 | |
|
636 | def kill(self, block=True): | |
|
734 | def kill(self, targets=None, block=True): | |
|
637 | 735 | """Kill my engines.""" |
|
638 | 736 | block = block if block is not None else self.block |
|
639 | return self.client.kill(targets=self._targets, block=block) | |
|
737 | targets = targets if targets is not None else self.targets | |
|
738 | return self.client.kill(targets=targets, block=block) | |
|
640 | 739 | |
|
641 | 740 | #---------------------------------------- |
|
642 | 741 | # activate for %px,%autopx magics |
@@ -684,15 +783,16 b' class LoadBalancedView(View):' | |||
|
684 | 783 | |
|
685 | 784 | """ |
|
686 | 785 | |
|
687 | _flag_names = ['block', 'track', 'follow', 'after', 'timeout'] | |
|
786 | follow=Any() | |
|
787 | after=Any() | |
|
788 | timeout=CFloat() | |
|
688 | 789 | |
|
689 | def __init__(self, client=None, socket=None, targets=None): | |
|
690 | super(LoadBalancedView, self).__init__(client=client, socket=socket, targets=targets) | |
|
691 | self._ntargets = 1 | |
|
790 | _task_scheme = Any() | |
|
791 | _flag_names = List(['targets', 'block', 'track', 'follow', 'after', 'timeout']) | |
|
792 | ||
|
793 | def __init__(self, client=None, socket=None, **flags): | |
|
794 | super(LoadBalancedView, self).__init__(client=client, socket=socket, **flags) | |
|
692 | 795 | self._task_scheme=client._task_scheme |
|
693 | if targets is None: | |
|
694 | self._targets = None | |
|
695 | self._idents=[] | |
|
696 | 796 | |
|
697 | 797 | def _validate_dependency(self, dep): |
|
698 | 798 | """validate a dependency. |
@@ -786,7 +886,8 b' class LoadBalancedView(View):' | |||
|
786 | 886 | @sync_results |
|
787 | 887 | @save_ids |
|
788 | 888 | def _really_apply(self, f, args=None, kwargs=None, block=None, track=None, |
|
789 |
after=None, follow=None, timeout=None |
|
|
889 | after=None, follow=None, timeout=None, | |
|
890 | targets=None): | |
|
790 | 891 | """calls f(*args, **kwargs) on a remote engine, returning the result. |
|
791 | 892 | |
|
792 | 893 | This method temporarily sets all of `apply`'s flags for a single call. |
@@ -844,9 +945,16 b' class LoadBalancedView(View):' | |||
|
844 | 945 | after = self.after if after is None else after |
|
845 | 946 | follow = self.follow if follow is None else follow |
|
846 | 947 | timeout = self.timeout if timeout is None else timeout |
|
948 | targets = self.targets if targets is None else targets | |
|
949 | ||
|
950 | if targets is None: | |
|
951 | idents = [] | |
|
952 | else: | |
|
953 | idents = self.client._build_targets(targets)[0] | |
|
954 | ||
|
847 | 955 | after = self._render_dependency(after) |
|
848 | 956 | follow = self._render_dependency(follow) |
|
849 |
subheader = dict(after=after, follow=follow, timeout=timeout, targets= |
|
|
957 | subheader = dict(after=after, follow=follow, timeout=timeout, targets=idents) | |
|
850 | 958 | |
|
851 | 959 | msg = self.client.send_apply_message(self._socket, f, args, kwargs, track=track, |
|
852 | 960 | subheader=subheader) |
@@ -916,5 +1024,5 b' class LoadBalancedView(View):' | |||
|
916 | 1024 | |
|
917 | 1025 | pf = ParallelFunction(self, f, block=block, chunksize=chunksize) |
|
918 | 1026 | return pf.map(*sequences) |
|
919 | ||
|
1027 | ||
|
920 | 1028 | __all__ = ['LoadBalancedView', 'DirectView'] No newline at end of file |
General Comments 0
You need to be logged in to leave comments.
Login now