Show More
@@ -10,14 +10,15 b'' | |||||
10 | # Imports |
|
10 | # Imports | |
11 | #----------------------------------------------------------------------------- |
|
11 | #----------------------------------------------------------------------------- | |
12 |
|
12 | |||
13 | # from .asyncresult import * |
|
|||
14 | # from .client import Client |
|
|||
15 | # from .dependency import * |
|
|||
16 | # from .remotefunction import * |
|
|||
17 | # from .view import * |
|
|||
18 |
|
||||
19 | import zmq |
|
13 | import zmq | |
20 |
|
14 | |||
21 | if zmq.__version__ < '2.1.3': |
|
15 | if zmq.__version__ < '2.1.3': | |
22 | raise ImportError("IPython.zmq.parallel requires pyzmq/0MQ >= 2.1.3, you appear to have %s"%zmq.__version__) |
|
16 | raise ImportError("IPython.zmq.parallel requires pyzmq/0MQ >= 2.1.3, you appear to have %s"%zmq.__version__) | |
23 |
|
17 | |||
|
18 | from .asyncresult import * | |||
|
19 | from .client import Client | |||
|
20 | from .dependency import * | |||
|
21 | from .remotefunction import * | |||
|
22 | from .view import * | |||
|
23 | ||||
|
24 |
@@ -245,8 +245,6 b' class Client(HasTraits):' | |||||
245 | _mux_socket=Instance('zmq.Socket') |
|
245 | _mux_socket=Instance('zmq.Socket') | |
246 | _task_socket=Instance('zmq.Socket') |
|
246 | _task_socket=Instance('zmq.Socket') | |
247 | _task_scheme=Str() |
|
247 | _task_scheme=Str() | |
248 | _balanced_views=Dict() |
|
|||
249 | _direct_views=Dict() |
|
|||
250 | _closed = False |
|
248 | _closed = False | |
251 | _ignored_control_replies=Int(0) |
|
249 | _ignored_control_replies=Int(0) | |
252 | _ignored_hub_replies=Int(0) |
|
250 | _ignored_hub_replies=Int(0) | |
@@ -389,7 +387,20 b' class Client(HasTraits):' | |||||
389 | else: |
|
387 | else: | |
390 | raise TypeError("%r not valid str target, must be 'all'"%(targets)) |
|
388 | raise TypeError("%r not valid str target, must be 'all'"%(targets)) | |
391 | elif isinstance(targets, int): |
|
389 | elif isinstance(targets, int): | |
|
390 | if targets < 0: | |||
|
391 | targets = self.ids[targets] | |||
|
392 | if targets not in self.ids: | |||
|
393 | raise IndexError("No such engine: %i"%targets) | |||
392 | targets = [targets] |
|
394 | targets = [targets] | |
|
395 | ||||
|
396 | if isinstance(targets, slice): | |||
|
397 | indices = range(len(self._ids))[targets] | |||
|
398 | ids = self.ids | |||
|
399 | targets = [ ids[i] for i in indices ] | |||
|
400 | ||||
|
401 | if not isinstance(targets, (tuple, list, xrange)): | |||
|
402 | raise TypeError("targets by int/slice/collection of ints only, not %s"%(type(targets))) | |||
|
403 | ||||
393 | return [self._engines[t] for t in targets], list(targets) |
|
404 | return [self._engines[t] for t in targets], list(targets) | |
394 |
|
405 | |||
395 | def _connect(self, sshserver, ssh_kwargs, timeout): |
|
406 | def _connect(self, sshserver, ssh_kwargs, timeout): | |
@@ -688,7 +699,7 b' class Client(HasTraits):' | |||||
688 | if not isinstance(key, (int, slice, tuple, list, xrange)): |
|
699 | if not isinstance(key, (int, slice, tuple, list, xrange)): | |
689 | raise TypeError("key by int/slice/iterable of ints only, not %s"%(type(key))) |
|
700 | raise TypeError("key by int/slice/iterable of ints only, not %s"%(type(key))) | |
690 | else: |
|
701 | else: | |
691 |
return self. |
|
702 | return self.direct_view(key) | |
692 |
|
703 | |||
693 | #-------------------------------------------------------------------------- |
|
704 | #-------------------------------------------------------------------------- | |
694 | # Begin public methods |
|
705 | # Begin public methods | |
@@ -962,31 +973,6 b' class Client(HasTraits):' | |||||
962 | # construct a View object |
|
973 | # construct a View object | |
963 | #-------------------------------------------------------------------------- |
|
974 | #-------------------------------------------------------------------------- | |
964 |
|
975 | |||
965 | def _cache_view(self, targets, balanced): |
|
|||
966 | """save views, so subsequent requests don't create new objects.""" |
|
|||
967 | if balanced: |
|
|||
968 | # validate whether we can run |
|
|||
969 | if not self._task_socket: |
|
|||
970 | msg = "Task farming is disabled" |
|
|||
971 | if self._task_scheme == 'pure': |
|
|||
972 | msg += " because the pure ZMQ scheduler cannot handle" |
|
|||
973 | msg += " disappearing engines." |
|
|||
974 | raise RuntimeError(msg) |
|
|||
975 | socket = self._task_socket |
|
|||
976 | view_class = LoadBalancedView |
|
|||
977 | view_cache = self._balanced_views |
|
|||
978 | else: |
|
|||
979 | socket = self._mux_socket |
|
|||
980 | view_class = DirectView |
|
|||
981 | view_cache = self._direct_views |
|
|||
982 |
|
||||
983 | # use str, since often targets will be a list |
|
|||
984 | key = str(targets) |
|
|||
985 | if key not in view_cache: |
|
|||
986 | view_cache[key] = view_class(client=self, socket=socket, targets=targets) |
|
|||
987 |
|
||||
988 | return view_cache[key] |
|
|||
989 |
|
||||
990 | def load_balanced_view(self, targets=None): |
|
976 | def load_balanced_view(self, targets=None): | |
991 | """construct a DirectView object. |
|
977 | """construct a DirectView object. | |
992 |
|
978 | |||
@@ -999,7 +985,9 b' class Client(HasTraits):' | |||||
999 | targets: list,slice,int,etc. [default: use all engines] |
|
985 | targets: list,slice,int,etc. [default: use all engines] | |
1000 | The subset of engines across which to load-balance |
|
986 | The subset of engines across which to load-balance | |
1001 | """ |
|
987 | """ | |
1002 | return self._get_view(targets, balanced=True) |
|
988 | if targets is None: | |
|
989 | targets = self._build_targets(targets)[1] | |||
|
990 | return LoadBalancedView(client=self, socket=self._task_socket, targets=targets) | |||
1003 |
|
991 | |||
1004 | def direct_view(self, targets='all'): |
|
992 | def direct_view(self, targets='all'): | |
1005 | """construct a DirectView object. |
|
993 | """construct a DirectView object. | |
@@ -1013,49 +1001,11 b' class Client(HasTraits):' | |||||
1013 | targets: list,slice,int,etc. [default: use all engines] |
|
1001 | targets: list,slice,int,etc. [default: use all engines] | |
1014 | The engines to use for the View |
|
1002 | The engines to use for the View | |
1015 | """ |
|
1003 | """ | |
1016 | return self._get_view(targets, balanced=False) |
|
1004 | single = isinstance(targets, int) | |
1017 |
|
1005 | targets = self._build_targets(targets)[1] | ||
1018 | def _get_view(self, targets, balanced): |
|
1006 | if single: | |
1019 | """Method for constructing View objects. |
|
1007 | targets = targets[0] | |
1020 |
|
1008 | return DirectView(client=self, socket=self._mux_socket, targets=targets) | ||
1021 | If no arguments are specified, create a LoadBalancedView |
|
|||
1022 | using all engines. If only `targets` specified, it will |
|
|||
1023 | be a DirectView. This method is the underlying implementation |
|
|||
1024 | of ``client.__getitem__``. |
|
|||
1025 |
|
||||
1026 | Parameters |
|
|||
1027 | ---------- |
|
|||
1028 |
|
||||
1029 | targets: list,slice,int,etc. [default: use all engines] |
|
|||
1030 | The engines to use for the View |
|
|||
1031 | balanced : bool [default: False if targets specified, True else] |
|
|||
1032 | whether to build a LoadBalancedView or a DirectView |
|
|||
1033 |
|
||||
1034 | """ |
|
|||
1035 |
|
||||
1036 | if targets in (None,'all'): |
|
|||
1037 | if balanced: |
|
|||
1038 | return self._cache_view(None,True) |
|
|||
1039 | else: |
|
|||
1040 | targets = slice(None) |
|
|||
1041 |
|
||||
1042 | if isinstance(targets, int): |
|
|||
1043 | if targets < 0: |
|
|||
1044 | targets = self.ids[targets] |
|
|||
1045 | if targets not in self.ids: |
|
|||
1046 | raise IndexError("No such engine: %i"%targets) |
|
|||
1047 | return self._cache_view(targets, balanced) |
|
|||
1048 |
|
||||
1049 | if isinstance(targets, slice): |
|
|||
1050 | indices = range(len(self.ids))[targets] |
|
|||
1051 | ids = sorted(self._ids) |
|
|||
1052 | targets = [ ids[i] for i in indices ] |
|
|||
1053 |
|
||||
1054 | if isinstance(targets, (tuple, list, xrange)): |
|
|||
1055 | _,targets = self._build_targets(list(targets)) |
|
|||
1056 | return self._cache_view(targets, balanced) |
|
|||
1057 | else: |
|
|||
1058 | raise TypeError("targets by int/slice/collection of ints only, not %s"%(type(targets))) |
|
|||
1059 |
|
1009 | |||
1060 | #-------------------------------------------------------------------------- |
|
1010 | #-------------------------------------------------------------------------- | |
1061 | # Data movement (TO BE REMOVED) |
|
1011 | # Data movement (TO BE REMOVED) |
@@ -6,6 +6,8 b'' | |||||
6 | # the file COPYING, distributed as part of this software. |
|
6 | # the file COPYING, distributed as part of this software. | |
7 | #----------------------------------------------------------------------------- |
|
7 | #----------------------------------------------------------------------------- | |
8 |
|
8 | |||
|
9 | from types import ModuleType | |||
|
10 | ||||
9 | from .asyncresult import AsyncResult |
|
11 | from .asyncresult import AsyncResult | |
10 | from .error import UnmetDependency |
|
12 | from .error import UnmetDependency | |
11 | from .util import interactive |
|
13 | from .util import interactive | |
@@ -76,7 +78,7 b' def _require(*names):' | |||||
76 | raise UnmetDependency(name) |
|
78 | raise UnmetDependency(name) | |
77 | return True |
|
79 | return True | |
78 |
|
80 | |||
79 |
def require(* |
|
81 | def require(*mods): | |
80 | """Simple decorator for requiring names to be importable. |
|
82 | """Simple decorator for requiring names to be importable. | |
81 |
|
83 | |||
82 | Examples |
|
84 | Examples | |
@@ -87,6 +89,16 b' def require(*names):' | |||||
87 | ...: import numpy |
|
89 | ...: import numpy | |
88 | ...: return numpy.linalg.norm(a,2) |
|
90 | ...: return numpy.linalg.norm(a,2) | |
89 | """ |
|
91 | """ | |
|
92 | names = [] | |||
|
93 | for mod in mods: | |||
|
94 | if isinstance(mod, ModuleType): | |||
|
95 | mod = mod.__name__ | |||
|
96 | ||||
|
97 | if isinstance(mod, basestring): | |||
|
98 | names.append(mod) | |||
|
99 | else: | |||
|
100 | raise TypeError("names must be modules or module names, not %s"%type(mod)) | |||
|
101 | ||||
90 | return depend(_require, *names) |
|
102 | return depend(_require, *names) | |
91 |
|
103 | |||
92 | class Dependency(set): |
|
104 | class Dependency(set): |
@@ -48,9 +48,9 b' from IPython.utils.process import find_cmd, pycmd2argv, FindCmdError' | |||||
48 |
|
48 | |||
49 | from .factory import LoggingFactory |
|
49 | from .factory import LoggingFactory | |
50 |
|
50 | |||
51 | # load winhpcjob from IPython.kernel |
|
51 | # load winhpcjob only on Windows | |
52 | try: |
|
52 | try: | |
53 |
from |
|
53 | from .winhpcjob import ( | |
54 | IPControllerTask, IPEngineTask, |
|
54 | IPControllerTask, IPEngineTask, | |
55 | IPControllerJob, IPEngineSetJob |
|
55 | IPControllerJob, IPEngineSetJob | |
56 | ) |
|
56 | ) |
@@ -61,15 +61,6 b' class TestClient(ClusterTestCase):' | |||||
61 | self.assertEquals(v.targets, targets[-1]) |
|
61 | self.assertEquals(v.targets, targets[-1]) | |
62 | self.assertRaises(TypeError, lambda : self.client[None]) |
|
62 | self.assertRaises(TypeError, lambda : self.client[None]) | |
63 |
|
63 | |||
64 | def test_view_cache(self): |
|
|||
65 | """test that multiple view requests return the same object""" |
|
|||
66 | v = self.client[:2] |
|
|||
67 | v2 =self.client[:2] |
|
|||
68 | self.assertTrue(v is v2) |
|
|||
69 | v = self.client.load_balanced_view() |
|
|||
70 | v2 = self.client.load_balanced_view(targets=None) |
|
|||
71 | self.assertTrue(v is v2) |
|
|||
72 |
|
||||
73 | def test_targets(self): |
|
64 | def test_targets(self): | |
74 | """test various valid targets arguments""" |
|
65 | """test various valid targets arguments""" | |
75 | build = self.client._build_targets |
|
66 | build = self.client._build_targets |
@@ -285,3 +285,17 b' class TestView(ClusterTestCase):' | |||||
285 | self.assertFalse(view.block) |
|
285 | self.assertFalse(view.block) | |
286 | self.assertTrue(view.block) |
|
286 | self.assertTrue(view.block) | |
287 |
|
287 | |||
|
288 | def test_importer(self): | |||
|
289 | view = self.client[-1] | |||
|
290 | view.clear(block=True) | |||
|
291 | with view.importer: | |||
|
292 | import re | |||
|
293 | ||||
|
294 | @interactive | |||
|
295 | def findall(pat, s): | |||
|
296 | # this globals() step isn't necessary in real code | |||
|
297 | # only to prevent a closure in the test | |||
|
298 | return globals()['re'].findall(pat, s) | |||
|
299 | ||||
|
300 | self.assertEquals(view.apply_sync(findall, '\w+', 'hello world'), 'hello world'.split()) | |||
|
301 |
@@ -10,13 +10,16 b'' | |||||
10 | # Imports |
|
10 | # Imports | |
11 | #----------------------------------------------------------------------------- |
|
11 | #----------------------------------------------------------------------------- | |
12 |
|
12 | |||
|
13 | import imp | |||
|
14 | import sys | |||
13 | import warnings |
|
15 | import warnings | |
14 | from contextlib import contextmanager |
|
16 | from contextlib import contextmanager | |
|
17 | from types import ModuleType | |||
15 |
|
18 | |||
16 | import zmq |
|
19 | import zmq | |
17 |
|
20 | |||
18 | from IPython.testing import decorators as testdec |
|
21 | from IPython.testing import decorators as testdec | |
19 | from IPython.utils.traitlets import HasTraits, Any, Bool, List, Dict, Set, Int, Instance |
|
22 | from IPython.utils.traitlets import HasTraits, Any, Bool, List, Dict, Set, Int, Instance, CFloat | |
20 |
|
23 | |||
21 | from IPython.external.decorator import decorator |
|
24 | from IPython.external.decorator import decorator | |
22 |
|
25 | |||
@@ -94,48 +97,36 b' class View(HasTraits):' | |||||
94 | abort, shutdown |
|
97 | abort, shutdown | |
95 |
|
98 | |||
96 | """ |
|
99 | """ | |
|
100 | # flags | |||
97 | block=Bool(False) |
|
101 | block=Bool(False) | |
98 | track=Bool(True) |
|
102 | track=Bool(True) | |
|
103 | targets = Any() | |||
|
104 | ||||
99 | history=List() |
|
105 | history=List() | |
100 | outstanding = Set() |
|
106 | outstanding = Set() | |
101 | results = Dict() |
|
107 | results = Dict() | |
102 | client = Instance('IPython.zmq.parallel.client.Client') |
|
108 | client = Instance('IPython.zmq.parallel.client.Client') | |
103 |
|
109 | |||
104 | _socket = Instance('zmq.Socket') |
|
110 | _socket = Instance('zmq.Socket') | |
105 | _ntargets = Int(1) |
|
111 | _flag_names = List(['targets', 'block', 'track']) | |
106 | _flag_names = List(['block', 'track']) |
|
|||
107 | _targets = Any() |
|
112 | _targets = Any() | |
108 | _idents = Any() |
|
113 | _idents = Any() | |
109 |
|
114 | |||
110 |
def __init__(self, client=None, socket=None, |
|
115 | def __init__(self, client=None, socket=None, **flags): | |
111 | super(View, self).__init__(client=client, _socket=socket) |
|
116 | super(View, self).__init__(client=client, _socket=socket) | |
112 | self._ntargets = 1 if isinstance(targets, (int,type(None))) else len(targets) |
|
|||
113 | self.block = client.block |
|
117 | self.block = client.block | |
114 |
|
118 | |||
115 | self._idents, self._targets = self.client._build_targets(targets) |
|
119 | self.set_flags(**flags) | |
116 | if targets is None or isinstance(targets, int): |
|
|||
117 | self._targets = targets |
|
|||
118 | for name in self._flag_names: |
|
|||
119 | # set flags, if they haven't been set yet |
|
|||
120 | setattr(self, name, getattr(self, name, None)) |
|
|||
121 |
|
120 | |||
122 | assert not self.__class__ is View, "Don't use base View objects, use subclasses" |
|
121 | assert not self.__class__ is View, "Don't use base View objects, use subclasses" | |
123 |
|
122 | |||
124 |
|
123 | |||
125 | def __repr__(self): |
|
124 | def __repr__(self): | |
126 |
strtargets = str(self. |
|
125 | strtargets = str(self.targets) | |
127 | if len(strtargets) > 16: |
|
126 | if len(strtargets) > 16: | |
128 | strtargets = strtargets[:12]+'...]' |
|
127 | strtargets = strtargets[:12]+'...]' | |
129 | return "<%s %s>"%(self.__class__.__name__, strtargets) |
|
128 | return "<%s %s>"%(self.__class__.__name__, strtargets) | |
130 |
|
129 | |||
131 | @property |
|
|||
132 | def targets(self): |
|
|||
133 | return self._targets |
|
|||
134 |
|
||||
135 | @targets.setter |
|
|||
136 | def targets(self, value): |
|
|||
137 | raise AttributeError("Cannot set View `targets` after construction!") |
|
|||
138 |
|
||||
139 | def set_flags(self, **kwargs): |
|
130 | def set_flags(self, **kwargs): | |
140 | """set my attribute flags by keyword. |
|
131 | """set my attribute flags by keyword. | |
141 |
|
132 | |||
@@ -182,9 +173,11 b' class View(HasTraits):' | |||||
182 | saved_flags[f] = getattr(self, f) |
|
173 | saved_flags[f] = getattr(self, f) | |
183 | self.set_flags(**kwargs) |
|
174 | self.set_flags(**kwargs) | |
184 | # yield to the with-statement block |
|
175 | # yield to the with-statement block | |
185 |
|
|
176 | try: | |
186 | # postflight: restore saved flags |
|
177 | yield | |
187 | self.set_flags(**saved_flags) |
|
178 | finally: | |
|
179 | # postflight: restore saved flags | |||
|
180 | self.set_flags(**saved_flags) | |||
188 |
|
181 | |||
189 |
|
182 | |||
190 | #---------------------------------------------------------------- |
|
183 | #---------------------------------------------------------------- | |
@@ -258,7 +251,7 b' class View(HasTraits):' | |||||
258 | jobs = self.history |
|
251 | jobs = self.history | |
259 | return self.client.wait(jobs, timeout) |
|
252 | return self.client.wait(jobs, timeout) | |
260 |
|
253 | |||
261 | def abort(self, jobs=None, block=None): |
|
254 | def abort(self, jobs=None, targets=None, block=None): | |
262 | """Abort jobs on my engines. |
|
255 | """Abort jobs on my engines. | |
263 |
|
256 | |||
264 | Parameters |
|
257 | Parameters | |
@@ -269,16 +262,18 b' class View(HasTraits):' | |||||
269 | else: abort specific msg_id(s). |
|
262 | else: abort specific msg_id(s). | |
270 | """ |
|
263 | """ | |
271 | block = block if block is not None else self.block |
|
264 | block = block if block is not None else self.block | |
272 | return self.client.abort(jobs=jobs, targets=self._targets, block=block) |
|
265 | targets = targets if targets is not None else self.targets | |
|
266 | return self.client.abort(jobs=jobs, targets=targets, block=block) | |||
273 |
|
267 | |||
274 | def queue_status(self, verbose=False): |
|
268 | def queue_status(self, targets=None, verbose=False): | |
275 | """Fetch the Queue status of my engines""" |
|
269 | """Fetch the Queue status of my engines""" | |
276 | return self.client.queue_status(targets=self._targets, verbose=verbose) |
|
270 | targets = targets if targets is not None else self.targets | |
|
271 | return self.client.queue_status(targets=targets, verbose=verbose) | |||
277 |
|
272 | |||
278 | def purge_results(self, jobs=[], targets=[]): |
|
273 | def purge_results(self, jobs=[], targets=[]): | |
279 | """Instruct the controller to forget specific results.""" |
|
274 | """Instruct the controller to forget specific results.""" | |
280 | if targets is None or targets == 'all': |
|
275 | if targets is None or targets == 'all': | |
281 |
targets = self. |
|
276 | targets = self.targets | |
282 | return self.client.purge_results(jobs=jobs, targets=targets) |
|
277 | return self.client.purge_results(jobs=jobs, targets=targets) | |
283 |
|
278 | |||
284 | @spin_after |
|
279 | @spin_after | |
@@ -377,11 +372,104 b' class DirectView(View):' | |||||
377 |
|
372 | |||
378 | def __init__(self, client=None, socket=None, targets=None): |
|
373 | def __init__(self, client=None, socket=None, targets=None): | |
379 | super(DirectView, self).__init__(client=client, socket=socket, targets=targets) |
|
374 | super(DirectView, self).__init__(client=client, socket=socket, targets=targets) | |
|
375 | ||||
|
376 | @property | |||
|
377 | def importer(self): | |||
|
378 | """sync_imports(local=True) as a property. | |||
380 |
|
|
379 | ||
|
380 | See sync_imports for details. | |||
|
381 | ||||
|
382 | In [10]: with v.importer: | |||
|
383 | ....: import numpy | |||
|
384 | ....: | |||
|
385 | importing numpy on engine(s) | |||
|
386 | ||||
|
387 | """ | |||
|
388 | return self.sync_imports(True) | |||
|
389 | ||||
|
390 | @contextmanager | |||
|
391 | def sync_imports(self, local=True): | |||
|
392 | """Context Manager for performing simultaneous local and remote imports. | |||
|
393 | ||||
|
394 | 'import x as y' will *not* work. The 'as y' part will simply be ignored. | |||
|
395 | ||||
|
396 | >>> with view.sync_imports(): | |||
|
397 | ... from numpy import recarray | |||
|
398 | importing recarray from numpy on engine(s) | |||
|
399 | ||||
|
400 | """ | |||
|
401 | import __builtin__ | |||
|
402 | local_import = __builtin__.__import__ | |||
|
403 | modules = set() | |||
|
404 | results = [] | |||
|
405 | @util.interactive | |||
|
406 | def remote_import(name, fromlist, level): | |||
|
407 | """the function to be passed to apply, that actually performs the import | |||
|
408 | on the engine, and loads up the user namespace. | |||
|
409 | """ | |||
|
410 | import sys | |||
|
411 | user_ns = globals() | |||
|
412 | mod = __import__(name, fromlist=fromlist, level=level) | |||
|
413 | if fromlist: | |||
|
414 | for key in fromlist: | |||
|
415 | user_ns[key] = getattr(mod, key) | |||
|
416 | else: | |||
|
417 | user_ns[name] = sys.modules[name] | |||
|
418 | ||||
|
419 | def view_import(name, globals={}, locals={}, fromlist=[], level=-1): | |||
|
420 | """the drop-in replacement for __import__, that optionally imports | |||
|
421 | locally as well. | |||
|
422 | """ | |||
|
423 | # don't override nested imports | |||
|
424 | save_import = __builtin__.__import__ | |||
|
425 | __builtin__.__import__ = local_import | |||
|
426 | ||||
|
427 | if imp.lock_held(): | |||
|
428 | # this is a side-effect import, don't do it remotely, or even | |||
|
429 | # ignore the local effects | |||
|
430 | return local_import(name, globals, locals, fromlist, level) | |||
|
431 | ||||
|
432 | imp.acquire_lock() | |||
|
433 | if local: | |||
|
434 | mod = local_import(name, globals, locals, fromlist, level) | |||
|
435 | else: | |||
|
436 | raise NotImplementedError("remote-only imports not yet implemented") | |||
|
437 | imp.release_lock() | |||
|
438 | ||||
|
439 | key = name+':'+','.join(fromlist or []) | |||
|
440 | if level == -1 and key not in modules: | |||
|
441 | modules.add(key) | |||
|
442 | if fromlist: | |||
|
443 | print "importing %s from %s on engine(s)"%(','.join(fromlist), name) | |||
|
444 | else: | |||
|
445 | print "importing %s on engine(s)"%name | |||
|
446 | results.append(self.apply_async(remote_import, name, fromlist, level)) | |||
|
447 | # restore override | |||
|
448 | __builtin__.__import__ = save_import | |||
|
449 | ||||
|
450 | return mod | |||
|
451 | ||||
|
452 | # override __import__ | |||
|
453 | __builtin__.__import__ = view_import | |||
|
454 | try: | |||
|
455 | # enter the block | |||
|
456 | yield | |||
|
457 | except ImportError: | |||
|
458 | if not local: | |||
|
459 | # ignore import errors if not doing local imports | |||
|
460 | pass | |||
|
461 | finally: | |||
|
462 | # always restore __import__ | |||
|
463 | __builtin__.__import__ = local_import | |||
|
464 | ||||
|
465 | for r in results: | |||
|
466 | # raise possible remote ImportErrors here | |||
|
467 | r.get() | |||
|
468 | ||||
381 |
|
469 | |||
382 | @sync_results |
|
470 | @sync_results | |
383 | @save_ids |
|
471 | @save_ids | |
384 | def _really_apply(self, f, args=None, kwargs=None, block=None, track=None): |
|
472 | def _really_apply(self, f, args=None, kwargs=None, targets=None, block=None, track=None): | |
385 | """calls f(*args, **kwargs) on remote engines, returning the result. |
|
473 | """calls f(*args, **kwargs) on remote engines, returning the result. | |
386 |
|
474 | |||
387 | This method sets all of `apply`'s flags via this View's attributes. |
|
475 | This method sets all of `apply`'s flags via this View's attributes. | |
@@ -395,6 +483,8 b' class DirectView(View):' | |||||
395 |
|
483 | |||
396 | kwargs : dict [default: empty] |
|
484 | kwargs : dict [default: empty] | |
397 |
|
485 | |||
|
486 | targets : target list [default: self.targets] | |||
|
487 | where to run | |||
398 | block : bool [default: self.block] |
|
488 | block : bool [default: self.block] | |
399 | whether to block |
|
489 | whether to block | |
400 | track : bool [default: self.track] |
|
490 | track : bool [default: self.track] | |
@@ -414,16 +504,19 b' class DirectView(View):' | |||||
414 | kwargs = {} if kwargs is None else kwargs |
|
504 | kwargs = {} if kwargs is None else kwargs | |
415 | block = self.block if block is None else block |
|
505 | block = self.block if block is None else block | |
416 | track = self.track if track is None else track |
|
506 | track = self.track if track is None else track | |
|
507 | targets = self.targets if targets is None else targets | |||
|
508 | ||||
|
509 | _idents = self.client._build_targets(targets)[0] | |||
417 | msg_ids = [] |
|
510 | msg_ids = [] | |
418 | trackers = [] |
|
511 | trackers = [] | |
419 |
for ident in |
|
512 | for ident in _idents: | |
420 | msg = self.client.send_apply_message(self._socket, f, args, kwargs, track=track, |
|
513 | msg = self.client.send_apply_message(self._socket, f, args, kwargs, track=track, | |
421 | ident=ident) |
|
514 | ident=ident) | |
422 | if track: |
|
515 | if track: | |
423 | trackers.append(msg['tracker']) |
|
516 | trackers.append(msg['tracker']) | |
424 | msg_ids.append(msg['msg_id']) |
|
517 | msg_ids.append(msg['msg_id']) | |
425 | tracker = None if track is False else zmq.MessageTracker(*trackers) |
|
518 | tracker = None if track is False else zmq.MessageTracker(*trackers) | |
426 |
ar = AsyncResult(self.client, msg_ids, fname=f.__name__, targets= |
|
519 | ar = AsyncResult(self.client, msg_ids, fname=f.__name__, targets=targets, tracker=tracker) | |
427 | if block: |
|
520 | if block: | |
428 | try: |
|
521 | try: | |
429 | return ar.get() |
|
522 | return ar.get() | |
@@ -474,7 +567,7 b' class DirectView(View):' | |||||
474 | pf = ParallelFunction(self, f, block=block, **kwargs) |
|
567 | pf = ParallelFunction(self, f, block=block, **kwargs) | |
475 | return pf.map(*sequences) |
|
568 | return pf.map(*sequences) | |
476 |
|
569 | |||
477 | def execute(self, code, block=None): |
|
570 | def execute(self, code, targets=None, block=None): | |
478 | """Executes `code` on `targets` in blocking or nonblocking manner. |
|
571 | """Executes `code` on `targets` in blocking or nonblocking manner. | |
479 |
|
572 | |||
480 | ``execute`` is always `bound` (affects engine namespace) |
|
573 | ``execute`` is always `bound` (affects engine namespace) | |
@@ -488,9 +581,9 b' class DirectView(View):' | |||||
488 | whether or not to wait until done to return |
|
581 | whether or not to wait until done to return | |
489 | default: self.block |
|
582 | default: self.block | |
490 | """ |
|
583 | """ | |
491 | return self._really_apply(util._execute, args=(code,), block=block) |
|
584 | return self._really_apply(util._execute, args=(code,), block=block, targets=targets) | |
492 |
|
585 | |||
493 | def run(self, filename, block=None): |
|
586 | def run(self, filename, targets=None, block=None): | |
494 | """Execute contents of `filename` on my engine(s). |
|
587 | """Execute contents of `filename` on my engine(s). | |
495 |
|
588 | |||
496 | This simply reads the contents of the file and calls `execute`. |
|
589 | This simply reads the contents of the file and calls `execute`. | |
@@ -512,7 +605,7 b' class DirectView(View):' | |||||
512 | # add newline in case of trailing indented whitespace |
|
605 | # add newline in case of trailing indented whitespace | |
513 | # which will cause SyntaxError |
|
606 | # which will cause SyntaxError | |
514 | code = f.read()+'\n' |
|
607 | code = f.read()+'\n' | |
515 | return self.execute(code, block=block) |
|
608 | return self.execute(code, block=block, targets=targets) | |
516 |
|
609 | |||
517 | def update(self, ns): |
|
610 | def update(self, ns): | |
518 | """update remote namespace with dict `ns` |
|
611 | """update remote namespace with dict `ns` | |
@@ -521,7 +614,7 b' class DirectView(View):' | |||||
521 | """ |
|
614 | """ | |
522 | return self.push(ns, block=self.block, track=self.track) |
|
615 | return self.push(ns, block=self.block, track=self.track) | |
523 |
|
616 | |||
524 | def push(self, ns, block=None, track=None): |
|
617 | def push(self, ns, targets=None, block=None, track=None): | |
525 | """update remote namespace with dict `ns` |
|
618 | """update remote namespace with dict `ns` | |
526 |
|
619 | |||
527 | Parameters |
|
620 | Parameters | |
@@ -536,10 +629,11 b' class DirectView(View):' | |||||
536 |
|
629 | |||
537 | block = block if block is not None else self.block |
|
630 | block = block if block is not None else self.block | |
538 | track = track if track is not None else self.track |
|
631 | track = track if track is not None else self.track | |
|
632 | targets = targets if targets is not None else self.targets | |||
539 | # applier = self.apply_sync if block else self.apply_async |
|
633 | # applier = self.apply_sync if block else self.apply_async | |
540 | if not isinstance(ns, dict): |
|
634 | if not isinstance(ns, dict): | |
541 | raise TypeError("Must be a dict, not %s"%type(ns)) |
|
635 | raise TypeError("Must be a dict, not %s"%type(ns)) | |
542 | return self._really_apply(util._push, (ns,),block=block, track=track) |
|
636 | return self._really_apply(util._push, (ns,), block=block, track=track, targets=targets) | |
543 |
|
637 | |||
544 | def get(self, key_s): |
|
638 | def get(self, key_s): | |
545 | """get object(s) by `key_s` from remote namespace |
|
639 | """get object(s) by `key_s` from remote namespace | |
@@ -549,13 +643,14 b' class DirectView(View):' | |||||
549 | # block = block if block is not None else self.block |
|
643 | # block = block if block is not None else self.block | |
550 | return self.pull(key_s, block=True) |
|
644 | return self.pull(key_s, block=True) | |
551 |
|
645 | |||
552 | def pull(self, names, block=True): |
|
646 | def pull(self, names, targets=None, block=True): | |
553 | """get object(s) by `name` from remote namespace |
|
647 | """get object(s) by `name` from remote namespace | |
554 |
|
648 | |||
555 | will return one object if it is a key. |
|
649 | will return one object if it is a key. | |
556 | can also take a list of keys, in which case it will return a list of objects. |
|
650 | can also take a list of keys, in which case it will return a list of objects. | |
557 | """ |
|
651 | """ | |
558 | block = block if block is not None else self.block |
|
652 | block = block if block is not None else self.block | |
|
653 | targets = targets if targets is not None else self.targets | |||
559 | applier = self.apply_sync if block else self.apply_async |
|
654 | applier = self.apply_sync if block else self.apply_async | |
560 | if isinstance(names, basestring): |
|
655 | if isinstance(names, basestring): | |
561 | pass |
|
656 | pass | |
@@ -565,26 +660,27 b' class DirectView(View):' | |||||
565 | raise TypeError("keys must be str, not type %r"%type(key)) |
|
660 | raise TypeError("keys must be str, not type %r"%type(key)) | |
566 | else: |
|
661 | else: | |
567 | raise TypeError("names must be strs, not %r"%names) |
|
662 | raise TypeError("names must be strs, not %r"%names) | |
568 |
return |
|
663 | return self._really_apply(util._pull, (names,), block=block, targets=targets) | |
569 |
|
664 | |||
570 | def scatter(self, key, seq, dist='b', flatten=False, block=None, track=None): |
|
665 | def scatter(self, key, seq, dist='b', flatten=False, targets=None, block=None, track=None): | |
571 | """ |
|
666 | """ | |
572 | Partition a Python sequence and send the partitions to a set of engines. |
|
667 | Partition a Python sequence and send the partitions to a set of engines. | |
573 | """ |
|
668 | """ | |
574 | block = block if block is not None else self.block |
|
669 | block = block if block is not None else self.block | |
575 | track = track if track is not None else self.track |
|
670 | track = track if track is not None else self.track | |
576 | targets = self._targets |
|
671 | targets = targets if targets is not None else self.targets | |
|
672 | ||||
577 | mapObject = Map.dists[dist]() |
|
673 | mapObject = Map.dists[dist]() | |
578 | nparts = len(targets) |
|
674 | nparts = len(targets) | |
579 | msg_ids = [] |
|
675 | msg_ids = [] | |
580 | trackers = [] |
|
676 | trackers = [] | |
581 | for index, engineid in enumerate(targets): |
|
677 | for index, engineid in enumerate(targets): | |
582 | push = self.client[engineid].push |
|
|||
583 | partition = mapObject.getPartition(seq, index, nparts) |
|
678 | partition = mapObject.getPartition(seq, index, nparts) | |
584 | if flatten and len(partition) == 1: |
|
679 | if flatten and len(partition) == 1: | |
585 |
|
|
680 | ns = {key: partition[0]} | |
586 | else: |
|
681 | else: | |
587 |
|
|
682 | ns = {key: partition} | |
|
683 | r = self.push(ns, block=False, track=track, targets=engineid) | |||
588 | msg_ids.extend(r.msg_ids) |
|
684 | msg_ids.extend(r.msg_ids) | |
589 | if track: |
|
685 | if track: | |
590 | trackers.append(r._tracker) |
|
686 | trackers.append(r._tracker) | |
@@ -602,16 +698,17 b' class DirectView(View):' | |||||
602 |
|
698 | |||
603 | @sync_results |
|
699 | @sync_results | |
604 | @save_ids |
|
700 | @save_ids | |
605 | def gather(self, key, dist='b', block=None): |
|
701 | def gather(self, key, dist='b', targets=None, block=None): | |
606 | """ |
|
702 | """ | |
607 | Gather a partitioned sequence on a set of engines as a single local seq. |
|
703 | Gather a partitioned sequence on a set of engines as a single local seq. | |
608 | """ |
|
704 | """ | |
609 | block = block if block is not None else self.block |
|
705 | block = block if block is not None else self.block | |
|
706 | targets = targets if targets is not None else self.targets | |||
610 | mapObject = Map.dists[dist]() |
|
707 | mapObject = Map.dists[dist]() | |
611 | msg_ids = [] |
|
708 | msg_ids = [] | |
612 | for index, engineid in enumerate(self._targets): |
|
709 | ||
613 |
|
710 | for index, engineid in enumerate(targets): | ||
614 |
msg_ids.extend(self |
|
711 | msg_ids.extend(self.pull(key, block=False, targets=engineid).msg_ids) | |
615 |
|
712 | |||
616 | r = AsyncMapResult(self.client, msg_ids, mapObject, fname='gather') |
|
713 | r = AsyncMapResult(self.client, msg_ids, mapObject, fname='gather') | |
617 |
|
714 | |||
@@ -628,15 +725,17 b' class DirectView(View):' | |||||
628 | def __setitem__(self,key, value): |
|
725 | def __setitem__(self,key, value): | |
629 | self.update({key:value}) |
|
726 | self.update({key:value}) | |
630 |
|
727 | |||
631 | def clear(self, block=False): |
|
728 | def clear(self, targets=None, block=False): | |
632 | """Clear the remote namespaces on my engines.""" |
|
729 | """Clear the remote namespaces on my engines.""" | |
633 | block = block if block is not None else self.block |
|
730 | block = block if block is not None else self.block | |
634 | return self.client.clear(targets=self._targets, block=block) |
|
731 | targets = targets if targets is not None else self.targets | |
|
732 | return self.client.clear(targets=targets, block=block) | |||
635 |
|
733 | |||
636 | def kill(self, block=True): |
|
734 | def kill(self, targets=None, block=True): | |
637 | """Kill my engines.""" |
|
735 | """Kill my engines.""" | |
638 | block = block if block is not None else self.block |
|
736 | block = block if block is not None else self.block | |
639 | return self.client.kill(targets=self._targets, block=block) |
|
737 | targets = targets if targets is not None else self.targets | |
|
738 | return self.client.kill(targets=targets, block=block) | |||
640 |
|
739 | |||
641 | #---------------------------------------- |
|
740 | #---------------------------------------- | |
642 | # activate for %px,%autopx magics |
|
741 | # activate for %px,%autopx magics | |
@@ -684,15 +783,16 b' class LoadBalancedView(View):' | |||||
684 |
|
783 | |||
685 | """ |
|
784 | """ | |
686 |
|
785 | |||
687 | _flag_names = ['block', 'track', 'follow', 'after', 'timeout'] |
|
786 | follow=Any() | |
|
787 | after=Any() | |||
|
788 | timeout=CFloat() | |||
688 |
|
789 | |||
689 | def __init__(self, client=None, socket=None, targets=None): |
|
790 | _task_scheme = Any() | |
690 | super(LoadBalancedView, self).__init__(client=client, socket=socket, targets=targets) |
|
791 | _flag_names = List(['targets', 'block', 'track', 'follow', 'after', 'timeout']) | |
691 | self._ntargets = 1 |
|
792 | ||
|
793 | def __init__(self, client=None, socket=None, **flags): | |||
|
794 | super(LoadBalancedView, self).__init__(client=client, socket=socket, **flags) | |||
692 | self._task_scheme=client._task_scheme |
|
795 | self._task_scheme=client._task_scheme | |
693 | if targets is None: |
|
|||
694 | self._targets = None |
|
|||
695 | self._idents=[] |
|
|||
696 |
|
796 | |||
697 | def _validate_dependency(self, dep): |
|
797 | def _validate_dependency(self, dep): | |
698 | """validate a dependency. |
|
798 | """validate a dependency. | |
@@ -786,7 +886,8 b' class LoadBalancedView(View):' | |||||
786 | @sync_results |
|
886 | @sync_results | |
787 | @save_ids |
|
887 | @save_ids | |
788 | def _really_apply(self, f, args=None, kwargs=None, block=None, track=None, |
|
888 | def _really_apply(self, f, args=None, kwargs=None, block=None, track=None, | |
789 |
after=None, follow=None, timeout=None |
|
889 | after=None, follow=None, timeout=None, | |
|
890 | targets=None): | |||
790 | """calls f(*args, **kwargs) on a remote engine, returning the result. |
|
891 | """calls f(*args, **kwargs) on a remote engine, returning the result. | |
791 |
|
892 | |||
792 | This method temporarily sets all of `apply`'s flags for a single call. |
|
893 | This method temporarily sets all of `apply`'s flags for a single call. | |
@@ -844,9 +945,16 b' class LoadBalancedView(View):' | |||||
844 | after = self.after if after is None else after |
|
945 | after = self.after if after is None else after | |
845 | follow = self.follow if follow is None else follow |
|
946 | follow = self.follow if follow is None else follow | |
846 | timeout = self.timeout if timeout is None else timeout |
|
947 | timeout = self.timeout if timeout is None else timeout | |
|
948 | targets = self.targets if targets is None else targets | |||
|
949 | ||||
|
950 | if targets is None: | |||
|
951 | idents = [] | |||
|
952 | else: | |||
|
953 | idents = self.client._build_targets(targets)[0] | |||
|
954 | ||||
847 | after = self._render_dependency(after) |
|
955 | after = self._render_dependency(after) | |
848 | follow = self._render_dependency(follow) |
|
956 | follow = self._render_dependency(follow) | |
849 |
subheader = dict(after=after, follow=follow, timeout=timeout, targets= |
|
957 | subheader = dict(after=after, follow=follow, timeout=timeout, targets=idents) | |
850 |
|
958 | |||
851 | msg = self.client.send_apply_message(self._socket, f, args, kwargs, track=track, |
|
959 | msg = self.client.send_apply_message(self._socket, f, args, kwargs, track=track, | |
852 | subheader=subheader) |
|
960 | subheader=subheader) | |
@@ -916,5 +1024,5 b' class LoadBalancedView(View):' | |||||
916 |
|
1024 | |||
917 | pf = ParallelFunction(self, f, block=block, chunksize=chunksize) |
|
1025 | pf = ParallelFunction(self, f, block=block, chunksize=chunksize) | |
918 | return pf.map(*sequences) |
|
1026 | return pf.map(*sequences) | |
919 |
|
1027 | |||
920 | __all__ = ['LoadBalancedView', 'DirectView'] No newline at end of file |
|
1028 | __all__ = ['LoadBalancedView', 'DirectView'] |
General Comments 0
You need to be logged in to leave comments.
Login now