##// END OF EJS Templates
update connections and diagrams for reduced sockets
MinRK -
Show More

The requested changes are too big and content was truncated. Show full diff

@@ -1,1570 +1,1584 b''
1 """A semi-synchronous Client for the ZMQ controller"""
1 """A semi-synchronous Client for the ZMQ controller"""
2 #-----------------------------------------------------------------------------
2 #-----------------------------------------------------------------------------
3 # Copyright (C) 2010 The IPython Development Team
3 # Copyright (C) 2010 The IPython Development Team
4 #
4 #
5 # Distributed under the terms of the BSD License. The full license is in
5 # Distributed under the terms of the BSD License. The full license is in
6 # the file COPYING, distributed as part of this software.
6 # the file COPYING, distributed as part of this software.
7 #-----------------------------------------------------------------------------
7 #-----------------------------------------------------------------------------
8
8
9 #-----------------------------------------------------------------------------
9 #-----------------------------------------------------------------------------
10 # Imports
10 # Imports
11 #-----------------------------------------------------------------------------
11 #-----------------------------------------------------------------------------
12
12
13 import os
13 import os
14 import json
14 import json
15 import time
15 import time
16 import warnings
16 import warnings
17 from datetime import datetime
17 from datetime import datetime
18 from getpass import getpass
18 from getpass import getpass
19 from pprint import pprint
19 from pprint import pprint
20
20
21 pjoin = os.path.join
21 pjoin = os.path.join
22
22
23 import zmq
23 import zmq
24 # from zmq.eventloop import ioloop, zmqstream
24 # from zmq.eventloop import ioloop, zmqstream
25
25
26 from IPython.utils.path import get_ipython_dir
26 from IPython.utils.path import get_ipython_dir
27 from IPython.utils.pickleutil import Reference
27 from IPython.utils.pickleutil import Reference
28 from IPython.utils.traitlets import (HasTraits, Int, Instance, CUnicode,
28 from IPython.utils.traitlets import (HasTraits, Int, Instance, CUnicode,
29 Dict, List, Bool, Str, Set)
29 Dict, List, Bool, Str, Set)
30 from IPython.external.decorator import decorator
30 from IPython.external.decorator import decorator
31 from IPython.external.ssh import tunnel
31 from IPython.external.ssh import tunnel
32
32
33 from . import error
33 from . import error
34 from . import map as Map
34 from . import map as Map
35 from . import util
35 from . import util
36 from . import streamsession as ss
36 from . import streamsession as ss
37 from .asyncresult import AsyncResult, AsyncMapResult, AsyncHubResult
37 from .asyncresult import AsyncResult, AsyncMapResult, AsyncHubResult
38 from .clusterdir import ClusterDir, ClusterDirError
38 from .clusterdir import ClusterDir, ClusterDirError
39 from .dependency import Dependency, depend, require, dependent
39 from .dependency import Dependency, depend, require, dependent
40 from .remotefunction import remote, parallel, ParallelFunction, RemoteFunction
40 from .remotefunction import remote, parallel, ParallelFunction, RemoteFunction
41 from .util import ReverseDict, validate_url, disambiguate_url
41 from .util import ReverseDict, validate_url, disambiguate_url
42 from .view import DirectView, LoadBalancedView
42 from .view import DirectView, LoadBalancedView
43
43
44 #--------------------------------------------------------------------------
44 #--------------------------------------------------------------------------
45 # helpers for implementing old MEC API via client.apply
45 # helpers for implementing old MEC API via client.apply
46 #--------------------------------------------------------------------------
46 #--------------------------------------------------------------------------
47
47
48 def _push(user_ns, **ns):
48 def _push(user_ns, **ns):
49 """helper method for implementing `client.push` via `client.apply`"""
49 """helper method for implementing `client.push` via `client.apply`"""
50 user_ns.update(ns)
50 user_ns.update(ns)
51
51
52 def _pull(user_ns, keys):
52 def _pull(user_ns, keys):
53 """helper method for implementing `client.pull` via `client.apply`"""
53 """helper method for implementing `client.pull` via `client.apply`"""
54 if isinstance(keys, (list,tuple, set)):
54 if isinstance(keys, (list,tuple, set)):
55 for key in keys:
55 for key in keys:
56 if not user_ns.has_key(key):
56 if not user_ns.has_key(key):
57 raise NameError("name '%s' is not defined"%key)
57 raise NameError("name '%s' is not defined"%key)
58 return map(user_ns.get, keys)
58 return map(user_ns.get, keys)
59 else:
59 else:
60 if not user_ns.has_key(keys):
60 if not user_ns.has_key(keys):
61 raise NameError("name '%s' is not defined"%keys)
61 raise NameError("name '%s' is not defined"%keys)
62 return user_ns.get(keys)
62 return user_ns.get(keys)
63
63
64 def _clear(user_ns):
64 def _clear(user_ns):
65 """helper method for implementing `client.clear` via `client.apply`"""
65 """helper method for implementing `client.clear` via `client.apply`"""
66 user_ns.clear()
66 user_ns.clear()
67
67
68 def _execute(user_ns, code):
68 def _execute(user_ns, code):
69 """helper method for implementing `client.execute` via `client.apply`"""
69 """helper method for implementing `client.execute` via `client.apply`"""
70 exec code in user_ns
70 exec code in user_ns
71
71
72
72
73 #--------------------------------------------------------------------------
73 #--------------------------------------------------------------------------
74 # Decorators for Client methods
74 # Decorators for Client methods
75 #--------------------------------------------------------------------------
75 #--------------------------------------------------------------------------
76
76
77 @decorator
77 @decorator
78 def spinfirst(f, self, *args, **kwargs):
78 def spinfirst(f, self, *args, **kwargs):
79 """Call spin() to sync state prior to calling the method."""
79 """Call spin() to sync state prior to calling the method."""
80 self.spin()
80 self.spin()
81 return f(self, *args, **kwargs)
81 return f(self, *args, **kwargs)
82
82
83 @decorator
83 @decorator
84 def defaultblock(f, self, *args, **kwargs):
84 def defaultblock(f, self, *args, **kwargs):
85 """Default to self.block; preserve self.block."""
85 """Default to self.block; preserve self.block."""
86 block = kwargs.get('block',None)
86 block = kwargs.get('block',None)
87 block = self.block if block is None else block
87 block = self.block if block is None else block
88 saveblock = self.block
88 saveblock = self.block
89 self.block = block
89 self.block = block
90 try:
90 try:
91 ret = f(self, *args, **kwargs)
91 ret = f(self, *args, **kwargs)
92 finally:
92 finally:
93 self.block = saveblock
93 self.block = saveblock
94 return ret
94 return ret
95
95
96
96
97 #--------------------------------------------------------------------------
97 #--------------------------------------------------------------------------
98 # Classes
98 # Classes
99 #--------------------------------------------------------------------------
99 #--------------------------------------------------------------------------
100
100
101 class Metadata(dict):
101 class Metadata(dict):
102 """Subclass of dict for initializing metadata values.
102 """Subclass of dict for initializing metadata values.
103
103
104 Attribute access works on keys.
104 Attribute access works on keys.
105
105
106 These objects have a strict set of keys - errors will raise if you try
106 These objects have a strict set of keys - errors will raise if you try
107 to add new keys.
107 to add new keys.
108 """
108 """
109 def __init__(self, *args, **kwargs):
109 def __init__(self, *args, **kwargs):
110 dict.__init__(self)
110 dict.__init__(self)
111 md = {'msg_id' : None,
111 md = {'msg_id' : None,
112 'submitted' : None,
112 'submitted' : None,
113 'started' : None,
113 'started' : None,
114 'completed' : None,
114 'completed' : None,
115 'received' : None,
115 'received' : None,
116 'engine_uuid' : None,
116 'engine_uuid' : None,
117 'engine_id' : None,
117 'engine_id' : None,
118 'follow' : None,
118 'follow' : None,
119 'after' : None,
119 'after' : None,
120 'status' : None,
120 'status' : None,
121
121
122 'pyin' : None,
122 'pyin' : None,
123 'pyout' : None,
123 'pyout' : None,
124 'pyerr' : None,
124 'pyerr' : None,
125 'stdout' : '',
125 'stdout' : '',
126 'stderr' : '',
126 'stderr' : '',
127 }
127 }
128 self.update(md)
128 self.update(md)
129 self.update(dict(*args, **kwargs))
129 self.update(dict(*args, **kwargs))
130
130
131 def __getattr__(self, key):
131 def __getattr__(self, key):
132 """getattr aliased to getitem"""
132 """getattr aliased to getitem"""
133 if key in self.iterkeys():
133 if key in self.iterkeys():
134 return self[key]
134 return self[key]
135 else:
135 else:
136 raise AttributeError(key)
136 raise AttributeError(key)
137
137
138 def __setattr__(self, key, value):
138 def __setattr__(self, key, value):
139 """setattr aliased to setitem, with strict"""
139 """setattr aliased to setitem, with strict"""
140 if key in self.iterkeys():
140 if key in self.iterkeys():
141 self[key] = value
141 self[key] = value
142 else:
142 else:
143 raise AttributeError(key)
143 raise AttributeError(key)
144
144
145 def __setitem__(self, key, value):
145 def __setitem__(self, key, value):
146 """strict static key enforcement"""
146 """strict static key enforcement"""
147 if key in self.iterkeys():
147 if key in self.iterkeys():
148 dict.__setitem__(self, key, value)
148 dict.__setitem__(self, key, value)
149 else:
149 else:
150 raise KeyError(key)
150 raise KeyError(key)
151
151
152
152
153 class Client(HasTraits):
153 class Client(HasTraits):
154 """A semi-synchronous client to the IPython ZMQ controller
154 """A semi-synchronous client to the IPython ZMQ controller
155
155
156 Parameters
156 Parameters
157 ----------
157 ----------
158
158
159 url_or_file : bytes; zmq url or path to ipcontroller-client.json
159 url_or_file : bytes; zmq url or path to ipcontroller-client.json
160 Connection information for the Hub's registration. If a json connector
160 Connection information for the Hub's registration. If a json connector
161 file is given, then likely no further configuration is necessary.
161 file is given, then likely no further configuration is necessary.
162 [Default: use profile]
162 [Default: use profile]
163 profile : bytes
163 profile : bytes
164 The name of the Cluster profile to be used to find connector information.
164 The name of the Cluster profile to be used to find connector information.
165 [Default: 'default']
165 [Default: 'default']
166 context : zmq.Context
166 context : zmq.Context
167 Pass an existing zmq.Context instance, otherwise the client will create its own.
167 Pass an existing zmq.Context instance, otherwise the client will create its own.
168 username : bytes
168 username : bytes
169 set username to be passed to the Session object
169 set username to be passed to the Session object
170 debug : bool
170 debug : bool
171 flag for lots of message printing for debug purposes
171 flag for lots of message printing for debug purposes
172
172
173 #-------------- ssh related args ----------------
173 #-------------- ssh related args ----------------
174 # These are args for configuring the ssh tunnel to be used
174 # These are args for configuring the ssh tunnel to be used
175 # credentials are used to forward connections over ssh to the Controller
175 # credentials are used to forward connections over ssh to the Controller
176 # Note that the ip given in `addr` needs to be relative to sshserver
176 # Note that the ip given in `addr` needs to be relative to sshserver
177 # The most basic case is to leave addr as pointing to localhost (127.0.0.1),
177 # The most basic case is to leave addr as pointing to localhost (127.0.0.1),
178 # and set sshserver as the same machine the Controller is on. However,
178 # and set sshserver as the same machine the Controller is on. However,
179 # the only requirement is that sshserver is able to see the Controller
179 # the only requirement is that sshserver is able to see the Controller
180 # (i.e. is within the same trusted network).
180 # (i.e. is within the same trusted network).
181
181
182 sshserver : str
182 sshserver : str
183 A string of the form passed to ssh, i.e. 'server.tld' or 'user@server.tld:port'
183 A string of the form passed to ssh, i.e. 'server.tld' or 'user@server.tld:port'
184 If keyfile or password is specified, and this is not, it will default to
184 If keyfile or password is specified, and this is not, it will default to
185 the ip given in addr.
185 the ip given in addr.
186 sshkey : str; path to public ssh key file
186 sshkey : str; path to public ssh key file
187 This specifies a key to be used in ssh login, default None.
187 This specifies a key to be used in ssh login, default None.
188 Regular default ssh keys will be used without specifying this argument.
188 Regular default ssh keys will be used without specifying this argument.
189 password : str
189 password : str
190 Your ssh password to sshserver. Note that if this is left None,
190 Your ssh password to sshserver. Note that if this is left None,
191 you will be prompted for it if passwordless key based login is unavailable.
191 you will be prompted for it if passwordless key based login is unavailable.
192 paramiko : bool
192 paramiko : bool
193 flag for whether to use paramiko instead of shell ssh for tunneling.
193 flag for whether to use paramiko instead of shell ssh for tunneling.
194 [default: True on win32, False else]
194 [default: True on win32, False else]
195
195
196 #------- exec authentication args -------
196 #------- exec authentication args -------
197 # If even localhost is untrusted, you can have some protection against
197 # If even localhost is untrusted, you can have some protection against
198 # unauthorized execution by using a key. Messages are still sent
198 # unauthorized execution by using a key. Messages are still sent
199 # as cleartext, so if someone can snoop your loopback traffic this will
199 # as cleartext, so if someone can snoop your loopback traffic this will
200 # not help against malicious attacks.
200 # not help against malicious attacks.
201
201
202 exec_key : str
202 exec_key : str
203 an authentication key or file containing a key
203 an authentication key or file containing a key
204 default: None
204 default: None
205
205
206
206
207 Attributes
207 Attributes
208 ----------
208 ----------
209
209
210 ids : set of int engine IDs
210 ids : set of int engine IDs
211 requesting the ids attribute always synchronizes
211 requesting the ids attribute always synchronizes
212 the registration state. To request ids without synchronization,
212 the registration state. To request ids without synchronization,
213 use semi-private _ids attributes.
213 use semi-private _ids attributes.
214
214
215 history : list of msg_ids
215 history : list of msg_ids
216 a list of msg_ids, keeping track of all the execution
216 a list of msg_ids, keeping track of all the execution
217 messages you have submitted in order.
217 messages you have submitted in order.
218
218
219 outstanding : set of msg_ids
219 outstanding : set of msg_ids
220 a set of msg_ids that have been submitted, but whose
220 a set of msg_ids that have been submitted, but whose
221 results have not yet been received.
221 results have not yet been received.
222
222
223 results : dict
223 results : dict
224 a dict of all our results, keyed by msg_id
224 a dict of all our results, keyed by msg_id
225
225
226 block : bool
226 block : bool
227 determines default behavior when block not specified
227 determines default behavior when block not specified
228 in execution methods
228 in execution methods
229
229
230 Methods
230 Methods
231 -------
231 -------
232
232
233 spin
233 spin
234 flushes incoming results and registration state changes
234 flushes incoming results and registration state changes
235 control methods spin, and requesting `ids` also ensures up to date
235 control methods spin, and requesting `ids` also ensures up to date
236
236
237 barrier
237 barrier
238 wait on one or more msg_ids
238 wait on one or more msg_ids
239
239
240 execution methods
240 execution methods
241 apply
241 apply
242 legacy: execute, run
242 legacy: execute, run
243
243
244 query methods
244 query methods
245 queue_status, get_result, purge
245 queue_status, get_result, purge
246
246
247 control methods
247 control methods
248 abort, shutdown
248 abort, shutdown
249
249
250 """
250 """
251
251
252
252
253 block = Bool(False)
253 block = Bool(False)
254 outstanding = Set()
254 outstanding = Set()
255 results = Instance('collections.defaultdict', (dict,))
255 results = Instance('collections.defaultdict', (dict,))
256 metadata = Instance('collections.defaultdict', (Metadata,))
256 metadata = Instance('collections.defaultdict', (Metadata,))
257 history = List()
257 history = List()
258 debug = Bool(False)
258 debug = Bool(False)
259 profile=CUnicode('default')
259 profile=CUnicode('default')
260
260
261 _outstanding_dict = Instance('collections.defaultdict', (set,))
261 _outstanding_dict = Instance('collections.defaultdict', (set,))
262 _ids = List()
262 _ids = List()
263 _connected=Bool(False)
263 _connected=Bool(False)
264 _ssh=Bool(False)
264 _ssh=Bool(False)
265 _context = Instance('zmq.Context')
265 _context = Instance('zmq.Context')
266 _config = Dict()
266 _config = Dict()
267 _engines=Instance(ReverseDict, (), {})
267 _engines=Instance(ReverseDict, (), {})
268 # _hub_socket=Instance('zmq.Socket')
268 # _hub_socket=Instance('zmq.Socket')
269 _query_socket=Instance('zmq.Socket')
269 _query_socket=Instance('zmq.Socket')
270 _control_socket=Instance('zmq.Socket')
270 _control_socket=Instance('zmq.Socket')
271 _iopub_socket=Instance('zmq.Socket')
271 _iopub_socket=Instance('zmq.Socket')
272 _notification_socket=Instance('zmq.Socket')
272 _notification_socket=Instance('zmq.Socket')
273 _mux_socket=Instance('zmq.Socket')
273 _apply_socket=Instance('zmq.Socket')
274 _task_socket=Instance('zmq.Socket')
274 _mux_ident=Str()
275 _task_ident=Str()
275 _task_scheme=Str()
276 _task_scheme=Str()
276 _balanced_views=Dict()
277 _balanced_views=Dict()
277 _direct_views=Dict()
278 _direct_views=Dict()
278 _closed = False
279 _closed = False
279
280
280 def __init__(self, url_or_file=None, profile='default', cluster_dir=None, ipython_dir=None,
281 def __init__(self, url_or_file=None, profile='default', cluster_dir=None, ipython_dir=None,
281 context=None, username=None, debug=False, exec_key=None,
282 context=None, username=None, debug=False, exec_key=None,
282 sshserver=None, sshkey=None, password=None, paramiko=None,
283 sshserver=None, sshkey=None, password=None, paramiko=None,
283 ):
284 ):
284 super(Client, self).__init__(debug=debug, profile=profile)
285 super(Client, self).__init__(debug=debug, profile=profile)
285 if context is None:
286 if context is None:
286 context = zmq.Context()
287 context = zmq.Context()
287 self._context = context
288 self._context = context
288
289
289
290
290 self._setup_cluster_dir(profile, cluster_dir, ipython_dir)
291 self._setup_cluster_dir(profile, cluster_dir, ipython_dir)
291 if self._cd is not None:
292 if self._cd is not None:
292 if url_or_file is None:
293 if url_or_file is None:
293 url_or_file = pjoin(self._cd.security_dir, 'ipcontroller-client.json')
294 url_or_file = pjoin(self._cd.security_dir, 'ipcontroller-client.json')
294 assert url_or_file is not None, "I can't find enough information to connect to a controller!"\
295 assert url_or_file is not None, "I can't find enough information to connect to a controller!"\
295 " Please specify at least one of url_or_file or profile."
296 " Please specify at least one of url_or_file or profile."
296
297
297 try:
298 try:
298 validate_url(url_or_file)
299 validate_url(url_or_file)
299 except AssertionError:
300 except AssertionError:
300 if not os.path.exists(url_or_file):
301 if not os.path.exists(url_or_file):
301 if self._cd:
302 if self._cd:
302 url_or_file = os.path.join(self._cd.security_dir, url_or_file)
303 url_or_file = os.path.join(self._cd.security_dir, url_or_file)
303 assert os.path.exists(url_or_file), "Not a valid connection file or url: %r"%url_or_file
304 assert os.path.exists(url_or_file), "Not a valid connection file or url: %r"%url_or_file
304 with open(url_or_file) as f:
305 with open(url_or_file) as f:
305 cfg = json.loads(f.read())
306 cfg = json.loads(f.read())
306 else:
307 else:
307 cfg = {'url':url_or_file}
308 cfg = {'url':url_or_file}
308
309
309 # sync defaults from args, json:
310 # sync defaults from args, json:
310 if sshserver:
311 if sshserver:
311 cfg['ssh'] = sshserver
312 cfg['ssh'] = sshserver
312 if exec_key:
313 if exec_key:
313 cfg['exec_key'] = exec_key
314 cfg['exec_key'] = exec_key
314 exec_key = cfg['exec_key']
315 exec_key = cfg['exec_key']
315 sshserver=cfg['ssh']
316 sshserver=cfg['ssh']
316 url = cfg['url']
317 url = cfg['url']
317 location = cfg.setdefault('location', None)
318 location = cfg.setdefault('location', None)
318 cfg['url'] = disambiguate_url(cfg['url'], location)
319 cfg['url'] = disambiguate_url(cfg['url'], location)
319 url = cfg['url']
320 url = cfg['url']
320
321
321 self._config = cfg
322 self._config = cfg
322
323
323 self._ssh = bool(sshserver or sshkey or password)
324 self._ssh = bool(sshserver or sshkey or password)
324 if self._ssh and sshserver is None:
325 if self._ssh and sshserver is None:
325 # default to ssh via localhost
326 # default to ssh via localhost
326 sshserver = url.split('://')[1].split(':')[0]
327 sshserver = url.split('://')[1].split(':')[0]
327 if self._ssh and password is None:
328 if self._ssh and password is None:
328 if tunnel.try_passwordless_ssh(sshserver, sshkey, paramiko):
329 if tunnel.try_passwordless_ssh(sshserver, sshkey, paramiko):
329 password=False
330 password=False
330 else:
331 else:
331 password = getpass("SSH Password for %s: "%sshserver)
332 password = getpass("SSH Password for %s: "%sshserver)
332 ssh_kwargs = dict(keyfile=sshkey, password=password, paramiko=paramiko)
333 ssh_kwargs = dict(keyfile=sshkey, password=password, paramiko=paramiko)
333 if exec_key is not None and os.path.isfile(exec_key):
334 if exec_key is not None and os.path.isfile(exec_key):
334 arg = 'keyfile'
335 arg = 'keyfile'
335 else:
336 else:
336 arg = 'key'
337 arg = 'key'
337 key_arg = {arg:exec_key}
338 key_arg = {arg:exec_key}
338 if username is None:
339 if username is None:
339 self.session = ss.StreamSession(**key_arg)
340 self.session = ss.StreamSession(**key_arg)
340 else:
341 else:
341 self.session = ss.StreamSession(username, **key_arg)
342 self.session = ss.StreamSession(username, **key_arg)
342 self._query_socket = self._context.socket(zmq.XREQ)
343 self._query_socket = self._context.socket(zmq.XREQ)
343 self._query_socket.setsockopt(zmq.IDENTITY, self.session.session)
344 self._query_socket.setsockopt(zmq.IDENTITY, self.session.session)
344 if self._ssh:
345 if self._ssh:
345 tunnel.tunnel_connection(self._query_socket, url, sshserver, **ssh_kwargs)
346 tunnel.tunnel_connection(self._query_socket, url, sshserver, **ssh_kwargs)
346 else:
347 else:
347 self._query_socket.connect(url)
348 self._query_socket.connect(url)
348
349
349 self.session.debug = self.debug
350 self.session.debug = self.debug
350
351
351 self._notification_handlers = {'registration_notification' : self._register_engine,
352 self._notification_handlers = {'registration_notification' : self._register_engine,
352 'unregistration_notification' : self._unregister_engine,
353 'unregistration_notification' : self._unregister_engine,
353 }
354 }
354 self._queue_handlers = {'execute_reply' : self._handle_execute_reply,
355 self._queue_handlers = {'execute_reply' : self._handle_execute_reply,
355 'apply_reply' : self._handle_apply_reply}
356 'apply_reply' : self._handle_apply_reply}
356 self._connect(sshserver, ssh_kwargs)
357 self._connect(sshserver, ssh_kwargs)
357
358
358 def __del__(self):
359 def __del__(self):
359 """cleanup sockets, but _not_ context."""
360 """cleanup sockets, but _not_ context."""
360 self.close()
361 self.close()
361
362
362 def _setup_cluster_dir(self, profile, cluster_dir, ipython_dir):
363 def _setup_cluster_dir(self, profile, cluster_dir, ipython_dir):
363 if ipython_dir is None:
364 if ipython_dir is None:
364 ipython_dir = get_ipython_dir()
365 ipython_dir = get_ipython_dir()
365 if cluster_dir is not None:
366 if cluster_dir is not None:
366 try:
367 try:
367 self._cd = ClusterDir.find_cluster_dir(cluster_dir)
368 self._cd = ClusterDir.find_cluster_dir(cluster_dir)
368 return
369 return
369 except ClusterDirError:
370 except ClusterDirError:
370 pass
371 pass
371 elif profile is not None:
372 elif profile is not None:
372 try:
373 try:
373 self._cd = ClusterDir.find_cluster_dir_by_profile(
374 self._cd = ClusterDir.find_cluster_dir_by_profile(
374 ipython_dir, profile)
375 ipython_dir, profile)
375 return
376 return
376 except ClusterDirError:
377 except ClusterDirError:
377 pass
378 pass
378 self._cd = None
379 self._cd = None
379
380
380 @property
381 @property
381 def ids(self):
382 def ids(self):
382 """Always up-to-date ids property."""
383 """Always up-to-date ids property."""
383 self._flush_notifications()
384 self._flush_notifications()
384 # always copy:
385 # always copy:
385 return list(self._ids)
386 return list(self._ids)
386
387
387 def close(self):
388 def close(self):
388 if self._closed:
389 if self._closed:
389 return
390 return
390 snames = filter(lambda n: n.endswith('socket'), dir(self))
391 snames = filter(lambda n: n.endswith('socket'), dir(self))
391 for socket in map(lambda name: getattr(self, name), snames):
392 for socket in map(lambda name: getattr(self, name), snames):
392 if isinstance(socket, zmq.Socket) and not socket.closed:
393 if isinstance(socket, zmq.Socket) and not socket.closed:
393 socket.close()
394 socket.close()
394 self._closed = True
395 self._closed = True
395
396
396 def _update_engines(self, engines):
397 def _update_engines(self, engines):
397 """Update our engines dict and _ids from a dict of the form: {id:uuid}."""
398 """Update our engines dict and _ids from a dict of the form: {id:uuid}."""
398 for k,v in engines.iteritems():
399 for k,v in engines.iteritems():
399 eid = int(k)
400 eid = int(k)
400 self._engines[eid] = bytes(v) # force not unicode
401 self._engines[eid] = bytes(v) # force not unicode
401 self._ids.append(eid)
402 self._ids.append(eid)
402 self._ids = sorted(self._ids)
403 self._ids = sorted(self._ids)
403 if sorted(self._engines.keys()) != range(len(self._engines)) and \
404 if sorted(self._engines.keys()) != range(len(self._engines)) and \
404 self._task_scheme == 'pure' and self._task_socket:
405 self._task_scheme == 'pure' and self._task_ident:
405 self._stop_scheduling_tasks()
406 self._stop_scheduling_tasks()
406
407
407 def _stop_scheduling_tasks(self):
408 def _stop_scheduling_tasks(self):
408 """Stop scheduling tasks because an engine has been unregistered
409 """Stop scheduling tasks because an engine has been unregistered
409 from a pure ZMQ scheduler.
410 from a pure ZMQ scheduler.
410 """
411 """
411
412 self._task_ident = ''
412 self._task_socket.close()
413 # self._task_socket.close()
413 self._task_socket = None
414 # self._task_socket = None
414 msg = "An engine has been unregistered, and we are using pure " +\
415 msg = "An engine has been unregistered, and we are using pure " +\
415 "ZMQ task scheduling. Task farming will be disabled."
416 "ZMQ task scheduling. Task farming will be disabled."
416 if self.outstanding:
417 if self.outstanding:
417 msg += " If you were running tasks when this happened, " +\
418 msg += " If you were running tasks when this happened, " +\
418 "some `outstanding` msg_ids may never resolve."
419 "some `outstanding` msg_ids may never resolve."
419 warnings.warn(msg, RuntimeWarning)
420 warnings.warn(msg, RuntimeWarning)
420
421
421 def _build_targets(self, targets):
422 def _build_targets(self, targets):
422 """Turn valid target IDs or 'all' into two lists:
423 """Turn valid target IDs or 'all' into two lists:
423 (int_ids, uuids).
424 (int_ids, uuids).
424 """
425 """
425 if targets is None:
426 if targets is None:
426 targets = self._ids
427 targets = self._ids
427 elif isinstance(targets, str):
428 elif isinstance(targets, str):
428 if targets.lower() == 'all':
429 if targets.lower() == 'all':
429 targets = self._ids
430 targets = self._ids
430 else:
431 else:
431 raise TypeError("%r not valid str target, must be 'all'"%(targets))
432 raise TypeError("%r not valid str target, must be 'all'"%(targets))
432 elif isinstance(targets, int):
433 elif isinstance(targets, int):
433 targets = [targets]
434 targets = [targets]
434 return [self._engines[t] for t in targets], list(targets)
435 return [self._engines[t] for t in targets], list(targets)
435
436
436 def _connect(self, sshserver, ssh_kwargs):
437 def _connect(self, sshserver, ssh_kwargs):
437 """setup all our socket connections to the controller. This is called from
438 """setup all our socket connections to the controller. This is called from
438 __init__."""
439 __init__."""
439
440
440 # Maybe allow reconnecting?
441 # Maybe allow reconnecting?
441 if self._connected:
442 if self._connected:
442 return
443 return
443 self._connected=True
444 self._connected=True
444
445
445 def connect_socket(s, url):
446 def connect_socket(s, url):
446 url = disambiguate_url(url, self._config['location'])
447 url = disambiguate_url(url, self._config['location'])
447 if self._ssh:
448 if self._ssh:
448 return tunnel.tunnel_connection(s, url, sshserver, **ssh_kwargs)
449 return tunnel.tunnel_connection(s, url, sshserver, **ssh_kwargs)
449 else:
450 else:
450 return s.connect(url)
451 return s.connect(url)
451
452
452 self.session.send(self._query_socket, 'connection_request')
453 self.session.send(self._query_socket, 'connection_request')
453 idents,msg = self.session.recv(self._query_socket,mode=0)
454 idents,msg = self.session.recv(self._query_socket,mode=0)
454 if self.debug:
455 if self.debug:
455 pprint(msg)
456 pprint(msg)
456 msg = ss.Message(msg)
457 msg = ss.Message(msg)
457 content = msg.content
458 content = msg.content
458 self._config['registration'] = dict(content)
459 self._config['registration'] = dict(content)
459 if content.status == 'ok':
460 if content.status == 'ok':
461 self._apply_socket = self._context.socket(zmq.XREP)
462 self._apply_socket.setsockopt(zmq.IDENTITY, self.session.session)
460 if content.mux:
463 if content.mux:
461 self._mux_socket = self._context.socket(zmq.XREQ)
464 # self._mux_socket = self._context.socket(zmq.XREQ)
462 self._mux_socket.setsockopt(zmq.IDENTITY, self.session.session)
465 self._mux_ident = 'mux'
463 connect_socket(self._mux_socket, content.mux)
466 connect_socket(self._apply_socket, content.mux)
464 if content.task:
467 if content.task:
465 self._task_scheme, task_addr = content.task
468 self._task_scheme, task_addr = content.task
466 self._task_socket = self._context.socket(zmq.XREQ)
469 # self._task_socket = self._context.socket(zmq.XREQ)
467 self._task_socket.setsockopt(zmq.IDENTITY, self.session.session)
470 # self._task_socket.setsockopt(zmq.IDENTITY, self.session.session)
468 connect_socket(self._task_socket, task_addr)
471 connect_socket(self._apply_socket, task_addr)
472 self._task_ident = 'task'
469 if content.notification:
473 if content.notification:
470 self._notification_socket = self._context.socket(zmq.SUB)
474 self._notification_socket = self._context.socket(zmq.SUB)
471 connect_socket(self._notification_socket, content.notification)
475 connect_socket(self._notification_socket, content.notification)
472 self._notification_socket.setsockopt(zmq.SUBSCRIBE, b'')
476 self._notification_socket.setsockopt(zmq.SUBSCRIBE, b'')
473 # if content.query:
477 # if content.query:
474 # self._query_socket = self._context.socket(zmq.XREQ)
478 # self._query_socket = self._context.socket(zmq.XREQ)
475 # self._query_socket.setsockopt(zmq.IDENTITY, self.session.session)
479 # self._query_socket.setsockopt(zmq.IDENTITY, self.session.session)
476 # connect_socket(self._query_socket, content.query)
480 # connect_socket(self._query_socket, content.query)
477 if content.control:
481 if content.control:
478 self._control_socket = self._context.socket(zmq.XREQ)
482 self._control_socket = self._context.socket(zmq.XREQ)
479 self._control_socket.setsockopt(zmq.IDENTITY, self.session.session)
483 self._control_socket.setsockopt(zmq.IDENTITY, self.session.session)
480 connect_socket(self._control_socket, content.control)
484 connect_socket(self._control_socket, content.control)
481 if content.iopub:
485 if content.iopub:
482 self._iopub_socket = self._context.socket(zmq.SUB)
486 self._iopub_socket = self._context.socket(zmq.SUB)
483 self._iopub_socket.setsockopt(zmq.SUBSCRIBE, b'')
487 self._iopub_socket.setsockopt(zmq.SUBSCRIBE, b'')
484 self._iopub_socket.setsockopt(zmq.IDENTITY, self.session.session)
488 self._iopub_socket.setsockopt(zmq.IDENTITY, self.session.session)
485 connect_socket(self._iopub_socket, content.iopub)
489 connect_socket(self._iopub_socket, content.iopub)
486 self._update_engines(dict(content.engines))
490 self._update_engines(dict(content.engines))
487
491 # give XREP apply_socket some time to connect
492 time.sleep(0.25)
488 else:
493 else:
489 self._connected = False
494 self._connected = False
490 raise Exception("Failed to connect!")
495 raise Exception("Failed to connect!")
491
496
492 #--------------------------------------------------------------------------
497 #--------------------------------------------------------------------------
493 # handlers and callbacks for incoming messages
498 # handlers and callbacks for incoming messages
494 #--------------------------------------------------------------------------
499 #--------------------------------------------------------------------------
495
500
496 def _unwrap_exception(self, content):
501 def _unwrap_exception(self, content):
497 """unwrap exception, and remap engineid to int."""
502 """unwrap exception, and remap engineid to int."""
498 e = error.unwrap_exception(content)
503 e = error.unwrap_exception(content)
499 print e.traceback
504 # print e.traceback
500 if e.engine_info:
505 if e.engine_info:
501 e_uuid = e.engine_info['engine_uuid']
506 e_uuid = e.engine_info['engine_uuid']
502 eid = self._engines[e_uuid]
507 eid = self._engines[e_uuid]
503 e.engine_info['engine_id'] = eid
508 e.engine_info['engine_id'] = eid
504 return e
509 return e
505
510
506 def _extract_metadata(self, header, parent, content):
511 def _extract_metadata(self, header, parent, content):
507 md = {'msg_id' : parent['msg_id'],
512 md = {'msg_id' : parent['msg_id'],
508 'received' : datetime.now(),
513 'received' : datetime.now(),
509 'engine_uuid' : header.get('engine', None),
514 'engine_uuid' : header.get('engine', None),
510 'follow' : parent.get('follow', []),
515 'follow' : parent.get('follow', []),
511 'after' : parent.get('after', []),
516 'after' : parent.get('after', []),
512 'status' : content['status'],
517 'status' : content['status'],
513 }
518 }
514
519
515 if md['engine_uuid'] is not None:
520 if md['engine_uuid'] is not None:
516 md['engine_id'] = self._engines.get(md['engine_uuid'], None)
521 md['engine_id'] = self._engines.get(md['engine_uuid'], None)
517
522
518 if 'date' in parent:
523 if 'date' in parent:
519 md['submitted'] = datetime.strptime(parent['date'], util.ISO8601)
524 md['submitted'] = datetime.strptime(parent['date'], util.ISO8601)
520 if 'started' in header:
525 if 'started' in header:
521 md['started'] = datetime.strptime(header['started'], util.ISO8601)
526 md['started'] = datetime.strptime(header['started'], util.ISO8601)
522 if 'date' in header:
527 if 'date' in header:
523 md['completed'] = datetime.strptime(header['date'], util.ISO8601)
528 md['completed'] = datetime.strptime(header['date'], util.ISO8601)
524 return md
529 return md
525
530
526 def _register_engine(self, msg):
531 def _register_engine(self, msg):
527 """Register a new engine, and update our connection info."""
532 """Register a new engine, and update our connection info."""
528 content = msg['content']
533 content = msg['content']
529 eid = content['id']
534 eid = content['id']
530 d = {eid : content['queue']}
535 d = {eid : content['queue']}
531 self._update_engines(d)
536 self._update_engines(d)
532
537
533 def _unregister_engine(self, msg):
538 def _unregister_engine(self, msg):
534 """Unregister an engine that has died."""
539 """Unregister an engine that has died."""
535 content = msg['content']
540 content = msg['content']
536 eid = int(content['id'])
541 eid = int(content['id'])
537 if eid in self._ids:
542 if eid in self._ids:
538 self._ids.remove(eid)
543 self._ids.remove(eid)
539 uuid = self._engines.pop(eid)
544 uuid = self._engines.pop(eid)
540
545
541 self._handle_stranded_msgs(eid, uuid)
546 self._handle_stranded_msgs(eid, uuid)
542
547
543 if self._task_socket and self._task_scheme == 'pure':
548 if self._task_ident and self._task_scheme == 'pure':
544 self._stop_scheduling_tasks()
549 self._stop_scheduling_tasks()
545
550
546 def _handle_stranded_msgs(self, eid, uuid):
551 def _handle_stranded_msgs(self, eid, uuid):
547 """Handle messages known to be on an engine when the engine unregisters.
552 """Handle messages known to be on an engine when the engine unregisters.
548
553
549 It is possible that this will fire prematurely - that is, an engine will
554 It is possible that this will fire prematurely - that is, an engine will
550 go down after completing a result, and the client will be notified
555 go down after completing a result, and the client will be notified
551 of the unregistration and later receive the successful result.
556 of the unregistration and later receive the successful result.
552 """
557 """
553
558
554 outstanding = self._outstanding_dict[uuid]
559 outstanding = self._outstanding_dict[uuid]
555
560
556 for msg_id in list(outstanding):
561 for msg_id in list(outstanding):
557 if msg_id in self.results:
562 if msg_id in self.results:
558 # we already
563 # we already
559 continue
564 continue
560 try:
565 try:
561 raise error.EngineError("Engine %r died while running task %r"%(eid, msg_id))
566 raise error.EngineError("Engine %r died while running task %r"%(eid, msg_id))
562 except:
567 except:
563 content = error.wrap_exception()
568 content = error.wrap_exception()
564 # build a fake message:
569 # build a fake message:
565 parent = {}
570 parent = {}
566 header = {}
571 header = {}
567 parent['msg_id'] = msg_id
572 parent['msg_id'] = msg_id
568 header['engine'] = uuid
573 header['engine'] = uuid
569 header['date'] = datetime.now().strftime(util.ISO8601)
574 header['date'] = datetime.now().strftime(util.ISO8601)
570 msg = dict(parent_header=parent, header=header, content=content)
575 msg = dict(parent_header=parent, header=header, content=content)
571 self._handle_apply_reply(msg)
576 self._handle_apply_reply(msg)
572
577
573 def _handle_execute_reply(self, msg):
578 def _handle_execute_reply(self, msg):
574 """Save the reply to an execute_request into our results.
579 """Save the reply to an execute_request into our results.
575
580
576 execute messages are never actually used. apply is used instead.
581 execute messages are never actually used. apply is used instead.
577 """
582 """
578
583
579 parent = msg['parent_header']
584 parent = msg['parent_header']
580 msg_id = parent['msg_id']
585 msg_id = parent['msg_id']
581 if msg_id not in self.outstanding:
586 if msg_id not in self.outstanding:
582 if msg_id in self.history:
587 if msg_id in self.history:
583 print ("got stale result: %s"%msg_id)
588 print ("got stale result: %s"%msg_id)
584 else:
589 else:
585 print ("got unknown result: %s"%msg_id)
590 print ("got unknown result: %s"%msg_id)
586 else:
591 else:
587 self.outstanding.remove(msg_id)
592 self.outstanding.remove(msg_id)
588 self.results[msg_id] = self._unwrap_exception(msg['content'])
593 self.results[msg_id] = self._unwrap_exception(msg['content'])
589
594
590 def _handle_apply_reply(self, msg):
595 def _handle_apply_reply(self, msg):
591 """Save the reply to an apply_request into our results."""
596 """Save the reply to an apply_request into our results."""
592 parent = msg['parent_header']
597 parent = msg['parent_header']
593 msg_id = parent['msg_id']
598 msg_id = parent['msg_id']
594 if msg_id not in self.outstanding:
599 if msg_id not in self.outstanding:
595 if msg_id in self.history:
600 if msg_id in self.history:
596 print ("got stale result: %s"%msg_id)
601 print ("got stale result: %s"%msg_id)
597 print self.results[msg_id]
602 print self.results[msg_id]
598 print msg
603 print msg
599 else:
604 else:
600 print ("got unknown result: %s"%msg_id)
605 print ("got unknown result: %s"%msg_id)
601 else:
606 else:
602 self.outstanding.remove(msg_id)
607 self.outstanding.remove(msg_id)
603 content = msg['content']
608 content = msg['content']
604 header = msg['header']
609 header = msg['header']
605
610
606 # construct metadata:
611 # construct metadata:
607 md = self.metadata[msg_id]
612 md = self.metadata[msg_id]
608 md.update(self._extract_metadata(header, parent, content))
613 md.update(self._extract_metadata(header, parent, content))
609 # is this redundant?
614 # is this redundant?
610 self.metadata[msg_id] = md
615 self.metadata[msg_id] = md
611
616
612 e_outstanding = self._outstanding_dict[md['engine_uuid']]
617 e_outstanding = self._outstanding_dict[md['engine_uuid']]
613 if msg_id in e_outstanding:
618 if msg_id in e_outstanding:
614 e_outstanding.remove(msg_id)
619 e_outstanding.remove(msg_id)
615
620
616 # construct result:
621 # construct result:
617 if content['status'] == 'ok':
622 if content['status'] == 'ok':
618 self.results[msg_id] = util.unserialize_object(msg['buffers'])[0]
623 self.results[msg_id] = util.unserialize_object(msg['buffers'])[0]
619 elif content['status'] == 'aborted':
624 elif content['status'] == 'aborted':
620 self.results[msg_id] = error.AbortedTask(msg_id)
625 self.results[msg_id] = error.AbortedTask(msg_id)
621 elif content['status'] == 'resubmitted':
626 elif content['status'] == 'resubmitted':
622 # TODO: handle resubmission
627 # TODO: handle resubmission
623 pass
628 pass
624 else:
629 else:
625 self.results[msg_id] = self._unwrap_exception(content)
630 self.results[msg_id] = self._unwrap_exception(content)
626
631
627 def _flush_notifications(self):
632 def _flush_notifications(self):
628 """Flush notifications of engine registrations waiting
633 """Flush notifications of engine registrations waiting
629 in ZMQ queue."""
634 in ZMQ queue."""
630 msg = self.session.recv(self._notification_socket, mode=zmq.NOBLOCK)
635 msg = self.session.recv(self._notification_socket, mode=zmq.NOBLOCK)
631 while msg is not None:
636 while msg is not None:
632 if self.debug:
637 if self.debug:
633 pprint(msg)
638 pprint(msg)
634 msg = msg[-1]
639 msg = msg[-1]
635 msg_type = msg['msg_type']
640 msg_type = msg['msg_type']
636 handler = self._notification_handlers.get(msg_type, None)
641 handler = self._notification_handlers.get(msg_type, None)
637 if handler is None:
642 if handler is None:
638 raise Exception("Unhandled message type: %s"%msg.msg_type)
643 raise Exception("Unhandled message type: %s"%msg.msg_type)
639 else:
644 else:
640 handler(msg)
645 handler(msg)
641 msg = self.session.recv(self._notification_socket, mode=zmq.NOBLOCK)
646 msg = self.session.recv(self._notification_socket, mode=zmq.NOBLOCK)
642
647
643 def _flush_results(self, sock):
648 def _flush_results(self, sock):
644 """Flush task or queue results waiting in ZMQ queue."""
649 """Flush task or queue results waiting in ZMQ queue."""
645 msg = self.session.recv(sock, mode=zmq.NOBLOCK)
650 msg = self.session.recv(sock, mode=zmq.NOBLOCK)
646 while msg is not None:
651 while msg is not None:
647 if self.debug:
652 if self.debug:
648 pprint(msg)
653 pprint(msg)
649 msg = msg[-1]
654 msg = msg[-1]
650 msg_type = msg['msg_type']
655 msg_type = msg['msg_type']
651 handler = self._queue_handlers.get(msg_type, None)
656 handler = self._queue_handlers.get(msg_type, None)
652 if handler is None:
657 if handler is None:
653 raise Exception("Unhandled message type: %s"%msg.msg_type)
658 raise Exception("Unhandled message type: %s"%msg.msg_type)
654 else:
659 else:
655 handler(msg)
660 handler(msg)
656 msg = self.session.recv(sock, mode=zmq.NOBLOCK)
661 msg = self.session.recv(sock, mode=zmq.NOBLOCK)
657
662
658 def _flush_control(self, sock):
663 def _flush_control(self, sock):
659 """Flush replies from the control channel waiting
664 """Flush replies from the control channel waiting
660 in the ZMQ queue.
665 in the ZMQ queue.
661
666
662 Currently: ignore them."""
667 Currently: ignore them."""
663 msg = self.session.recv(sock, mode=zmq.NOBLOCK)
668 msg = self.session.recv(sock, mode=zmq.NOBLOCK)
664 while msg is not None:
669 while msg is not None:
665 if self.debug:
670 if self.debug:
666 pprint(msg)
671 pprint(msg)
667 msg = self.session.recv(sock, mode=zmq.NOBLOCK)
672 msg = self.session.recv(sock, mode=zmq.NOBLOCK)
668
673
669 def _flush_iopub(self, sock):
674 def _flush_iopub(self, sock):
670 """Flush replies from the iopub channel waiting
675 """Flush replies from the iopub channel waiting
671 in the ZMQ queue.
676 in the ZMQ queue.
672 """
677 """
673 msg = self.session.recv(sock, mode=zmq.NOBLOCK)
678 msg = self.session.recv(sock, mode=zmq.NOBLOCK)
674 while msg is not None:
679 while msg is not None:
675 if self.debug:
680 if self.debug:
676 pprint(msg)
681 pprint(msg)
677 msg = msg[-1]
682 msg = msg[-1]
678 parent = msg['parent_header']
683 parent = msg['parent_header']
679 msg_id = parent['msg_id']
684 msg_id = parent['msg_id']
680 content = msg['content']
685 content = msg['content']
681 header = msg['header']
686 header = msg['header']
682 msg_type = msg['msg_type']
687 msg_type = msg['msg_type']
683
688
684 # init metadata:
689 # init metadata:
685 md = self.metadata[msg_id]
690 md = self.metadata[msg_id]
686
691
687 if msg_type == 'stream':
692 if msg_type == 'stream':
688 name = content['name']
693 name = content['name']
689 s = md[name] or ''
694 s = md[name] or ''
690 md[name] = s + content['data']
695 md[name] = s + content['data']
691 elif msg_type == 'pyerr':
696 elif msg_type == 'pyerr':
692 md.update({'pyerr' : self._unwrap_exception(content)})
697 md.update({'pyerr' : self._unwrap_exception(content)})
693 else:
698 else:
694 md.update({msg_type : content['data']})
699 md.update({msg_type : content['data']})
695
700
696 # reduntant?
701 # reduntant?
697 self.metadata[msg_id] = md
702 self.metadata[msg_id] = md
698
703
699 msg = self.session.recv(sock, mode=zmq.NOBLOCK)
704 msg = self.session.recv(sock, mode=zmq.NOBLOCK)
700
705
701 #--------------------------------------------------------------------------
706 #--------------------------------------------------------------------------
702 # len, getitem
707 # len, getitem
703 #--------------------------------------------------------------------------
708 #--------------------------------------------------------------------------
704
709
705 def __len__(self):
710 def __len__(self):
706 """len(client) returns # of engines."""
711 """len(client) returns # of engines."""
707 return len(self.ids)
712 return len(self.ids)
708
713
709 def __getitem__(self, key):
714 def __getitem__(self, key):
710 """index access returns DirectView multiplexer objects
715 """index access returns DirectView multiplexer objects
711
716
712 Must be int, slice, or list/tuple/xrange of ints"""
717 Must be int, slice, or list/tuple/xrange of ints"""
713 if not isinstance(key, (int, slice, tuple, list, xrange)):
718 if not isinstance(key, (int, slice, tuple, list, xrange)):
714 raise TypeError("key by int/slice/iterable of ints only, not %s"%(type(key)))
719 raise TypeError("key by int/slice/iterable of ints only, not %s"%(type(key)))
715 else:
720 else:
716 return self.view(key, balanced=False)
721 return self.view(key, balanced=False)
717
722
718 #--------------------------------------------------------------------------
723 #--------------------------------------------------------------------------
719 # Begin public methods
724 # Begin public methods
720 #--------------------------------------------------------------------------
725 #--------------------------------------------------------------------------
721
726
722 def spin(self):
727 def spin(self):
723 """Flush any registration notifications and execution results
728 """Flush any registration notifications and execution results
724 waiting in the ZMQ queue.
729 waiting in the ZMQ queue.
725 """
730 """
726 if self._notification_socket:
731 if self._notification_socket:
727 self._flush_notifications()
732 self._flush_notifications()
728 if self._mux_socket:
733 if self._apply_socket:
729 self._flush_results(self._mux_socket)
734 self._flush_results(self._apply_socket)
730 if self._task_socket:
731 self._flush_results(self._task_socket)
732 if self._control_socket:
735 if self._control_socket:
733 self._flush_control(self._control_socket)
736 self._flush_control(self._control_socket)
734 if self._iopub_socket:
737 if self._iopub_socket:
735 self._flush_iopub(self._iopub_socket)
738 self._flush_iopub(self._iopub_socket)
736
739
737 def barrier(self, jobs=None, timeout=-1):
740 def barrier(self, jobs=None, timeout=-1):
738 """waits on one or more `jobs`, for up to `timeout` seconds.
741 """waits on one or more `jobs`, for up to `timeout` seconds.
739
742
740 Parameters
743 Parameters
741 ----------
744 ----------
742
745
743 jobs : int, str, or list of ints and/or strs, or one or more AsyncResult objects
746 jobs : int, str, or list of ints and/or strs, or one or more AsyncResult objects
744 ints are indices to self.history
747 ints are indices to self.history
745 strs are msg_ids
748 strs are msg_ids
746 default: wait on all outstanding messages
749 default: wait on all outstanding messages
747 timeout : float
750 timeout : float
748 a time in seconds, after which to give up.
751 a time in seconds, after which to give up.
749 default is -1, which means no timeout
752 default is -1, which means no timeout
750
753
751 Returns
754 Returns
752 -------
755 -------
753
756
754 True : when all msg_ids are done
757 True : when all msg_ids are done
755 False : timeout reached, some msg_ids still outstanding
758 False : timeout reached, some msg_ids still outstanding
756 """
759 """
757 tic = time.time()
760 tic = time.time()
758 if jobs is None:
761 if jobs is None:
759 theids = self.outstanding
762 theids = self.outstanding
760 else:
763 else:
761 if isinstance(jobs, (int, str, AsyncResult)):
764 if isinstance(jobs, (int, str, AsyncResult)):
762 jobs = [jobs]
765 jobs = [jobs]
763 theids = set()
766 theids = set()
764 for job in jobs:
767 for job in jobs:
765 if isinstance(job, int):
768 if isinstance(job, int):
766 # index access
769 # index access
767 job = self.history[job]
770 job = self.history[job]
768 elif isinstance(job, AsyncResult):
771 elif isinstance(job, AsyncResult):
769 map(theids.add, job.msg_ids)
772 map(theids.add, job.msg_ids)
770 continue
773 continue
771 theids.add(job)
774 theids.add(job)
772 if not theids.intersection(self.outstanding):
775 if not theids.intersection(self.outstanding):
773 return True
776 return True
774 self.spin()
777 self.spin()
775 while theids.intersection(self.outstanding):
778 while theids.intersection(self.outstanding):
776 if timeout >= 0 and ( time.time()-tic ) > timeout:
779 if timeout >= 0 and ( time.time()-tic ) > timeout:
777 break
780 break
778 time.sleep(1e-3)
781 time.sleep(1e-3)
779 self.spin()
782 self.spin()
780 return len(theids.intersection(self.outstanding)) == 0
783 return len(theids.intersection(self.outstanding)) == 0
781
784
782 #--------------------------------------------------------------------------
785 #--------------------------------------------------------------------------
783 # Control methods
786 # Control methods
784 #--------------------------------------------------------------------------
787 #--------------------------------------------------------------------------
785
788
786 @spinfirst
789 @spinfirst
787 @defaultblock
790 @defaultblock
788 def clear(self, targets=None, block=None):
791 def clear(self, targets=None, block=None):
789 """Clear the namespace in target(s)."""
792 """Clear the namespace in target(s)."""
790 targets = self._build_targets(targets)[0]
793 targets = self._build_targets(targets)[0]
791 for t in targets:
794 for t in targets:
792 self.session.send(self._control_socket, 'clear_request', content={}, ident=t)
795 self.session.send(self._control_socket, 'clear_request', content={}, ident=t)
793 error = False
796 error = False
794 if self.block:
797 if self.block:
795 for i in range(len(targets)):
798 for i in range(len(targets)):
796 idents,msg = self.session.recv(self._control_socket,0)
799 idents,msg = self.session.recv(self._control_socket,0)
797 if self.debug:
800 if self.debug:
798 pprint(msg)
801 pprint(msg)
799 if msg['content']['status'] != 'ok':
802 if msg['content']['status'] != 'ok':
800 error = self._unwrap_exception(msg['content'])
803 error = self._unwrap_exception(msg['content'])
801 if error:
804 if error:
802 raise error
805 raise error
803
806
804
807
805 @spinfirst
808 @spinfirst
806 @defaultblock
809 @defaultblock
807 def abort(self, jobs=None, targets=None, block=None):
810 def abort(self, jobs=None, targets=None, block=None):
808 """Abort specific jobs from the execution queues of target(s).
811 """Abort specific jobs from the execution queues of target(s).
809
812
810 This is a mechanism to prevent jobs that have already been submitted
813 This is a mechanism to prevent jobs that have already been submitted
811 from executing.
814 from executing.
812
815
813 Parameters
816 Parameters
814 ----------
817 ----------
815
818
816 jobs : msg_id, list of msg_ids, or AsyncResult
819 jobs : msg_id, list of msg_ids, or AsyncResult
817 The jobs to be aborted
820 The jobs to be aborted
818
821
819
822
820 """
823 """
821 targets = self._build_targets(targets)[0]
824 targets = self._build_targets(targets)[0]
822 msg_ids = []
825 msg_ids = []
823 if isinstance(jobs, (basestring,AsyncResult)):
826 if isinstance(jobs, (basestring,AsyncResult)):
824 jobs = [jobs]
827 jobs = [jobs]
825 bad_ids = filter(lambda obj: not isinstance(obj, (basestring, AsyncResult)), jobs)
828 bad_ids = filter(lambda obj: not isinstance(obj, (basestring, AsyncResult)), jobs)
826 if bad_ids:
829 if bad_ids:
827 raise TypeError("Invalid msg_id type %r, expected str or AsyncResult"%bad_ids[0])
830 raise TypeError("Invalid msg_id type %r, expected str or AsyncResult"%bad_ids[0])
828 for j in jobs:
831 for j in jobs:
829 if isinstance(j, AsyncResult):
832 if isinstance(j, AsyncResult):
830 msg_ids.extend(j.msg_ids)
833 msg_ids.extend(j.msg_ids)
831 else:
834 else:
832 msg_ids.append(j)
835 msg_ids.append(j)
833 content = dict(msg_ids=msg_ids)
836 content = dict(msg_ids=msg_ids)
834 for t in targets:
837 for t in targets:
835 self.session.send(self._control_socket, 'abort_request',
838 self.session.send(self._control_socket, 'abort_request',
836 content=content, ident=t)
839 content=content, ident=t)
837 error = False
840 error = False
838 if self.block:
841 if self.block:
839 for i in range(len(targets)):
842 for i in range(len(targets)):
840 idents,msg = self.session.recv(self._control_socket,0)
843 idents,msg = self.session.recv(self._control_socket,0)
841 if self.debug:
844 if self.debug:
842 pprint(msg)
845 pprint(msg)
843 if msg['content']['status'] != 'ok':
846 if msg['content']['status'] != 'ok':
844 error = self._unwrap_exception(msg['content'])
847 error = self._unwrap_exception(msg['content'])
845 if error:
848 if error:
846 raise error
849 raise error
847
850
848 @spinfirst
851 @spinfirst
849 @defaultblock
852 @defaultblock
850 def shutdown(self, targets=None, restart=False, controller=False, block=None):
853 def shutdown(self, targets=None, restart=False, controller=False, block=None):
851 """Terminates one or more engine processes, optionally including the controller."""
854 """Terminates one or more engine processes, optionally including the controller."""
852 if controller:
855 if controller:
853 targets = 'all'
856 targets = 'all'
854 targets = self._build_targets(targets)[0]
857 targets = self._build_targets(targets)[0]
855 for t in targets:
858 for t in targets:
856 self.session.send(self._control_socket, 'shutdown_request',
859 self.session.send(self._control_socket, 'shutdown_request',
857 content={'restart':restart},ident=t)
860 content={'restart':restart},ident=t)
858 error = False
861 error = False
859 if block or controller:
862 if block or controller:
860 for i in range(len(targets)):
863 for i in range(len(targets)):
861 idents,msg = self.session.recv(self._control_socket,0)
864 idents,msg = self.session.recv(self._control_socket,0)
862 if self.debug:
865 if self.debug:
863 pprint(msg)
866 pprint(msg)
864 if msg['content']['status'] != 'ok':
867 if msg['content']['status'] != 'ok':
865 error = self._unwrap_exception(msg['content'])
868 error = self._unwrap_exception(msg['content'])
866
869
867 if controller:
870 if controller:
868 time.sleep(0.25)
871 time.sleep(0.25)
869 self.session.send(self._query_socket, 'shutdown_request')
872 self.session.send(self._query_socket, 'shutdown_request')
870 idents,msg = self.session.recv(self._query_socket, 0)
873 idents,msg = self.session.recv(self._query_socket, 0)
871 if self.debug:
874 if self.debug:
872 pprint(msg)
875 pprint(msg)
873 if msg['content']['status'] != 'ok':
876 if msg['content']['status'] != 'ok':
874 error = self._unwrap_exception(msg['content'])
877 error = self._unwrap_exception(msg['content'])
875
878
876 if error:
879 if error:
877 raise error
880 raise error
878
881
879 #--------------------------------------------------------------------------
882 #--------------------------------------------------------------------------
880 # Execution methods
883 # Execution methods
881 #--------------------------------------------------------------------------
884 #--------------------------------------------------------------------------
882
885
883 @defaultblock
886 @defaultblock
884 def execute(self, code, targets='all', block=None):
887 def execute(self, code, targets='all', block=None):
885 """Executes `code` on `targets` in blocking or nonblocking manner.
888 """Executes `code` on `targets` in blocking or nonblocking manner.
886
889
887 ``execute`` is always `bound` (affects engine namespace)
890 ``execute`` is always `bound` (affects engine namespace)
888
891
889 Parameters
892 Parameters
890 ----------
893 ----------
891
894
892 code : str
895 code : str
893 the code string to be executed
896 the code string to be executed
894 targets : int/str/list of ints/strs
897 targets : int/str/list of ints/strs
895 the engines on which to execute
898 the engines on which to execute
896 default : all
899 default : all
897 block : bool
900 block : bool
898 whether or not to wait until done to return
901 whether or not to wait until done to return
899 default: self.block
902 default: self.block
900 """
903 """
901 result = self.apply(_execute, (code,), targets=targets, block=block, bound=True, balanced=False)
904 result = self.apply(_execute, (code,), targets=targets, block=block, bound=True, balanced=False)
902 if not block:
905 if not block:
903 return result
906 return result
904
907
905 def run(self, filename, targets='all', block=None):
908 def run(self, filename, targets='all', block=None):
906 """Execute contents of `filename` on engine(s).
909 """Execute contents of `filename` on engine(s).
907
910
908 This simply reads the contents of the file and calls `execute`.
911 This simply reads the contents of the file and calls `execute`.
909
912
910 Parameters
913 Parameters
911 ----------
914 ----------
912
915
913 filename : str
916 filename : str
914 The path to the file
917 The path to the file
915 targets : int/str/list of ints/strs
918 targets : int/str/list of ints/strs
916 the engines on which to execute
919 the engines on which to execute
917 default : all
920 default : all
918 block : bool
921 block : bool
919 whether or not to wait until done
922 whether or not to wait until done
920 default: self.block
923 default: self.block
921
924
922 """
925 """
923 with open(filename, 'r') as f:
926 with open(filename, 'r') as f:
924 # add newline in case of trailing indented whitespace
927 # add newline in case of trailing indented whitespace
925 # which will cause SyntaxError
928 # which will cause SyntaxError
926 code = f.read()+'\n'
929 code = f.read()+'\n'
927 return self.execute(code, targets=targets, block=block)
930 return self.execute(code, targets=targets, block=block)
928
931
929 def _maybe_raise(self, result):
932 def _maybe_raise(self, result):
930 """wrapper for maybe raising an exception if apply failed."""
933 """wrapper for maybe raising an exception if apply failed."""
931 if isinstance(result, error.RemoteError):
934 if isinstance(result, error.RemoteError):
932 raise result
935 raise result
933
936
934 return result
937 return result
935
938
936 def _build_dependency(self, dep):
939 def _build_dependency(self, dep):
937 """helper for building jsonable dependencies from various input forms"""
940 """helper for building jsonable dependencies from various input forms"""
938 if isinstance(dep, Dependency):
941 if isinstance(dep, Dependency):
939 return dep.as_dict()
942 return dep.as_dict()
940 elif isinstance(dep, AsyncResult):
943 elif isinstance(dep, AsyncResult):
941 return dep.msg_ids
944 return dep.msg_ids
942 elif dep is None:
945 elif dep is None:
943 return []
946 return []
944 else:
947 else:
945 # pass to Dependency constructor
948 # pass to Dependency constructor
946 return list(Dependency(dep))
949 return list(Dependency(dep))
947
950
948 @defaultblock
951 @defaultblock
949 def apply(self, f, args=None, kwargs=None, bound=False, block=None,
952 def apply(self, f, args=None, kwargs=None, bound=False, block=None,
950 targets=None, balanced=None,
953 targets=None, balanced=None,
951 after=None, follow=None, timeout=None,
954 after=None, follow=None, timeout=None,
952 track=False):
955 track=False):
953 """Call `f(*args, **kwargs)` on a remote engine(s), returning the result.
956 """Call `f(*args, **kwargs)` on a remote engine(s), returning the result.
954
957
955 This is the central execution command for the client.
958 This is the central execution command for the client.
956
959
957 Parameters
960 Parameters
958 ----------
961 ----------
959
962
960 f : function
963 f : function
961 The fuction to be called remotely
964 The fuction to be called remotely
962 args : tuple/list
965 args : tuple/list
963 The positional arguments passed to `f`
966 The positional arguments passed to `f`
964 kwargs : dict
967 kwargs : dict
965 The keyword arguments passed to `f`
968 The keyword arguments passed to `f`
966 bound : bool (default: False)
969 bound : bool (default: False)
967 Whether to pass the Engine(s) Namespace as the first argument to `f`.
970 Whether to pass the Engine(s) Namespace as the first argument to `f`.
968 block : bool (default: self.block)
971 block : bool (default: self.block)
969 Whether to wait for the result, or return immediately.
972 Whether to wait for the result, or return immediately.
970 False:
973 False:
971 returns AsyncResult
974 returns AsyncResult
972 True:
975 True:
973 returns actual result(s) of f(*args, **kwargs)
976 returns actual result(s) of f(*args, **kwargs)
974 if multiple targets:
977 if multiple targets:
975 list of results, matching `targets`
978 list of results, matching `targets`
976 targets : int,list of ints, 'all', None
979 targets : int,list of ints, 'all', None
977 Specify the destination of the job.
980 Specify the destination of the job.
978 if None:
981 if None:
979 Submit via Task queue for load-balancing.
982 Submit via Task queue for load-balancing.
980 if 'all':
983 if 'all':
981 Run on all active engines
984 Run on all active engines
982 if list:
985 if list:
983 Run on each specified engine
986 Run on each specified engine
984 if int:
987 if int:
985 Run on single engine
988 Run on single engine
986
989
987 balanced : bool, default None
990 balanced : bool, default None
988 whether to load-balance. This will default to True
991 whether to load-balance. This will default to True
989 if targets is unspecified, or False if targets is specified.
992 if targets is unspecified, or False if targets is specified.
990
993
991 The following arguments are only used when balanced is True:
994 The following arguments are only used when balanced is True:
992 after : Dependency or collection of msg_ids
995 after : Dependency or collection of msg_ids
993 Only for load-balanced execution (targets=None)
996 Only for load-balanced execution (targets=None)
994 Specify a list of msg_ids as a time-based dependency.
997 Specify a list of msg_ids as a time-based dependency.
995 This job will only be run *after* the dependencies
998 This job will only be run *after* the dependencies
996 have been met.
999 have been met.
997
1000
998 follow : Dependency or collection of msg_ids
1001 follow : Dependency or collection of msg_ids
999 Only for load-balanced execution (targets=None)
1002 Only for load-balanced execution (targets=None)
1000 Specify a list of msg_ids as a location-based dependency.
1003 Specify a list of msg_ids as a location-based dependency.
1001 This job will only be run on an engine where this dependency
1004 This job will only be run on an engine where this dependency
1002 is met.
1005 is met.
1003
1006
1004 timeout : float/int or None
1007 timeout : float/int or None
1005 Only for load-balanced execution (targets=None)
1008 Only for load-balanced execution (targets=None)
1006 Specify an amount of time (in seconds) for the scheduler to
1009 Specify an amount of time (in seconds) for the scheduler to
1007 wait for dependencies to be met before failing with a
1010 wait for dependencies to be met before failing with a
1008 DependencyTimeout.
1011 DependencyTimeout.
1009 track : bool
1012 track : bool
1010 whether to track non-copying sends.
1013 whether to track non-copying sends.
1011 [default False]
1014 [default False]
1012
1015
1013 after,follow,timeout only used if `balanced=True`.
1016 after,follow,timeout only used if `balanced=True`.
1014
1017
1015 Returns
1018 Returns
1016 -------
1019 -------
1017
1020
1018 if block is False:
1021 if block is False:
1019 return AsyncResult wrapping msg_ids
1022 return AsyncResult wrapping msg_ids
1020 output of AsyncResult.get() is identical to that of `apply(...block=True)`
1023 output of AsyncResult.get() is identical to that of `apply(...block=True)`
1021 else:
1024 else:
1022 if single target:
1025 if single target:
1023 return result of `f(*args, **kwargs)`
1026 return result of `f(*args, **kwargs)`
1024 else:
1027 else:
1025 return list of results, matching `targets`
1028 return list of results, matching `targets`
1026 """
1029 """
1027 assert not self._closed, "cannot use me anymore, I'm closed!"
1030 assert not self._closed, "cannot use me anymore, I'm closed!"
1028 # defaults:
1031 # defaults:
1029 block = block if block is not None else self.block
1032 block = block if block is not None else self.block
1030 args = args if args is not None else []
1033 args = args if args is not None else []
1031 kwargs = kwargs if kwargs is not None else {}
1034 kwargs = kwargs if kwargs is not None else {}
1032
1035
1036 if not self._ids:
1037 # flush notification socket if no engines yet
1038 any_ids = self.ids
1039 if not any_ids:
1040 raise error.NoEnginesRegistered("Can't execute without any connected engines.")
1041
1033 if balanced is None:
1042 if balanced is None:
1034 if targets is None:
1043 if targets is None:
1035 # default to balanced if targets unspecified
1044 # default to balanced if targets unspecified
1036 balanced = True
1045 balanced = True
1037 else:
1046 else:
1038 # otherwise default to multiplexing
1047 # otherwise default to multiplexing
1039 balanced = False
1048 balanced = False
1040
1049
1041 if targets is None and balanced is False:
1050 if targets is None and balanced is False:
1042 # default to all if *not* balanced, and targets is unspecified
1051 # default to all if *not* balanced, and targets is unspecified
1043 targets = 'all'
1052 targets = 'all'
1044
1053
1045 # enforce types of f,args,kwrags
1054 # enforce types of f,args,kwrags
1046 if not callable(f):
1055 if not callable(f):
1047 raise TypeError("f must be callable, not %s"%type(f))
1056 raise TypeError("f must be callable, not %s"%type(f))
1048 if not isinstance(args, (tuple, list)):
1057 if not isinstance(args, (tuple, list)):
1049 raise TypeError("args must be tuple or list, not %s"%type(args))
1058 raise TypeError("args must be tuple or list, not %s"%type(args))
1050 if not isinstance(kwargs, dict):
1059 if not isinstance(kwargs, dict):
1051 raise TypeError("kwargs must be dict, not %s"%type(kwargs))
1060 raise TypeError("kwargs must be dict, not %s"%type(kwargs))
1052
1061
1053 options = dict(bound=bound, block=block, targets=targets, track=track)
1062 options = dict(bound=bound, block=block, targets=targets, track=track)
1054
1063
1055 if balanced:
1064 if balanced:
1056 return self._apply_balanced(f, args, kwargs, timeout=timeout,
1065 return self._apply_balanced(f, args, kwargs, timeout=timeout,
1057 after=after, follow=follow, **options)
1066 after=after, follow=follow, **options)
1058 elif follow or after or timeout:
1067 elif follow or after or timeout:
1059 msg = "follow, after, and timeout args are only used for"
1068 msg = "follow, after, and timeout args are only used for"
1060 msg += " load-balanced execution."
1069 msg += " load-balanced execution."
1061 raise ValueError(msg)
1070 raise ValueError(msg)
1062 else:
1071 else:
1063 return self._apply_direct(f, args, kwargs, **options)
1072 return self._apply_direct(f, args, kwargs, **options)
1064
1073
1065 def _apply_balanced(self, f, args, kwargs, bound=None, block=None, targets=None,
1074 def _apply_balanced(self, f, args, kwargs, bound=None, block=None, targets=None,
1066 after=None, follow=None, timeout=None, track=None):
1075 after=None, follow=None, timeout=None, track=None):
1067 """call f(*args, **kwargs) remotely in a load-balanced manner.
1076 """call f(*args, **kwargs) remotely in a load-balanced manner.
1068
1077
1069 This is a private method, see `apply` for details.
1078 This is a private method, see `apply` for details.
1070 Not to be called directly!
1079 Not to be called directly!
1071 """
1080 """
1072
1081
1073 loc = locals()
1082 loc = locals()
1074 for name in ('bound', 'block', 'track'):
1083 for name in ('bound', 'block', 'track'):
1075 assert loc[name] is not None, "kwarg %r must be specified!"%name
1084 assert loc[name] is not None, "kwarg %r must be specified!"%name
1076
1085
1077 if self._task_socket is None:
1086 if not self._task_ident:
1078 msg = "Task farming is disabled"
1087 msg = "Task farming is disabled"
1079 if self._task_scheme == 'pure':
1088 if self._task_scheme == 'pure':
1080 msg += " because the pure ZMQ scheduler cannot handle"
1089 msg += " because the pure ZMQ scheduler cannot handle"
1081 msg += " disappearing engines."
1090 msg += " disappearing engines."
1082 raise RuntimeError(msg)
1091 raise RuntimeError(msg)
1083
1092
1084 if self._task_scheme == 'pure':
1093 if self._task_scheme == 'pure':
1085 # pure zmq scheme doesn't support dependencies
1094 # pure zmq scheme doesn't support dependencies
1086 msg = "Pure ZMQ scheduler doesn't support dependencies"
1095 msg = "Pure ZMQ scheduler doesn't support dependencies"
1087 if (follow or after):
1096 if (follow or after):
1088 # hard fail on DAG dependencies
1097 # hard fail on DAG dependencies
1089 raise RuntimeError(msg)
1098 raise RuntimeError(msg)
1090 if isinstance(f, dependent):
1099 if isinstance(f, dependent):
1091 # soft warn on functional dependencies
1100 # soft warn on functional dependencies
1092 warnings.warn(msg, RuntimeWarning)
1101 warnings.warn(msg, RuntimeWarning)
1093
1102
1094 # defaults:
1103 # defaults:
1095 args = args if args is not None else []
1104 args = args if args is not None else []
1096 kwargs = kwargs if kwargs is not None else {}
1105 kwargs = kwargs if kwargs is not None else {}
1097
1106
1098 if targets:
1107 if targets:
1099 idents,_ = self._build_targets(targets)
1108 idents,_ = self._build_targets(targets)
1100 else:
1109 else:
1101 idents = []
1110 idents = []
1102
1111
1103 after = self._build_dependency(after)
1112 after = self._build_dependency(after)
1104 follow = self._build_dependency(follow)
1113 follow = self._build_dependency(follow)
1105 subheader = dict(after=after, follow=follow, timeout=timeout, targets=idents)
1114 subheader = dict(after=after, follow=follow, timeout=timeout, targets=idents)
1106 bufs = util.pack_apply_message(f,args,kwargs)
1115 bufs = util.pack_apply_message(f,args,kwargs)
1107 content = dict(bound=bound)
1116 content = dict(bound=bound)
1108
1117
1109 msg = self.session.send(self._task_socket, "apply_request",
1118 msg = self.session.send(self._apply_socket, "apply_request", ident=self._task_ident,
1110 content=content, buffers=bufs, subheader=subheader, track=track)
1119 content=content, buffers=bufs, subheader=subheader, track=track)
1111 msg_id = msg['msg_id']
1120 msg_id = msg['msg_id']
1112 self.outstanding.add(msg_id)
1121 self.outstanding.add(msg_id)
1113 self.history.append(msg_id)
1122 self.history.append(msg_id)
1114 self.metadata[msg_id]['submitted'] = datetime.now()
1123 self.metadata[msg_id]['submitted'] = datetime.now()
1115 tracker = None if track is False else msg['tracker']
1124 tracker = None if track is False else msg['tracker']
1116 ar = AsyncResult(self, [msg_id], fname=f.__name__, targets=targets, tracker=tracker)
1125 ar = AsyncResult(self, [msg_id], fname=f.__name__, targets=targets, tracker=tracker)
1117 if block:
1126 if block:
1118 try:
1127 try:
1119 return ar.get()
1128 return ar.get()
1120 except KeyboardInterrupt:
1129 except KeyboardInterrupt:
1121 return ar
1130 return ar
1122 else:
1131 else:
1123 return ar
1132 return ar
1124
1133
1125 def _apply_direct(self, f, args, kwargs, bound=None, block=None, targets=None,
1134 def _apply_direct(self, f, args, kwargs, bound=None, block=None, targets=None,
1126 track=None):
1135 track=None):
1127 """Then underlying method for applying functions to specific engines
1136 """Then underlying method for applying functions to specific engines
1128 via the MUX queue.
1137 via the MUX queue.
1129
1138
1130 This is a private method, see `apply` for details.
1139 This is a private method, see `apply` for details.
1131 Not to be called directly!
1140 Not to be called directly!
1132 """
1141 """
1142
1143 if not self._mux_ident:
1144 msg = "Multiplexing is disabled"
1145 raise RuntimeError(msg)
1146
1133 loc = locals()
1147 loc = locals()
1134 for name in ('bound', 'block', 'targets', 'track'):
1148 for name in ('bound', 'block', 'targets', 'track'):
1135 assert loc[name] is not None, "kwarg %r must be specified!"%name
1149 assert loc[name] is not None, "kwarg %r must be specified!"%name
1136
1150
1137 idents,targets = self._build_targets(targets)
1151 idents,targets = self._build_targets(targets)
1138
1152
1139 subheader = {}
1153 subheader = {}
1140 content = dict(bound=bound)
1154 content = dict(bound=bound)
1141 bufs = util.pack_apply_message(f,args,kwargs)
1155 bufs = util.pack_apply_message(f,args,kwargs)
1142
1156
1143 msg_ids = []
1157 msg_ids = []
1144 trackers = []
1158 trackers = []
1145 for ident in idents:
1159 for ident in idents:
1146 msg = self.session.send(self._mux_socket, "apply_request",
1160 msg = self.session.send(self._apply_socket, "apply_request",
1147 content=content, buffers=bufs, ident=ident, subheader=subheader,
1161 content=content, buffers=bufs, ident=[self._mux_ident, ident], subheader=subheader,
1148 track=track)
1162 track=track)
1149 if track:
1163 if track:
1150 trackers.append(msg['tracker'])
1164 trackers.append(msg['tracker'])
1151 msg_id = msg['msg_id']
1165 msg_id = msg['msg_id']
1152 self.outstanding.add(msg_id)
1166 self.outstanding.add(msg_id)
1153 self._outstanding_dict[ident].add(msg_id)
1167 self._outstanding_dict[ident].add(msg_id)
1154 self.history.append(msg_id)
1168 self.history.append(msg_id)
1155 msg_ids.append(msg_id)
1169 msg_ids.append(msg_id)
1156
1170
1157 tracker = None if track is False else zmq.MessageTracker(*trackers)
1171 tracker = None if track is False else zmq.MessageTracker(*trackers)
1158 ar = AsyncResult(self, msg_ids, fname=f.__name__, targets=targets, tracker=tracker)
1172 ar = AsyncResult(self, msg_ids, fname=f.__name__, targets=targets, tracker=tracker)
1159
1173
1160 if block:
1174 if block:
1161 try:
1175 try:
1162 return ar.get()
1176 return ar.get()
1163 except KeyboardInterrupt:
1177 except KeyboardInterrupt:
1164 return ar
1178 return ar
1165 else:
1179 else:
1166 return ar
1180 return ar
1167
1181
1168 #--------------------------------------------------------------------------
1182 #--------------------------------------------------------------------------
1169 # construct a View object
1183 # construct a View object
1170 #--------------------------------------------------------------------------
1184 #--------------------------------------------------------------------------
1171
1185
1172 @defaultblock
1186 @defaultblock
1173 def remote(self, bound=False, block=None, targets=None, balanced=None):
1187 def remote(self, bound=False, block=None, targets=None, balanced=None):
1174 """Decorator for making a RemoteFunction"""
1188 """Decorator for making a RemoteFunction"""
1175 return remote(self, bound=bound, targets=targets, block=block, balanced=balanced)
1189 return remote(self, bound=bound, targets=targets, block=block, balanced=balanced)
1176
1190
1177 @defaultblock
1191 @defaultblock
1178 def parallel(self, dist='b', bound=False, block=None, targets=None, balanced=None):
1192 def parallel(self, dist='b', bound=False, block=None, targets=None, balanced=None):
1179 """Decorator for making a ParallelFunction"""
1193 """Decorator for making a ParallelFunction"""
1180 return parallel(self, bound=bound, targets=targets, block=block, balanced=balanced)
1194 return parallel(self, bound=bound, targets=targets, block=block, balanced=balanced)
1181
1195
1182 def _cache_view(self, targets, balanced):
1196 def _cache_view(self, targets, balanced):
1183 """save views, so subsequent requests don't create new objects."""
1197 """save views, so subsequent requests don't create new objects."""
1184 if balanced:
1198 if balanced:
1185 view_class = LoadBalancedView
1199 view_class = LoadBalancedView
1186 view_cache = self._balanced_views
1200 view_cache = self._balanced_views
1187 else:
1201 else:
1188 view_class = DirectView
1202 view_class = DirectView
1189 view_cache = self._direct_views
1203 view_cache = self._direct_views
1190
1204
1191 # use str, since often targets will be a list
1205 # use str, since often targets will be a list
1192 key = str(targets)
1206 key = str(targets)
1193 if key not in view_cache:
1207 if key not in view_cache:
1194 view_cache[key] = view_class(client=self, targets=targets)
1208 view_cache[key] = view_class(client=self, targets=targets)
1195
1209
1196 return view_cache[key]
1210 return view_cache[key]
1197
1211
1198 def view(self, targets=None, balanced=None):
1212 def view(self, targets=None, balanced=None):
1199 """Method for constructing View objects.
1213 """Method for constructing View objects.
1200
1214
1201 If no arguments are specified, create a LoadBalancedView
1215 If no arguments are specified, create a LoadBalancedView
1202 using all engines. If only `targets` specified, it will
1216 using all engines. If only `targets` specified, it will
1203 be a DirectView. This method is the underlying implementation
1217 be a DirectView. This method is the underlying implementation
1204 of ``client.__getitem__``.
1218 of ``client.__getitem__``.
1205
1219
1206 Parameters
1220 Parameters
1207 ----------
1221 ----------
1208
1222
1209 targets: list,slice,int,etc. [default: use all engines]
1223 targets: list,slice,int,etc. [default: use all engines]
1210 The engines to use for the View
1224 The engines to use for the View
1211 balanced : bool [default: False if targets specified, True else]
1225 balanced : bool [default: False if targets specified, True else]
1212 whether to build a LoadBalancedView or a DirectView
1226 whether to build a LoadBalancedView or a DirectView
1213
1227
1214 """
1228 """
1215
1229
1216 balanced = (targets is None) if balanced is None else balanced
1230 balanced = (targets is None) if balanced is None else balanced
1217
1231
1218 if targets is None:
1232 if targets is None:
1219 if balanced:
1233 if balanced:
1220 return self._cache_view(None,True)
1234 return self._cache_view(None,True)
1221 else:
1235 else:
1222 targets = slice(None)
1236 targets = slice(None)
1223
1237
1224 if isinstance(targets, int):
1238 if isinstance(targets, int):
1225 if targets < 0:
1239 if targets < 0:
1226 targets = self.ids[targets]
1240 targets = self.ids[targets]
1227 if targets not in self.ids:
1241 if targets not in self.ids:
1228 raise IndexError("No such engine: %i"%targets)
1242 raise IndexError("No such engine: %i"%targets)
1229 return self._cache_view(targets, balanced)
1243 return self._cache_view(targets, balanced)
1230
1244
1231 if isinstance(targets, slice):
1245 if isinstance(targets, slice):
1232 indices = range(len(self.ids))[targets]
1246 indices = range(len(self.ids))[targets]
1233 ids = sorted(self._ids)
1247 ids = sorted(self._ids)
1234 targets = [ ids[i] for i in indices ]
1248 targets = [ ids[i] for i in indices ]
1235
1249
1236 if isinstance(targets, (tuple, list, xrange)):
1250 if isinstance(targets, (tuple, list, xrange)):
1237 _,targets = self._build_targets(list(targets))
1251 _,targets = self._build_targets(list(targets))
1238 return self._cache_view(targets, balanced)
1252 return self._cache_view(targets, balanced)
1239 else:
1253 else:
1240 raise TypeError("targets by int/slice/collection of ints only, not %s"%(type(targets)))
1254 raise TypeError("targets by int/slice/collection of ints only, not %s"%(type(targets)))
1241
1255
1242 #--------------------------------------------------------------------------
1256 #--------------------------------------------------------------------------
1243 # Data movement
1257 # Data movement
1244 #--------------------------------------------------------------------------
1258 #--------------------------------------------------------------------------
1245
1259
1246 @defaultblock
1260 @defaultblock
1247 def push(self, ns, targets='all', block=None, track=False):
1261 def push(self, ns, targets='all', block=None, track=False):
1248 """Push the contents of `ns` into the namespace on `target`"""
1262 """Push the contents of `ns` into the namespace on `target`"""
1249 if not isinstance(ns, dict):
1263 if not isinstance(ns, dict):
1250 raise TypeError("Must be a dict, not %s"%type(ns))
1264 raise TypeError("Must be a dict, not %s"%type(ns))
1251 result = self.apply(_push, kwargs=ns, targets=targets, block=block, bound=True, balanced=False, track=track)
1265 result = self.apply(_push, kwargs=ns, targets=targets, block=block, bound=True, balanced=False, track=track)
1252 if not block:
1266 if not block:
1253 return result
1267 return result
1254
1268
1255 @defaultblock
1269 @defaultblock
1256 def pull(self, keys, targets='all', block=None):
1270 def pull(self, keys, targets='all', block=None):
1257 """Pull objects from `target`'s namespace by `keys`"""
1271 """Pull objects from `target`'s namespace by `keys`"""
1258 if isinstance(keys, basestring):
1272 if isinstance(keys, basestring):
1259 pass
1273 pass
1260 elif isinstance(keys, (list,tuple,set)):
1274 elif isinstance(keys, (list,tuple,set)):
1261 for key in keys:
1275 for key in keys:
1262 if not isinstance(key, basestring):
1276 if not isinstance(key, basestring):
1263 raise TypeError("keys must be str, not type %r"%type(key))
1277 raise TypeError("keys must be str, not type %r"%type(key))
1264 else:
1278 else:
1265 raise TypeError("keys must be strs, not %r"%keys)
1279 raise TypeError("keys must be strs, not %r"%keys)
1266 result = self.apply(_pull, (keys,), targets=targets, block=block, bound=True, balanced=False)
1280 result = self.apply(_pull, (keys,), targets=targets, block=block, bound=True, balanced=False)
1267 return result
1281 return result
1268
1282
1269 @defaultblock
1283 @defaultblock
1270 def scatter(self, key, seq, dist='b', flatten=False, targets='all', block=None, track=False):
1284 def scatter(self, key, seq, dist='b', flatten=False, targets='all', block=None, track=False):
1271 """
1285 """
1272 Partition a Python sequence and send the partitions to a set of engines.
1286 Partition a Python sequence and send the partitions to a set of engines.
1273 """
1287 """
1274 targets = self._build_targets(targets)[-1]
1288 targets = self._build_targets(targets)[-1]
1275 mapObject = Map.dists[dist]()
1289 mapObject = Map.dists[dist]()
1276 nparts = len(targets)
1290 nparts = len(targets)
1277 msg_ids = []
1291 msg_ids = []
1278 trackers = []
1292 trackers = []
1279 for index, engineid in enumerate(targets):
1293 for index, engineid in enumerate(targets):
1280 partition = mapObject.getPartition(seq, index, nparts)
1294 partition = mapObject.getPartition(seq, index, nparts)
1281 if flatten and len(partition) == 1:
1295 if flatten and len(partition) == 1:
1282 r = self.push({key: partition[0]}, targets=engineid, block=False, track=track)
1296 r = self.push({key: partition[0]}, targets=engineid, block=False, track=track)
1283 else:
1297 else:
1284 r = self.push({key: partition}, targets=engineid, block=False, track=track)
1298 r = self.push({key: partition}, targets=engineid, block=False, track=track)
1285 msg_ids.extend(r.msg_ids)
1299 msg_ids.extend(r.msg_ids)
1286 if track:
1300 if track:
1287 trackers.append(r._tracker)
1301 trackers.append(r._tracker)
1288
1302
1289 if track:
1303 if track:
1290 tracker = zmq.MessageTracker(*trackers)
1304 tracker = zmq.MessageTracker(*trackers)
1291 else:
1305 else:
1292 tracker = None
1306 tracker = None
1293
1307
1294 r = AsyncResult(self, msg_ids, fname='scatter', targets=targets, tracker=tracker)
1308 r = AsyncResult(self, msg_ids, fname='scatter', targets=targets, tracker=tracker)
1295 if block:
1309 if block:
1296 r.wait()
1310 r.wait()
1297 else:
1311 else:
1298 return r
1312 return r
1299
1313
1300 @defaultblock
1314 @defaultblock
1301 def gather(self, key, dist='b', targets='all', block=None):
1315 def gather(self, key, dist='b', targets='all', block=None):
1302 """
1316 """
1303 Gather a partitioned sequence on a set of engines as a single local seq.
1317 Gather a partitioned sequence on a set of engines as a single local seq.
1304 """
1318 """
1305
1319
1306 targets = self._build_targets(targets)[-1]
1320 targets = self._build_targets(targets)[-1]
1307 mapObject = Map.dists[dist]()
1321 mapObject = Map.dists[dist]()
1308 msg_ids = []
1322 msg_ids = []
1309 for index, engineid in enumerate(targets):
1323 for index, engineid in enumerate(targets):
1310 msg_ids.extend(self.pull(key, targets=engineid,block=False).msg_ids)
1324 msg_ids.extend(self.pull(key, targets=engineid,block=False).msg_ids)
1311
1325
1312 r = AsyncMapResult(self, msg_ids, mapObject, fname='gather')
1326 r = AsyncMapResult(self, msg_ids, mapObject, fname='gather')
1313 if block:
1327 if block:
1314 return r.get()
1328 return r.get()
1315 else:
1329 else:
1316 return r
1330 return r
1317
1331
1318 #--------------------------------------------------------------------------
1332 #--------------------------------------------------------------------------
1319 # Query methods
1333 # Query methods
1320 #--------------------------------------------------------------------------
1334 #--------------------------------------------------------------------------
1321
1335
1322 @spinfirst
1336 @spinfirst
1323 @defaultblock
1337 @defaultblock
1324 def get_result(self, indices_or_msg_ids=None, block=None):
1338 def get_result(self, indices_or_msg_ids=None, block=None):
1325 """Retrieve a result by msg_id or history index, wrapped in an AsyncResult object.
1339 """Retrieve a result by msg_id or history index, wrapped in an AsyncResult object.
1326
1340
1327 If the client already has the results, no request to the Hub will be made.
1341 If the client already has the results, no request to the Hub will be made.
1328
1342
1329 This is a convenient way to construct AsyncResult objects, which are wrappers
1343 This is a convenient way to construct AsyncResult objects, which are wrappers
1330 that include metadata about execution, and allow for awaiting results that
1344 that include metadata about execution, and allow for awaiting results that
1331 were not submitted by this Client.
1345 were not submitted by this Client.
1332
1346
1333 It can also be a convenient way to retrieve the metadata associated with
1347 It can also be a convenient way to retrieve the metadata associated with
1334 blocking execution, since it always retrieves
1348 blocking execution, since it always retrieves
1335
1349
1336 Examples
1350 Examples
1337 --------
1351 --------
1338 ::
1352 ::
1339
1353
1340 In [10]: r = client.apply()
1354 In [10]: r = client.apply()
1341
1355
1342 Parameters
1356 Parameters
1343 ----------
1357 ----------
1344
1358
1345 indices_or_msg_ids : integer history index, str msg_id, or list of either
1359 indices_or_msg_ids : integer history index, str msg_id, or list of either
1346 The indices or msg_ids of indices to be retrieved
1360 The indices or msg_ids of indices to be retrieved
1347
1361
1348 block : bool
1362 block : bool
1349 Whether to wait for the result to be done
1363 Whether to wait for the result to be done
1350
1364
1351 Returns
1365 Returns
1352 -------
1366 -------
1353
1367
1354 AsyncResult
1368 AsyncResult
1355 A single AsyncResult object will always be returned.
1369 A single AsyncResult object will always be returned.
1356
1370
1357 AsyncHubResult
1371 AsyncHubResult
1358 A subclass of AsyncResult that retrieves results from the Hub
1372 A subclass of AsyncResult that retrieves results from the Hub
1359
1373
1360 """
1374 """
1361 if indices_or_msg_ids is None:
1375 if indices_or_msg_ids is None:
1362 indices_or_msg_ids = -1
1376 indices_or_msg_ids = -1
1363
1377
1364 if not isinstance(indices_or_msg_ids, (list,tuple)):
1378 if not isinstance(indices_or_msg_ids, (list,tuple)):
1365 indices_or_msg_ids = [indices_or_msg_ids]
1379 indices_or_msg_ids = [indices_or_msg_ids]
1366
1380
1367 theids = []
1381 theids = []
1368 for id in indices_or_msg_ids:
1382 for id in indices_or_msg_ids:
1369 if isinstance(id, int):
1383 if isinstance(id, int):
1370 id = self.history[id]
1384 id = self.history[id]
1371 if not isinstance(id, str):
1385 if not isinstance(id, str):
1372 raise TypeError("indices must be str or int, not %r"%id)
1386 raise TypeError("indices must be str or int, not %r"%id)
1373 theids.append(id)
1387 theids.append(id)
1374
1388
1375 local_ids = filter(lambda msg_id: msg_id in self.history or msg_id in self.results, theids)
1389 local_ids = filter(lambda msg_id: msg_id in self.history or msg_id in self.results, theids)
1376 remote_ids = filter(lambda msg_id: msg_id not in local_ids, theids)
1390 remote_ids = filter(lambda msg_id: msg_id not in local_ids, theids)
1377
1391
1378 if remote_ids:
1392 if remote_ids:
1379 ar = AsyncHubResult(self, msg_ids=theids)
1393 ar = AsyncHubResult(self, msg_ids=theids)
1380 else:
1394 else:
1381 ar = AsyncResult(self, msg_ids=theids)
1395 ar = AsyncResult(self, msg_ids=theids)
1382
1396
1383 if block:
1397 if block:
1384 ar.wait()
1398 ar.wait()
1385
1399
1386 return ar
1400 return ar
1387
1401
1388 @spinfirst
1402 @spinfirst
1389 def result_status(self, msg_ids, status_only=True):
1403 def result_status(self, msg_ids, status_only=True):
1390 """Check on the status of the result(s) of the apply request with `msg_ids`.
1404 """Check on the status of the result(s) of the apply request with `msg_ids`.
1391
1405
1392 If status_only is False, then the actual results will be retrieved, else
1406 If status_only is False, then the actual results will be retrieved, else
1393 only the status of the results will be checked.
1407 only the status of the results will be checked.
1394
1408
1395 Parameters
1409 Parameters
1396 ----------
1410 ----------
1397
1411
1398 msg_ids : list of msg_ids
1412 msg_ids : list of msg_ids
1399 if int:
1413 if int:
1400 Passed as index to self.history for convenience.
1414 Passed as index to self.history for convenience.
1401 status_only : bool (default: True)
1415 status_only : bool (default: True)
1402 if False:
1416 if False:
1403 Retrieve the actual results of completed tasks.
1417 Retrieve the actual results of completed tasks.
1404
1418
1405 Returns
1419 Returns
1406 -------
1420 -------
1407
1421
1408 results : dict
1422 results : dict
1409 There will always be the keys 'pending' and 'completed', which will
1423 There will always be the keys 'pending' and 'completed', which will
1410 be lists of msg_ids that are incomplete or complete. If `status_only`
1424 be lists of msg_ids that are incomplete or complete. If `status_only`
1411 is False, then completed results will be keyed by their `msg_id`.
1425 is False, then completed results will be keyed by their `msg_id`.
1412 """
1426 """
1413 if not isinstance(msg_ids, (list,tuple)):
1427 if not isinstance(msg_ids, (list,tuple)):
1414 msg_ids = [msg_ids]
1428 msg_ids = [msg_ids]
1415
1429
1416 theids = []
1430 theids = []
1417 for msg_id in msg_ids:
1431 for msg_id in msg_ids:
1418 if isinstance(msg_id, int):
1432 if isinstance(msg_id, int):
1419 msg_id = self.history[msg_id]
1433 msg_id = self.history[msg_id]
1420 if not isinstance(msg_id, basestring):
1434 if not isinstance(msg_id, basestring):
1421 raise TypeError("msg_ids must be str, not %r"%msg_id)
1435 raise TypeError("msg_ids must be str, not %r"%msg_id)
1422 theids.append(msg_id)
1436 theids.append(msg_id)
1423
1437
1424 completed = []
1438 completed = []
1425 local_results = {}
1439 local_results = {}
1426
1440
1427 # comment this block out to temporarily disable local shortcut:
1441 # comment this block out to temporarily disable local shortcut:
1428 for msg_id in theids:
1442 for msg_id in theids:
1429 if msg_id in self.results:
1443 if msg_id in self.results:
1430 completed.append(msg_id)
1444 completed.append(msg_id)
1431 local_results[msg_id] = self.results[msg_id]
1445 local_results[msg_id] = self.results[msg_id]
1432 theids.remove(msg_id)
1446 theids.remove(msg_id)
1433
1447
1434 if theids: # some not locally cached
1448 if theids: # some not locally cached
1435 content = dict(msg_ids=theids, status_only=status_only)
1449 content = dict(msg_ids=theids, status_only=status_only)
1436 msg = self.session.send(self._query_socket, "result_request", content=content)
1450 msg = self.session.send(self._query_socket, "result_request", content=content)
1437 zmq.select([self._query_socket], [], [])
1451 zmq.select([self._query_socket], [], [])
1438 idents,msg = self.session.recv(self._query_socket, zmq.NOBLOCK)
1452 idents,msg = self.session.recv(self._query_socket, zmq.NOBLOCK)
1439 if self.debug:
1453 if self.debug:
1440 pprint(msg)
1454 pprint(msg)
1441 content = msg['content']
1455 content = msg['content']
1442 if content['status'] != 'ok':
1456 if content['status'] != 'ok':
1443 raise self._unwrap_exception(content)
1457 raise self._unwrap_exception(content)
1444 buffers = msg['buffers']
1458 buffers = msg['buffers']
1445 else:
1459 else:
1446 content = dict(completed=[],pending=[])
1460 content = dict(completed=[],pending=[])
1447
1461
1448 content['completed'].extend(completed)
1462 content['completed'].extend(completed)
1449
1463
1450 if status_only:
1464 if status_only:
1451 return content
1465 return content
1452
1466
1453 failures = []
1467 failures = []
1454 # load cached results into result:
1468 # load cached results into result:
1455 content.update(local_results)
1469 content.update(local_results)
1456 # update cache with results:
1470 # update cache with results:
1457 for msg_id in sorted(theids):
1471 for msg_id in sorted(theids):
1458 if msg_id in content['completed']:
1472 if msg_id in content['completed']:
1459 rec = content[msg_id]
1473 rec = content[msg_id]
1460 parent = rec['header']
1474 parent = rec['header']
1461 header = rec['result_header']
1475 header = rec['result_header']
1462 rcontent = rec['result_content']
1476 rcontent = rec['result_content']
1463 iodict = rec['io']
1477 iodict = rec['io']
1464 if isinstance(rcontent, str):
1478 if isinstance(rcontent, str):
1465 rcontent = self.session.unpack(rcontent)
1479 rcontent = self.session.unpack(rcontent)
1466
1480
1467 md = self.metadata[msg_id]
1481 md = self.metadata[msg_id]
1468 md.update(self._extract_metadata(header, parent, rcontent))
1482 md.update(self._extract_metadata(header, parent, rcontent))
1469 md.update(iodict)
1483 md.update(iodict)
1470
1484
1471 if rcontent['status'] == 'ok':
1485 if rcontent['status'] == 'ok':
1472 res,buffers = util.unserialize_object(buffers)
1486 res,buffers = util.unserialize_object(buffers)
1473 else:
1487 else:
1474 print rcontent
1488 print rcontent
1475 res = self._unwrap_exception(rcontent)
1489 res = self._unwrap_exception(rcontent)
1476 failures.append(res)
1490 failures.append(res)
1477
1491
1478 self.results[msg_id] = res
1492 self.results[msg_id] = res
1479 content[msg_id] = res
1493 content[msg_id] = res
1480
1494
1481 if len(theids) == 1 and failures:
1495 if len(theids) == 1 and failures:
1482 raise failures[0]
1496 raise failures[0]
1483
1497
1484 error.collect_exceptions(failures, "result_status")
1498 error.collect_exceptions(failures, "result_status")
1485 return content
1499 return content
1486
1500
1487 @spinfirst
1501 @spinfirst
1488 def queue_status(self, targets='all', verbose=False):
1502 def queue_status(self, targets='all', verbose=False):
1489 """Fetch the status of engine queues.
1503 """Fetch the status of engine queues.
1490
1504
1491 Parameters
1505 Parameters
1492 ----------
1506 ----------
1493
1507
1494 targets : int/str/list of ints/strs
1508 targets : int/str/list of ints/strs
1495 the engines whose states are to be queried.
1509 the engines whose states are to be queried.
1496 default : all
1510 default : all
1497 verbose : bool
1511 verbose : bool
1498 Whether to return lengths only, or lists of ids for each element
1512 Whether to return lengths only, or lists of ids for each element
1499 """
1513 """
1500 targets = self._build_targets(targets)[1]
1514 targets = self._build_targets(targets)[1]
1501 content = dict(targets=targets, verbose=verbose)
1515 content = dict(targets=targets, verbose=verbose)
1502 self.session.send(self._query_socket, "queue_request", content=content)
1516 self.session.send(self._query_socket, "queue_request", content=content)
1503 idents,msg = self.session.recv(self._query_socket, 0)
1517 idents,msg = self.session.recv(self._query_socket, 0)
1504 if self.debug:
1518 if self.debug:
1505 pprint(msg)
1519 pprint(msg)
1506 content = msg['content']
1520 content = msg['content']
1507 status = content.pop('status')
1521 status = content.pop('status')
1508 if status != 'ok':
1522 if status != 'ok':
1509 raise self._unwrap_exception(content)
1523 raise self._unwrap_exception(content)
1510 return util.rekey(content)
1524 return util.rekey(content)
1511
1525
1512 @spinfirst
1526 @spinfirst
1513 def purge_results(self, jobs=[], targets=[]):
1527 def purge_results(self, jobs=[], targets=[]):
1514 """Tell the controller to forget results.
1528 """Tell the controller to forget results.
1515
1529
1516 Individual results can be purged by msg_id, or the entire
1530 Individual results can be purged by msg_id, or the entire
1517 history of specific targets can be purged.
1531 history of specific targets can be purged.
1518
1532
1519 Parameters
1533 Parameters
1520 ----------
1534 ----------
1521
1535
1522 jobs : str or list of strs or AsyncResult objects
1536 jobs : str or list of strs or AsyncResult objects
1523 the msg_ids whose results should be forgotten.
1537 the msg_ids whose results should be forgotten.
1524 targets : int/str/list of ints/strs
1538 targets : int/str/list of ints/strs
1525 The targets, by uuid or int_id, whose entire history is to be purged.
1539 The targets, by uuid or int_id, whose entire history is to be purged.
1526 Use `targets='all'` to scrub everything from the controller's memory.
1540 Use `targets='all'` to scrub everything from the controller's memory.
1527
1541
1528 default : None
1542 default : None
1529 """
1543 """
1530 if not targets and not jobs:
1544 if not targets and not jobs:
1531 raise ValueError("Must specify at least one of `targets` and `jobs`")
1545 raise ValueError("Must specify at least one of `targets` and `jobs`")
1532 if targets:
1546 if targets:
1533 targets = self._build_targets(targets)[1]
1547 targets = self._build_targets(targets)[1]
1534
1548
1535 # construct msg_ids from jobs
1549 # construct msg_ids from jobs
1536 msg_ids = []
1550 msg_ids = []
1537 if isinstance(jobs, (basestring,AsyncResult)):
1551 if isinstance(jobs, (basestring,AsyncResult)):
1538 jobs = [jobs]
1552 jobs = [jobs]
1539 bad_ids = filter(lambda obj: not isinstance(obj, (basestring, AsyncResult)), jobs)
1553 bad_ids = filter(lambda obj: not isinstance(obj, (basestring, AsyncResult)), jobs)
1540 if bad_ids:
1554 if bad_ids:
1541 raise TypeError("Invalid msg_id type %r, expected str or AsyncResult"%bad_ids[0])
1555 raise TypeError("Invalid msg_id type %r, expected str or AsyncResult"%bad_ids[0])
1542 for j in jobs:
1556 for j in jobs:
1543 if isinstance(j, AsyncResult):
1557 if isinstance(j, AsyncResult):
1544 msg_ids.extend(j.msg_ids)
1558 msg_ids.extend(j.msg_ids)
1545 else:
1559 else:
1546 msg_ids.append(j)
1560 msg_ids.append(j)
1547
1561
1548 content = dict(targets=targets, msg_ids=msg_ids)
1562 content = dict(targets=targets, msg_ids=msg_ids)
1549 self.session.send(self._query_socket, "purge_request", content=content)
1563 self.session.send(self._query_socket, "purge_request", content=content)
1550 idents, msg = self.session.recv(self._query_socket, 0)
1564 idents, msg = self.session.recv(self._query_socket, 0)
1551 if self.debug:
1565 if self.debug:
1552 pprint(msg)
1566 pprint(msg)
1553 content = msg['content']
1567 content = msg['content']
1554 if content['status'] != 'ok':
1568 if content['status'] != 'ok':
1555 raise self._unwrap_exception(content)
1569 raise self._unwrap_exception(content)
1556
1570
1557
1571
1558 __all__ = [ 'Client',
1572 __all__ = [ 'Client',
1559 'depend',
1573 'depend',
1560 'require',
1574 'require',
1561 'remote',
1575 'remote',
1562 'parallel',
1576 'parallel',
1563 'RemoteFunction',
1577 'RemoteFunction',
1564 'ParallelFunction',
1578 'ParallelFunction',
1565 'DirectView',
1579 'DirectView',
1566 'LoadBalancedView',
1580 'LoadBalancedView',
1567 'AsyncResult',
1581 'AsyncResult',
1568 'AsyncMapResult',
1582 'AsyncMapResult',
1569 'Reference'
1583 'Reference'
1570 ]
1584 ]
@@ -1,115 +1,118 b''
1 #!/usr/bin/env python
1 #!/usr/bin/env python
2 """The IPython Controller with 0MQ
2 """The IPython Controller with 0MQ
3 This is a collection of one Hub and several Schedulers.
3 This is a collection of one Hub and several Schedulers.
4 """
4 """
5 #-----------------------------------------------------------------------------
5 #-----------------------------------------------------------------------------
6 # Copyright (C) 2010 The IPython Development Team
6 # Copyright (C) 2010 The IPython Development Team
7 #
7 #
8 # Distributed under the terms of the BSD License. The full license is in
8 # Distributed under the terms of the BSD License. The full license is in
9 # the file COPYING, distributed as part of this software.
9 # the file COPYING, distributed as part of this software.
10 #-----------------------------------------------------------------------------
10 #-----------------------------------------------------------------------------
11
11
12 #-----------------------------------------------------------------------------
12 #-----------------------------------------------------------------------------
13 # Imports
13 # Imports
14 #-----------------------------------------------------------------------------
14 #-----------------------------------------------------------------------------
15 from __future__ import print_function
15 from __future__ import print_function
16
16
17 import logging
17 import logging
18 from multiprocessing import Process
18 from multiprocessing import Process
19
19
20 import zmq
20 import zmq
21 from zmq.devices import ProcessMonitoredQueue
21 from zmq.devices import ProcessMonitoredQueue
22 # internal:
22 # internal:
23 from IPython.utils.importstring import import_item
23 from IPython.utils.importstring import import_item
24 from IPython.utils.traitlets import Int, CStr, Instance, List, Bool
24 from IPython.utils.traitlets import Int, CStr, Instance, List, Bool
25
25
26 from .entry_point import signal_children
26 from .entry_point import signal_children
27 from .hub import Hub, HubFactory
27 from .hub import Hub, HubFactory
28 from .scheduler import launch_scheduler
28 from .scheduler import launch_scheduler
29
29
30 #-----------------------------------------------------------------------------
30 #-----------------------------------------------------------------------------
31 # Configurable
31 # Configurable
32 #-----------------------------------------------------------------------------
32 #-----------------------------------------------------------------------------
33
33
34
34
35 class ControllerFactory(HubFactory):
35 class ControllerFactory(HubFactory):
36 """Configurable for setting up a Hub and Schedulers."""
36 """Configurable for setting up a Hub and Schedulers."""
37
37
38 usethreads = Bool(False, config=True)
38 usethreads = Bool(False, config=True)
39 # pure-zmq downstream HWM
39 # pure-zmq downstream HWM
40 hwm = Int(0, config=True)
40 hwm = Int(0, config=True)
41
41
42 # internal
42 # internal
43 children = List()
43 children = List()
44 mq_class = CStr('zmq.devices.ProcessMonitoredQueue')
44 mq_class = CStr('zmq.devices.ProcessMonitoredQueue')
45
45
46 def _usethreads_changed(self, name, old, new):
46 def _usethreads_changed(self, name, old, new):
47 self.mq_class = 'zmq.devices.%sMonitoredQueue'%('Thread' if new else 'Process')
47 self.mq_class = 'zmq.devices.%sMonitoredQueue'%('Thread' if new else 'Process')
48
48
49 def __init__(self, **kwargs):
49 def __init__(self, **kwargs):
50 super(ControllerFactory, self).__init__(**kwargs)
50 super(ControllerFactory, self).__init__(**kwargs)
51 self.subconstructors.append(self.construct_schedulers)
51 self.subconstructors.append(self.construct_schedulers)
52
52
53 def start(self):
53 def start(self):
54 super(ControllerFactory, self).start()
54 super(ControllerFactory, self).start()
55 child_procs = []
55 child_procs = []
56 for child in self.children:
56 for child in self.children:
57 child.start()
57 child.start()
58 if isinstance(child, ProcessMonitoredQueue):
58 if isinstance(child, ProcessMonitoredQueue):
59 child_procs.append(child.launcher)
59 child_procs.append(child.launcher)
60 elif isinstance(child, Process):
60 elif isinstance(child, Process):
61 child_procs.append(child)
61 child_procs.append(child)
62 if child_procs:
62 if child_procs:
63 signal_children(child_procs)
63 signal_children(child_procs)
64
64
65
65
66 def construct_schedulers(self):
66 def construct_schedulers(self):
67 children = self.children
67 children = self.children
68 mq = import_item(self.mq_class)
68 mq = import_item(self.mq_class)
69
69
70 maybe_inproc = 'inproc://monitor' if self.usethreads else self.monitor_url
70 maybe_inproc = 'inproc://monitor' if self.usethreads else self.monitor_url
71 # IOPub relay (in a Process)
71 # IOPub relay (in a Process)
72 q = mq(zmq.PUB, zmq.SUB, zmq.PUB, 'N/A','iopub')
72 q = mq(zmq.PUB, zmq.SUB, zmq.PUB, 'N/A','iopub')
73 q.bind_in(self.client_info['iopub'])
73 q.bind_in(self.client_info['iopub'])
74 q.bind_out(self.engine_info['iopub'])
74 q.bind_out(self.engine_info['iopub'])
75 q.setsockopt_out(zmq.SUBSCRIBE, '')
75 q.setsockopt_out(zmq.SUBSCRIBE, '')
76 q.connect_mon(maybe_inproc)
76 q.connect_mon(maybe_inproc)
77 q.daemon=True
77 q.daemon=True
78 children.append(q)
78 children.append(q)
79
79
80 # Multiplexer Queue (in a Process)
80 # Multiplexer Queue (in a Process)
81 q = mq(zmq.XREP, zmq.XREP, zmq.PUB, 'in', 'out')
81 q = mq(zmq.XREP, zmq.XREP, zmq.PUB, 'in', 'out')
82 q.bind_in(self.client_info['mux'])
82 q.bind_in(self.client_info['mux'])
83 q.setsockopt_in(zmq.IDENTITY, 'mux')
83 q.bind_out(self.engine_info['mux'])
84 q.bind_out(self.engine_info['mux'])
84 q.connect_mon(maybe_inproc)
85 q.connect_mon(maybe_inproc)
85 q.daemon=True
86 q.daemon=True
86 children.append(q)
87 children.append(q)
87
88
88 # Control Queue (in a Process)
89 # Control Queue (in a Process)
89 q = mq(zmq.XREP, zmq.XREP, zmq.PUB, 'incontrol', 'outcontrol')
90 q = mq(zmq.XREP, zmq.XREP, zmq.PUB, 'incontrol', 'outcontrol')
90 q.bind_in(self.client_info['control'])
91 q.bind_in(self.client_info['control'])
92 q.setsockopt_in(zmq.IDENTITY, 'control')
91 q.bind_out(self.engine_info['control'])
93 q.bind_out(self.engine_info['control'])
92 q.connect_mon(maybe_inproc)
94 q.connect_mon(maybe_inproc)
93 q.daemon=True
95 q.daemon=True
94 children.append(q)
96 children.append(q)
95 # Task Queue (in a Process)
97 # Task Queue (in a Process)
96 if self.scheme == 'pure':
98 if self.scheme == 'pure':
97 self.log.warn("task::using pure XREQ Task scheduler")
99 self.log.warn("task::using pure XREQ Task scheduler")
98 q = mq(zmq.XREP, zmq.XREQ, zmq.PUB, 'intask', 'outtask')
100 q = mq(zmq.XREP, zmq.XREQ, zmq.PUB, 'intask', 'outtask')
99 q.setsockopt_out(zmq.HWM, self.hwm)
101 q.setsockopt_out(zmq.HWM, self.hwm)
100 q.bind_in(self.client_info['task'][1])
102 q.bind_in(self.client_info['task'][1])
103 q.setsockopt_in(zmq.IDENTITY, 'task')
101 q.bind_out(self.engine_info['task'])
104 q.bind_out(self.engine_info['task'])
102 q.connect_mon(maybe_inproc)
105 q.connect_mon(maybe_inproc)
103 q.daemon=True
106 q.daemon=True
104 children.append(q)
107 children.append(q)
105 elif self.scheme == 'none':
108 elif self.scheme == 'none':
106 self.log.warn("task::using no Task scheduler")
109 self.log.warn("task::using no Task scheduler")
107
110
108 else:
111 else:
109 self.log.info("task::using Python %s Task scheduler"%self.scheme)
112 self.log.info("task::using Python %s Task scheduler"%self.scheme)
110 sargs = (self.client_info['task'][1], self.engine_info['task'], self.monitor_url, self.client_info['notification'])
113 sargs = (self.client_info['task'][1], self.engine_info['task'], self.monitor_url, self.client_info['notification'])
111 kwargs = dict(scheme=self.scheme,logname=self.log.name, loglevel=self.log.level, config=self.config)
114 kwargs = dict(scheme=self.scheme,logname=self.log.name, loglevel=self.log.level, config=self.config)
112 q = Process(target=launch_scheduler, args=sargs, kwargs=kwargs)
115 q = Process(target=launch_scheduler, args=sargs, kwargs=kwargs)
113 q.daemon=True
116 q.daemon=True
114 children.append(q)
117 children.append(q)
115
118
@@ -1,580 +1,584 b''
1 """The Python scheduler for rich scheduling.
1 """The Python scheduler for rich scheduling.
2
2
3 The Pure ZMQ scheduler does not allow routing schemes other than LRU,
3 The Pure ZMQ scheduler does not allow routing schemes other than LRU,
4 nor does it check msg_id DAG dependencies. For those, a slightly slower
4 nor does it check msg_id DAG dependencies. For those, a slightly slower
5 Python Scheduler exists.
5 Python Scheduler exists.
6 """
6 """
7
7
8 #----------------------------------------------------------------------
8 #----------------------------------------------------------------------
9 # Imports
9 # Imports
10 #----------------------------------------------------------------------
10 #----------------------------------------------------------------------
11
11
12 from __future__ import print_function
12 from __future__ import print_function
13
13
14 import logging
14 import logging
15 import sys
15 import sys
16
16
17 from datetime import datetime, timedelta
17 from datetime import datetime, timedelta
18 from random import randint, random
18 from random import randint, random
19 from types import FunctionType
19 from types import FunctionType
20
20
21 try:
21 try:
22 import numpy
22 import numpy
23 except ImportError:
23 except ImportError:
24 numpy = None
24 numpy = None
25
25
26 import zmq
26 import zmq
27 from zmq.eventloop import ioloop, zmqstream
27 from zmq.eventloop import ioloop, zmqstream
28
28
29 # local imports
29 # local imports
30 from IPython.external.decorator import decorator
30 from IPython.external.decorator import decorator
31 from IPython.utils.traitlets import Instance, Dict, List, Set
31 from IPython.utils.traitlets import Instance, Dict, List, Set
32
32
33 from . import error
33 from . import error
34 from .dependency import Dependency
34 from .dependency import Dependency
35 from .entry_point import connect_logger, local_logger
35 from .entry_point import connect_logger, local_logger
36 from .factory import SessionFactory
36 from .factory import SessionFactory
37
37
38
38
39 @decorator
39 @decorator
40 def logged(f,self,*args,**kwargs):
40 def logged(f,self,*args,**kwargs):
41 # print ("#--------------------")
41 # print ("#--------------------")
42 self.log.debug("scheduler::%s(*%s,**%s)"%(f.func_name, args, kwargs))
42 self.log.debug("scheduler::%s(*%s,**%s)"%(f.func_name, args, kwargs))
43 # print ("#--")
43 # print ("#--")
44 return f(self,*args, **kwargs)
44 return f(self,*args, **kwargs)
45
45
46 #----------------------------------------------------------------------
46 #----------------------------------------------------------------------
47 # Chooser functions
47 # Chooser functions
48 #----------------------------------------------------------------------
48 #----------------------------------------------------------------------
49
49
50 def plainrandom(loads):
50 def plainrandom(loads):
51 """Plain random pick."""
51 """Plain random pick."""
52 n = len(loads)
52 n = len(loads)
53 return randint(0,n-1)
53 return randint(0,n-1)
54
54
55 def lru(loads):
55 def lru(loads):
56 """Always pick the front of the line.
56 """Always pick the front of the line.
57
57
58 The content of `loads` is ignored.
58 The content of `loads` is ignored.
59
59
60 Assumes LRU ordering of loads, with oldest first.
60 Assumes LRU ordering of loads, with oldest first.
61 """
61 """
62 return 0
62 return 0
63
63
64 def twobin(loads):
64 def twobin(loads):
65 """Pick two at random, use the LRU of the two.
65 """Pick two at random, use the LRU of the two.
66
66
67 The content of loads is ignored.
67 The content of loads is ignored.
68
68
69 Assumes LRU ordering of loads, with oldest first.
69 Assumes LRU ordering of loads, with oldest first.
70 """
70 """
71 n = len(loads)
71 n = len(loads)
72 a = randint(0,n-1)
72 a = randint(0,n-1)
73 b = randint(0,n-1)
73 b = randint(0,n-1)
74 return min(a,b)
74 return min(a,b)
75
75
76 def weighted(loads):
76 def weighted(loads):
77 """Pick two at random using inverse load as weight.
77 """Pick two at random using inverse load as weight.
78
78
79 Return the less loaded of the two.
79 Return the less loaded of the two.
80 """
80 """
81 # weight 0 a million times more than 1:
81 # weight 0 a million times more than 1:
82 weights = 1./(1e-6+numpy.array(loads))
82 weights = 1./(1e-6+numpy.array(loads))
83 sums = weights.cumsum()
83 sums = weights.cumsum()
84 t = sums[-1]
84 t = sums[-1]
85 x = random()*t
85 x = random()*t
86 y = random()*t
86 y = random()*t
87 idx = 0
87 idx = 0
88 idy = 0
88 idy = 0
89 while sums[idx] < x:
89 while sums[idx] < x:
90 idx += 1
90 idx += 1
91 while sums[idy] < y:
91 while sums[idy] < y:
92 idy += 1
92 idy += 1
93 if weights[idy] > weights[idx]:
93 if weights[idy] > weights[idx]:
94 return idy
94 return idy
95 else:
95 else:
96 return idx
96 return idx
97
97
98 def leastload(loads):
98 def leastload(loads):
99 """Always choose the lowest load.
99 """Always choose the lowest load.
100
100
101 If the lowest load occurs more than once, the first
101 If the lowest load occurs more than once, the first
102 occurance will be used. If loads has LRU ordering, this means
102 occurance will be used. If loads has LRU ordering, this means
103 the LRU of those with the lowest load is chosen.
103 the LRU of those with the lowest load is chosen.
104 """
104 """
105 return loads.index(min(loads))
105 return loads.index(min(loads))
106
106
107 #---------------------------------------------------------------------
107 #---------------------------------------------------------------------
108 # Classes
108 # Classes
109 #---------------------------------------------------------------------
109 #---------------------------------------------------------------------
110 # store empty default dependency:
110 # store empty default dependency:
111 MET = Dependency([])
111 MET = Dependency([])
112
112
113 class TaskScheduler(SessionFactory):
113 class TaskScheduler(SessionFactory):
114 """Python TaskScheduler object.
114 """Python TaskScheduler object.
115
115
116 This is the simplest object that supports msg_id based
116 This is the simplest object that supports msg_id based
117 DAG dependencies. *Only* task msg_ids are checked, not
117 DAG dependencies. *Only* task msg_ids are checked, not
118 msg_ids of jobs submitted via the MUX queue.
118 msg_ids of jobs submitted via the MUX queue.
119
119
120 """
120 """
121
121
122 # input arguments:
122 # input arguments:
123 scheme = Instance(FunctionType, default=leastload) # function for determining the destination
123 scheme = Instance(FunctionType, default=leastload) # function for determining the destination
124 client_stream = Instance(zmqstream.ZMQStream) # client-facing stream
124 client_stream = Instance(zmqstream.ZMQStream) # client-facing stream
125 engine_stream = Instance(zmqstream.ZMQStream) # engine-facing stream
125 engine_stream = Instance(zmqstream.ZMQStream) # engine-facing stream
126 notifier_stream = Instance(zmqstream.ZMQStream) # hub-facing sub stream
126 notifier_stream = Instance(zmqstream.ZMQStream) # hub-facing sub stream
127 mon_stream = Instance(zmqstream.ZMQStream) # hub-facing pub stream
127 mon_stream = Instance(zmqstream.ZMQStream) # hub-facing pub stream
128
128
129 # internals:
129 # internals:
130 graph = Dict() # dict by msg_id of [ msg_ids that depend on key ]
130 graph = Dict() # dict by msg_id of [ msg_ids that depend on key ]
131 depending = Dict() # dict by msg_id of (msg_id, raw_msg, after, follow)
131 depending = Dict() # dict by msg_id of (msg_id, raw_msg, after, follow)
132 pending = Dict() # dict by engine_uuid of submitted tasks
132 pending = Dict() # dict by engine_uuid of submitted tasks
133 completed = Dict() # dict by engine_uuid of completed tasks
133 completed = Dict() # dict by engine_uuid of completed tasks
134 failed = Dict() # dict by engine_uuid of failed tasks
134 failed = Dict() # dict by engine_uuid of failed tasks
135 destinations = Dict() # dict by msg_id of engine_uuids where jobs ran (reverse of completed+failed)
135 destinations = Dict() # dict by msg_id of engine_uuids where jobs ran (reverse of completed+failed)
136 clients = Dict() # dict by msg_id for who submitted the task
136 clients = Dict() # dict by msg_id for who submitted the task
137 targets = List() # list of target IDENTs
137 targets = List() # list of target IDENTs
138 loads = List() # list of engine loads
138 loads = List() # list of engine loads
139 all_completed = Set() # set of all completed tasks
139 all_completed = Set() # set of all completed tasks
140 all_failed = Set() # set of all failed tasks
140 all_failed = Set() # set of all failed tasks
141 all_done = Set() # set of all finished tasks=union(completed,failed)
141 all_done = Set() # set of all finished tasks=union(completed,failed)
142 all_ids = Set() # set of all submitted task IDs
142 all_ids = Set() # set of all submitted task IDs
143 blacklist = Dict() # dict by msg_id of locations where a job has encountered UnmetDependency
143 blacklist = Dict() # dict by msg_id of locations where a job has encountered UnmetDependency
144 auditor = Instance('zmq.eventloop.ioloop.PeriodicCallback')
144 auditor = Instance('zmq.eventloop.ioloop.PeriodicCallback')
145
145
146
146
147 def start(self):
147 def start(self):
148 self.engine_stream.on_recv(self.dispatch_result, copy=False)
148 self.engine_stream.on_recv(self.dispatch_result, copy=False)
149 self._notification_handlers = dict(
149 self._notification_handlers = dict(
150 registration_notification = self._register_engine,
150 registration_notification = self._register_engine,
151 unregistration_notification = self._unregister_engine
151 unregistration_notification = self._unregister_engine
152 )
152 )
153 self.notifier_stream.on_recv(self.dispatch_notification)
153 self.notifier_stream.on_recv(self.dispatch_notification)
154 self.auditor = ioloop.PeriodicCallback(self.audit_timeouts, 2e3, self.loop) # 1 Hz
154 self.auditor = ioloop.PeriodicCallback(self.audit_timeouts, 2e3, self.loop) # 1 Hz
155 self.auditor.start()
155 self.auditor.start()
156 self.log.info("Scheduler started...%r"%self)
156 self.log.info("Scheduler started...%r"%self)
157
157
158 def resume_receiving(self):
158 def resume_receiving(self):
159 """Resume accepting jobs."""
159 """Resume accepting jobs."""
160 self.client_stream.on_recv(self.dispatch_submission, copy=False)
160 self.client_stream.on_recv(self.dispatch_submission, copy=False)
161
161
162 def stop_receiving(self):
162 def stop_receiving(self):
163 """Stop accepting jobs while there are no engines.
163 """Stop accepting jobs while there are no engines.
164 Leave them in the ZMQ queue."""
164 Leave them in the ZMQ queue."""
165 self.client_stream.on_recv(None)
165 self.client_stream.on_recv(None)
166
166
167 #-----------------------------------------------------------------------
167 #-----------------------------------------------------------------------
168 # [Un]Registration Handling
168 # [Un]Registration Handling
169 #-----------------------------------------------------------------------
169 #-----------------------------------------------------------------------
170
170
171 def dispatch_notification(self, msg):
171 def dispatch_notification(self, msg):
172 """dispatch register/unregister events."""
172 """dispatch register/unregister events."""
173 idents,msg = self.session.feed_identities(msg)
173 idents,msg = self.session.feed_identities(msg)
174 msg = self.session.unpack_message(msg)
174 msg = self.session.unpack_message(msg)
175 msg_type = msg['msg_type']
175 msg_type = msg['msg_type']
176 handler = self._notification_handlers.get(msg_type, None)
176 handler = self._notification_handlers.get(msg_type, None)
177 if handler is None:
177 if handler is None:
178 raise Exception("Unhandled message type: %s"%msg_type)
178 raise Exception("Unhandled message type: %s"%msg_type)
179 else:
179 else:
180 try:
180 try:
181 handler(str(msg['content']['queue']))
181 handler(str(msg['content']['queue']))
182 except KeyError:
182 except KeyError:
183 self.log.error("task::Invalid notification msg: %s"%msg)
183 self.log.error("task::Invalid notification msg: %s"%msg)
184
184
185 @logged
185 @logged
186 def _register_engine(self, uid):
186 def _register_engine(self, uid):
187 """New engine with ident `uid` became available."""
187 """New engine with ident `uid` became available."""
188 # head of the line:
188 # head of the line:
189 self.targets.insert(0,uid)
189 self.targets.insert(0,uid)
190 self.loads.insert(0,0)
190 self.loads.insert(0,0)
191 # initialize sets
191 # initialize sets
192 self.completed[uid] = set()
192 self.completed[uid] = set()
193 self.failed[uid] = set()
193 self.failed[uid] = set()
194 self.pending[uid] = {}
194 self.pending[uid] = {}
195 if len(self.targets) == 1:
195 if len(self.targets) == 1:
196 self.resume_receiving()
196 self.resume_receiving()
197
197
198 def _unregister_engine(self, uid):
198 def _unregister_engine(self, uid):
199 """Existing engine with ident `uid` became unavailable."""
199 """Existing engine with ident `uid` became unavailable."""
200 if len(self.targets) == 1:
200 if len(self.targets) == 1:
201 # this was our only engine
201 # this was our only engine
202 self.stop_receiving()
202 self.stop_receiving()
203
203
204 # handle any potentially finished tasks:
204 # handle any potentially finished tasks:
205 self.engine_stream.flush()
205 self.engine_stream.flush()
206
206
207 self.completed.pop(uid)
207 self.completed.pop(uid)
208 self.failed.pop(uid)
208 self.failed.pop(uid)
209 # don't pop destinations, because it might be used later
209 # don't pop destinations, because it might be used later
210 # map(self.destinations.pop, self.completed.pop(uid))
210 # map(self.destinations.pop, self.completed.pop(uid))
211 # map(self.destinations.pop, self.failed.pop(uid))
211 # map(self.destinations.pop, self.failed.pop(uid))
212
212
213 idx = self.targets.index(uid)
213 idx = self.targets.index(uid)
214 self.targets.pop(idx)
214 self.targets.pop(idx)
215 self.loads.pop(idx)
215 self.loads.pop(idx)
216
216
217 # wait 5 seconds before cleaning up pending jobs, since the results might
217 # wait 5 seconds before cleaning up pending jobs, since the results might
218 # still be incoming
218 # still be incoming
219 if self.pending[uid]:
219 if self.pending[uid]:
220 dc = ioloop.DelayedCallback(lambda : self.handle_stranded_tasks(uid), 5000, self.loop)
220 dc = ioloop.DelayedCallback(lambda : self.handle_stranded_tasks(uid), 5000, self.loop)
221 dc.start()
221 dc.start()
222
222
223 @logged
223 @logged
224 def handle_stranded_tasks(self, engine):
224 def handle_stranded_tasks(self, engine):
225 """Deal with jobs resident in an engine that died."""
225 """Deal with jobs resident in an engine that died."""
226 lost = self.pending.pop(engine)
226 lost = self.pending.pop(engine)
227
227
228 for msg_id, (raw_msg, targets, MET, follow, timeout) in lost.iteritems():
228 for msg_id, (raw_msg, targets, MET, follow, timeout) in lost.iteritems():
229 self.all_failed.add(msg_id)
229 self.all_failed.add(msg_id)
230 self.all_done.add(msg_id)
230 self.all_done.add(msg_id)
231 idents,msg = self.session.feed_identities(raw_msg, copy=False)
231 idents,msg = self.session.feed_identities(raw_msg, copy=False)
232 msg = self.session.unpack_message(msg, copy=False, content=False)
232 msg = self.session.unpack_message(msg, copy=False, content=False)
233 parent = msg['header']
233 parent = msg['header']
234 idents = [idents[0],engine]+idents[1:]
234 idents = [idents[0],engine]+idents[1:]
235 print (idents)
235 print (idents)
236 try:
236 try:
237 raise error.EngineError("Engine %r died while running task %r"%(engine, msg_id))
237 raise error.EngineError("Engine %r died while running task %r"%(engine, msg_id))
238 except:
238 except:
239 content = error.wrap_exception()
239 content = error.wrap_exception()
240 msg = self.session.send(self.client_stream, 'apply_reply', content,
240 msg = self.session.send(self.client_stream, 'apply_reply', content,
241 parent=parent, ident=idents)
241 parent=parent, ident=idents)
242 self.session.send(self.mon_stream, msg, ident=['outtask']+idents)
242 self.session.send(self.mon_stream, msg, ident=['outtask']+idents)
243 self.update_graph(msg_id)
243 self.update_graph(msg_id)
244
244
245
245
246 #-----------------------------------------------------------------------
246 #-----------------------------------------------------------------------
247 # Job Submission
247 # Job Submission
248 #-----------------------------------------------------------------------
248 #-----------------------------------------------------------------------
249 @logged
249 @logged
250 def dispatch_submission(self, raw_msg):
250 def dispatch_submission(self, raw_msg):
251 """Dispatch job submission to appropriate handlers."""
251 """Dispatch job submission to appropriate handlers."""
252 # ensure targets up to date:
252 # ensure targets up to date:
253 self.notifier_stream.flush()
253 self.notifier_stream.flush()
254 try:
254 try:
255 idents, msg = self.session.feed_identities(raw_msg, copy=False)
255 idents, msg = self.session.feed_identities(raw_msg, copy=False)
256 msg = self.session.unpack_message(msg, content=False, copy=False)
256 msg = self.session.unpack_message(msg, content=False, copy=False)
257 except:
257 except:
258 self.log.error("task::Invaid task: %s"%raw_msg, exc_info=True)
258 self.log.error("task::Invaid task: %s"%raw_msg, exc_info=True)
259 return
259 return
260
260
261 # send to monitor
261 # send to monitor
262 self.mon_stream.send_multipart(['intask']+raw_msg, copy=False)
262 self.mon_stream.send_multipart(['intask']+raw_msg, copy=False)
263
263
264 header = msg['header']
264 header = msg['header']
265 msg_id = header['msg_id']
265 msg_id = header['msg_id']
266 self.all_ids.add(msg_id)
266 self.all_ids.add(msg_id)
267
267
268 # targets
268 # targets
269 targets = set(header.get('targets', []))
269 targets = set(header.get('targets', []))
270
270
271 # time dependencies
271 # time dependencies
272 after = Dependency(header.get('after', []))
272 after = Dependency(header.get('after', []))
273 if after.all:
273 if after.all:
274 after.difference_update(self.all_completed)
274 after.difference_update(self.all_completed)
275 if not after.success_only:
275 if not after.success_only:
276 after.difference_update(self.all_failed)
276 after.difference_update(self.all_failed)
277 if after.check(self.all_completed, self.all_failed):
277 if after.check(self.all_completed, self.all_failed):
278 # recast as empty set, if `after` already met,
278 # recast as empty set, if `after` already met,
279 # to prevent unnecessary set comparisons
279 # to prevent unnecessary set comparisons
280 after = MET
280 after = MET
281
281
282 # location dependencies
282 # location dependencies
283 follow = Dependency(header.get('follow', []))
283 follow = Dependency(header.get('follow', []))
284
284
285 # turn timeouts into datetime objects:
285 # turn timeouts into datetime objects:
286 timeout = header.get('timeout', None)
286 timeout = header.get('timeout', None)
287 if timeout:
287 if timeout:
288 timeout = datetime.now() + timedelta(0,timeout,0)
288 timeout = datetime.now() + timedelta(0,timeout,0)
289
289
290 args = [raw_msg, targets, after, follow, timeout]
290 args = [raw_msg, targets, after, follow, timeout]
291
291
292 # validate and reduce dependencies:
292 # validate and reduce dependencies:
293 for dep in after,follow:
293 for dep in after,follow:
294 # check valid:
294 # check valid:
295 if msg_id in dep or dep.difference(self.all_ids):
295 if msg_id in dep or dep.difference(self.all_ids):
296 self.depending[msg_id] = args
296 self.depending[msg_id] = args
297 return self.fail_unreachable(msg_id, error.InvalidDependency)
297 return self.fail_unreachable(msg_id, error.InvalidDependency)
298 # check if unreachable:
298 # check if unreachable:
299 if dep.unreachable(self.all_failed):
299 if dep.unreachable(self.all_failed):
300 self.depending[msg_id] = args
300 self.depending[msg_id] = args
301 return self.fail_unreachable(msg_id)
301 return self.fail_unreachable(msg_id)
302
302
303 if after.check(self.all_completed, self.all_failed):
303 if after.check(self.all_completed, self.all_failed):
304 # time deps already met, try to run
304 # time deps already met, try to run
305 if not self.maybe_run(msg_id, *args):
305 if not self.maybe_run(msg_id, *args):
306 # can't run yet
306 # can't run yet
307 self.save_unmet(msg_id, *args)
307 self.save_unmet(msg_id, *args)
308 else:
308 else:
309 self.save_unmet(msg_id, *args)
309 self.save_unmet(msg_id, *args)
310
310
311 # @logged
311 # @logged
312 def audit_timeouts(self):
312 def audit_timeouts(self):
313 """Audit all waiting tasks for expired timeouts."""
313 """Audit all waiting tasks for expired timeouts."""
314 now = datetime.now()
314 now = datetime.now()
315 for msg_id in self.depending.keys():
315 for msg_id in self.depending.keys():
316 # must recheck, in case one failure cascaded to another:
316 # must recheck, in case one failure cascaded to another:
317 if msg_id in self.depending:
317 if msg_id in self.depending:
318 raw,after,targets,follow,timeout = self.depending[msg_id]
318 raw,after,targets,follow,timeout = self.depending[msg_id]
319 if timeout and timeout < now:
319 if timeout and timeout < now:
320 self.fail_unreachable(msg_id, timeout=True)
320 self.fail_unreachable(msg_id, timeout=True)
321
321
322 @logged
322 @logged
323 def fail_unreachable(self, msg_id, why=error.ImpossibleDependency):
323 def fail_unreachable(self, msg_id, why=error.ImpossibleDependency):
324 """a task has become unreachable, send a reply with an ImpossibleDependency
324 """a task has become unreachable, send a reply with an ImpossibleDependency
325 error."""
325 error."""
326 if msg_id not in self.depending:
326 if msg_id not in self.depending:
327 self.log.error("msg %r already failed!"%msg_id)
327 self.log.error("msg %r already failed!"%msg_id)
328 return
328 return
329 raw_msg,targets,after,follow,timeout = self.depending.pop(msg_id)
329 raw_msg,targets,after,follow,timeout = self.depending.pop(msg_id)
330 for mid in follow.union(after):
330 for mid in follow.union(after):
331 if mid in self.graph:
331 if mid in self.graph:
332 self.graph[mid].remove(msg_id)
332 self.graph[mid].remove(msg_id)
333
333
334 # FIXME: unpacking a message I've already unpacked, but didn't save:
334 # FIXME: unpacking a message I've already unpacked, but didn't save:
335 idents,msg = self.session.feed_identities(raw_msg, copy=False)
335 idents,msg = self.session.feed_identities(raw_msg, copy=False)
336 msg = self.session.unpack_message(msg, copy=False, content=False)
336 msg = self.session.unpack_message(msg, copy=False, content=False)
337 header = msg['header']
337 header = msg['header']
338
338
339 try:
339 try:
340 raise why()
340 raise why()
341 except:
341 except:
342 content = error.wrap_exception()
342 content = error.wrap_exception()
343
343
344 self.all_done.add(msg_id)
344 self.all_done.add(msg_id)
345 self.all_failed.add(msg_id)
345 self.all_failed.add(msg_id)
346
346
347 msg = self.session.send(self.client_stream, 'apply_reply', content,
347 msg = self.session.send(self.client_stream, 'apply_reply', content,
348 parent=header, ident=idents)
348 parent=header, ident=idents)
349 self.session.send(self.mon_stream, msg, ident=['outtask']+idents)
349 self.session.send(self.mon_stream, msg, ident=['outtask']+idents)
350
350
351 self.update_graph(msg_id, success=False)
351 self.update_graph(msg_id, success=False)
352
352
353 @logged
353 @logged
354 def maybe_run(self, msg_id, raw_msg, targets, after, follow, timeout):
354 def maybe_run(self, msg_id, raw_msg, targets, after, follow, timeout):
355 """check location dependencies, and run if they are met."""
355 """check location dependencies, and run if they are met."""
356 blacklist = self.blacklist.setdefault(msg_id, set())
356 blacklist = self.blacklist.setdefault(msg_id, set())
357 if follow or targets or blacklist:
357 if follow or targets or blacklist:
358 # we need a can_run filter
358 # we need a can_run filter
359 def can_run(idx):
359 def can_run(idx):
360 target = self.targets[idx]
360 target = self.targets[idx]
361 # check targets
361 # check targets
362 if targets and target not in targets:
362 if targets and target not in targets:
363 return False
363 return False
364 # check blacklist
364 # check blacklist
365 if target in blacklist:
365 if target in blacklist:
366 return False
366 return False
367 # check follow
367 # check follow
368 return follow.check(self.completed[target], self.failed[target])
368 return follow.check(self.completed[target], self.failed[target])
369
369
370 indices = filter(can_run, range(len(self.targets)))
370 indices = filter(can_run, range(len(self.targets)))
371 if not indices:
371 if not indices:
372 # couldn't run
372 # couldn't run
373 if follow.all:
373 if follow.all:
374 # check follow for impossibility
374 # check follow for impossibility
375 dests = set()
375 dests = set()
376 relevant = self.all_completed if follow.success_only else self.all_done
376 relevant = self.all_completed if follow.success_only else self.all_done
377 for m in follow.intersection(relevant):
377 for m in follow.intersection(relevant):
378 dests.add(self.destinations[m])
378 dests.add(self.destinations[m])
379 if len(dests) > 1:
379 if len(dests) > 1:
380 self.fail_unreachable(msg_id)
380 self.fail_unreachable(msg_id)
381 return False
381 return False
382 if targets:
382 if targets:
383 # check blacklist+targets for impossibility
383 # check blacklist+targets for impossibility
384 targets.difference_update(blacklist)
384 targets.difference_update(blacklist)
385 if not targets or not targets.intersection(self.targets):
385 if not targets or not targets.intersection(self.targets):
386 self.fail_unreachable(msg_id)
386 self.fail_unreachable(msg_id)
387 return False
387 return False
388 return False
388 return False
389 else:
389 else:
390 indices = None
390 indices = None
391
391
392 self.submit_task(msg_id, raw_msg, targets, follow, timeout, indices)
392 self.submit_task(msg_id, raw_msg, targets, follow, timeout, indices)
393 return True
393 return True
394
394
395 @logged
395 @logged
396 def save_unmet(self, msg_id, raw_msg, targets, after, follow, timeout):
396 def save_unmet(self, msg_id, raw_msg, targets, after, follow, timeout):
397 """Save a message for later submission when its dependencies are met."""
397 """Save a message for later submission when its dependencies are met."""
398 self.depending[msg_id] = [raw_msg,targets,after,follow,timeout]
398 self.depending[msg_id] = [raw_msg,targets,after,follow,timeout]
399 # track the ids in follow or after, but not those already finished
399 # track the ids in follow or after, but not those already finished
400 for dep_id in after.union(follow).difference(self.all_done):
400 for dep_id in after.union(follow).difference(self.all_done):
401 if dep_id not in self.graph:
401 if dep_id not in self.graph:
402 self.graph[dep_id] = set()
402 self.graph[dep_id] = set()
403 self.graph[dep_id].add(msg_id)
403 self.graph[dep_id].add(msg_id)
404
404
405 @logged
405 @logged
406 def submit_task(self, msg_id, raw_msg, targets, follow, timeout, indices=None):
406 def submit_task(self, msg_id, raw_msg, targets, follow, timeout, indices=None):
407 """Submit a task to any of a subset of our targets."""
407 """Submit a task to any of a subset of our targets."""
408 if indices:
408 if indices:
409 loads = [self.loads[i] for i in indices]
409 loads = [self.loads[i] for i in indices]
410 else:
410 else:
411 loads = self.loads
411 loads = self.loads
412 idx = self.scheme(loads)
412 idx = self.scheme(loads)
413 if indices:
413 if indices:
414 idx = indices[idx]
414 idx = indices[idx]
415 target = self.targets[idx]
415 target = self.targets[idx]
416 # print (target, map(str, msg[:3]))
416 # print (target, map(str, msg[:3]))
417 self.engine_stream.send(target, flags=zmq.SNDMORE, copy=False)
417 self.engine_stream.send(target, flags=zmq.SNDMORE, copy=False)
418 self.engine_stream.send_multipart(raw_msg, copy=False)
418 self.engine_stream.send_multipart(raw_msg, copy=False)
419 self.add_job(idx)
419 self.add_job(idx)
420 self.pending[target][msg_id] = (raw_msg, targets, MET, follow, timeout)
420 self.pending[target][msg_id] = (raw_msg, targets, MET, follow, timeout)
421 content = dict(msg_id=msg_id, engine_id=target)
421 content = dict(msg_id=msg_id, engine_id=target)
422 self.session.send(self.mon_stream, 'task_destination', content=content,
422 self.session.send(self.mon_stream, 'task_destination', content=content,
423 ident=['tracktask',self.session.session])
423 ident=['tracktask',self.session.session])
424
424
425 #-----------------------------------------------------------------------
425 #-----------------------------------------------------------------------
426 # Result Handling
426 # Result Handling
427 #-----------------------------------------------------------------------
427 #-----------------------------------------------------------------------
428 @logged
428 @logged
429 def dispatch_result(self, raw_msg):
429 def dispatch_result(self, raw_msg):
430 """dispatch method for result replies"""
430 """dispatch method for result replies"""
431 try:
431 try:
432 idents,msg = self.session.feed_identities(raw_msg, copy=False)
432 idents,msg = self.session.feed_identities(raw_msg, copy=False)
433 msg = self.session.unpack_message(msg, content=False, copy=False)
433 msg = self.session.unpack_message(msg, content=False, copy=False)
434 except:
434 except:
435 self.log.error("task::Invaid result: %s"%raw_msg, exc_info=True)
435 self.log.error("task::Invaid result: %s"%raw_msg, exc_info=True)
436 return
436 return
437
437
438 header = msg['header']
438 header = msg['header']
439 if header.get('dependencies_met', True):
439 if header.get('dependencies_met', True):
440 success = (header['status'] == 'ok')
440 success = (header['status'] == 'ok')
441 self.handle_result(idents, msg['parent_header'], raw_msg, success)
441 self.handle_result(idents, msg['parent_header'], raw_msg, success)
442 # send to Hub monitor
442 # send to Hub monitor
443 self.mon_stream.send_multipart(['outtask']+raw_msg, copy=False)
443 self.mon_stream.send_multipart(['outtask']+raw_msg, copy=False)
444 else:
444 else:
445 self.handle_unmet_dependency(idents, msg['parent_header'])
445 self.handle_unmet_dependency(idents, msg['parent_header'])
446
446
447 @logged
447 @logged
448 def handle_result(self, idents, parent, raw_msg, success=True):
448 def handle_result(self, idents, parent, raw_msg, success=True):
449 """handle a real task result, either success or failure"""
449 """handle a real task result, either success or failure"""
450 # first, relay result to client
450 # first, relay result to client
451 engine = idents[0]
451 engine = idents[0]
452 client = idents[1]
452 client = idents[1]
453 # swap_ids for XREP-XREP mirror
453 # swap_ids for XREP-XREP mirror
454 raw_msg[:2] = [client,engine]
454 raw_msg[:2] = [client,engine]
455 # print (map(str, raw_msg[:4]))
455 # print (map(str, raw_msg[:4]))
456 self.client_stream.send_multipart(raw_msg, copy=False)
456 self.client_stream.send_multipart(raw_msg, copy=False)
457 # now, update our data structures
457 # now, update our data structures
458 msg_id = parent['msg_id']
458 msg_id = parent['msg_id']
459 self.blacklist.pop(msg_id, None)
459 self.blacklist.pop(msg_id, None)
460 self.pending[engine].pop(msg_id)
460 self.pending[engine].pop(msg_id)
461 if success:
461 if success:
462 self.completed[engine].add(msg_id)
462 self.completed[engine].add(msg_id)
463 self.all_completed.add(msg_id)
463 self.all_completed.add(msg_id)
464 else:
464 else:
465 self.failed[engine].add(msg_id)
465 self.failed[engine].add(msg_id)
466 self.all_failed.add(msg_id)
466 self.all_failed.add(msg_id)
467 self.all_done.add(msg_id)
467 self.all_done.add(msg_id)
468 self.destinations[msg_id] = engine
468 self.destinations[msg_id] = engine
469
469
470 self.update_graph(msg_id, success)
470 self.update_graph(msg_id, success)
471
471
472 @logged
472 @logged
473 def handle_unmet_dependency(self, idents, parent):
473 def handle_unmet_dependency(self, idents, parent):
474 """handle an unmet dependency"""
474 """handle an unmet dependency"""
475 engine = idents[0]
475 engine = idents[0]
476 msg_id = parent['msg_id']
476 msg_id = parent['msg_id']
477
477
478 if msg_id not in self.blacklist:
478 if msg_id not in self.blacklist:
479 self.blacklist[msg_id] = set()
479 self.blacklist[msg_id] = set()
480 self.blacklist[msg_id].add(engine)
480 self.blacklist[msg_id].add(engine)
481
481
482 args = self.pending[engine].pop(msg_id)
482 args = self.pending[engine].pop(msg_id)
483 raw,targets,after,follow,timeout = args
483 raw,targets,after,follow,timeout = args
484
484
485 if self.blacklist[msg_id] == targets:
485 if self.blacklist[msg_id] == targets:
486 self.depending[msg_id] = args
486 self.depending[msg_id] = args
487 return self.fail_unreachable(msg_id)
487 return self.fail_unreachable(msg_id)
488
488
489 elif not self.maybe_run(msg_id, *args):
489 elif not self.maybe_run(msg_id, *args):
490 # resubmit failed, put it back in our dependency tree
490 # resubmit failed, put it back in our dependency tree
491 self.save_unmet(msg_id, *args)
491 self.save_unmet(msg_id, *args)
492
492
493
493
494 @logged
494 @logged
495 def update_graph(self, dep_id, success=True):
495 def update_graph(self, dep_id, success=True):
496 """dep_id just finished. Update our dependency
496 """dep_id just finished. Update our dependency
497 graph and submit any jobs that just became runable."""
497 graph and submit any jobs that just became runable."""
498 # print ("\n\n***********")
498 # print ("\n\n***********")
499 # pprint (dep_id)
499 # pprint (dep_id)
500 # pprint (self.graph)
500 # pprint (self.graph)
501 # pprint (self.depending)
501 # pprint (self.depending)
502 # pprint (self.all_completed)
502 # pprint (self.all_completed)
503 # pprint (self.all_failed)
503 # pprint (self.all_failed)
504 # print ("\n\n***********\n\n")
504 # print ("\n\n***********\n\n")
505 if dep_id not in self.graph:
505 if dep_id not in self.graph:
506 return
506 return
507 jobs = self.graph.pop(dep_id)
507 jobs = self.graph.pop(dep_id)
508
508
509 for msg_id in jobs:
509 for msg_id in jobs:
510 raw_msg, targets, after, follow, timeout = self.depending[msg_id]
510 raw_msg, targets, after, follow, timeout = self.depending[msg_id]
511 # if dep_id in after:
511 # if dep_id in after:
512 # if after.all and (success or not after.success_only):
512 # if after.all and (success or not after.success_only):
513 # after.remove(dep_id)
513 # after.remove(dep_id)
514
514
515 if after.unreachable(self.all_failed) or follow.unreachable(self.all_failed):
515 if after.unreachable(self.all_failed) or follow.unreachable(self.all_failed):
516 self.fail_unreachable(msg_id)
516 self.fail_unreachable(msg_id)
517
517
518 elif after.check(self.all_completed, self.all_failed): # time deps met, maybe run
518 elif after.check(self.all_completed, self.all_failed): # time deps met, maybe run
519 if self.maybe_run(msg_id, raw_msg, targets, MET, follow, timeout):
519 if self.maybe_run(msg_id, raw_msg, targets, MET, follow, timeout):
520
520
521 self.depending.pop(msg_id)
521 self.depending.pop(msg_id)
522 for mid in follow.union(after):
522 for mid in follow.union(after):
523 if mid in self.graph:
523 if mid in self.graph:
524 self.graph[mid].remove(msg_id)
524 self.graph[mid].remove(msg_id)
525
525
526 #----------------------------------------------------------------------
526 #----------------------------------------------------------------------
527 # methods to be overridden by subclasses
527 # methods to be overridden by subclasses
528 #----------------------------------------------------------------------
528 #----------------------------------------------------------------------
529
529
530 def add_job(self, idx):
530 def add_job(self, idx):
531 """Called after self.targets[idx] just got the job with header.
531 """Called after self.targets[idx] just got the job with header.
532 Override with subclasses. The default ordering is simple LRU.
532 Override with subclasses. The default ordering is simple LRU.
533 The default loads are the number of outstanding jobs."""
533 The default loads are the number of outstanding jobs."""
534 self.loads[idx] += 1
534 self.loads[idx] += 1
535 for lis in (self.targets, self.loads):
535 for lis in (self.targets, self.loads):
536 lis.append(lis.pop(idx))
536 lis.append(lis.pop(idx))
537
537
538
538
539 def finish_job(self, idx):
539 def finish_job(self, idx):
540 """Called after self.targets[idx] just finished a job.
540 """Called after self.targets[idx] just finished a job.
541 Override with subclasses."""
541 Override with subclasses."""
542 self.loads[idx] -= 1
542 self.loads[idx] -= 1
543
543
544
544
545
545
546 def launch_scheduler(in_addr, out_addr, mon_addr, not_addr, config=None,logname='ZMQ',
546 def launch_scheduler(in_addr, out_addr, mon_addr, not_addr, config=None,logname='ZMQ',
547 log_addr=None, loglevel=logging.DEBUG, scheme='lru'):
547 log_addr=None, loglevel=logging.DEBUG, scheme='lru',
548 identity=b'task'):
548 from zmq.eventloop import ioloop
549 from zmq.eventloop import ioloop
549 from zmq.eventloop.zmqstream import ZMQStream
550 from zmq.eventloop.zmqstream import ZMQStream
550
551
551 ctx = zmq.Context()
552 ctx = zmq.Context()
552 loop = ioloop.IOLoop()
553 loop = ioloop.IOLoop()
553 print (in_addr, out_addr, mon_addr, not_addr)
554 print (in_addr, out_addr, mon_addr, not_addr)
554 ins = ZMQStream(ctx.socket(zmq.XREP),loop)
555 ins = ZMQStream(ctx.socket(zmq.XREP),loop)
556 ins.setsockopt(zmq.IDENTITY, identity)
555 ins.bind(in_addr)
557 ins.bind(in_addr)
558
556 outs = ZMQStream(ctx.socket(zmq.XREP),loop)
559 outs = ZMQStream(ctx.socket(zmq.XREP),loop)
560 outs.setsockopt(zmq.IDENTITY, identity)
557 outs.bind(out_addr)
561 outs.bind(out_addr)
558 mons = ZMQStream(ctx.socket(zmq.PUB),loop)
562 mons = ZMQStream(ctx.socket(zmq.PUB),loop)
559 mons.connect(mon_addr)
563 mons.connect(mon_addr)
560 nots = ZMQStream(ctx.socket(zmq.SUB),loop)
564 nots = ZMQStream(ctx.socket(zmq.SUB),loop)
561 nots.setsockopt(zmq.SUBSCRIBE, '')
565 nots.setsockopt(zmq.SUBSCRIBE, '')
562 nots.connect(not_addr)
566 nots.connect(not_addr)
563
567
564 scheme = globals().get(scheme, None)
568 scheme = globals().get(scheme, None)
565 # setup logging
569 # setup logging
566 if log_addr:
570 if log_addr:
567 connect_logger(logname, ctx, log_addr, root="scheduler", loglevel=loglevel)
571 connect_logger(logname, ctx, log_addr, root="scheduler", loglevel=loglevel)
568 else:
572 else:
569 local_logger(logname, loglevel)
573 local_logger(logname, loglevel)
570
574
571 scheduler = TaskScheduler(client_stream=ins, engine_stream=outs,
575 scheduler = TaskScheduler(client_stream=ins, engine_stream=outs,
572 mon_stream=mons, notifier_stream=nots,
576 mon_stream=mons, notifier_stream=nots,
573 scheme=scheme, loop=loop, logname=logname,
577 scheme=scheme, loop=loop, logname=logname,
574 config=config)
578 config=config)
575 scheduler.start()
579 scheduler.start()
576 try:
580 try:
577 loop.start()
581 loop.start()
578 except KeyboardInterrupt:
582 except KeyboardInterrupt:
579 print ("interrupted, exiting...", file=sys.__stderr__)
583 print ("interrupted, exiting...", file=sys.__stderr__)
580
584
@@ -1,48 +1,48 b''
1 """toplevel setup/teardown for parallel tests."""
1 """toplevel setup/teardown for parallel tests."""
2
2
3 import tempfile
3 import tempfile
4 import time
4 import time
5 from subprocess import Popen, PIPE, STDOUT
5 from subprocess import Popen, PIPE, STDOUT
6
6
7 from IPython.zmq.parallel import client
7 from IPython.zmq.parallel import client
8
8
9 processes = []
9 processes = []
10 blackhole = tempfile.TemporaryFile()
10 blackhole = tempfile.TemporaryFile()
11
11
12 # nose setup/teardown
12 # nose setup/teardown
13
13
14 def setup():
14 def setup():
15 cp = Popen('ipcontrollerz --profile iptest -r --log-level 40'.split(), stdout=blackhole, stderr=STDOUT)
15 cp = Popen('ipcontrollerz --profile iptest -r --log-level 10 --log-to-file'.split(), stdout=blackhole, stderr=STDOUT)
16 processes.append(cp)
16 processes.append(cp)
17 time.sleep(.5)
17 time.sleep(.5)
18 add_engine()
18 add_engine()
19 c = client.Client(profile='iptest')
19 c = client.Client(profile='iptest')
20 while not c.ids:
20 while not c.ids:
21 time.sleep(.1)
21 time.sleep(.1)
22 c.spin()
22 c.spin()
23
23
24 def add_engine(profile='iptest'):
24 def add_engine(profile='iptest'):
25 ep = Popen(['ipenginez']+ ['--profile', profile, '--log-level', '40'], stdout=blackhole, stderr=STDOUT)
25 ep = Popen(['ipenginez']+ ['--profile', profile, '--log-level', '10', '--log-to-file'], stdout=blackhole, stderr=STDOUT)
26 # ep.start()
26 # ep.start()
27 processes.append(ep)
27 processes.append(ep)
28 return ep
28 return ep
29
29
30 def teardown():
30 def teardown():
31 time.sleep(1)
31 time.sleep(1)
32 while processes:
32 while processes:
33 p = processes.pop()
33 p = processes.pop()
34 if p.poll() is None:
34 if p.poll() is None:
35 try:
35 try:
36 p.terminate()
36 p.terminate()
37 except Exception, e:
37 except Exception, e:
38 print e
38 print e
39 pass
39 pass
40 if p.poll() is None:
40 if p.poll() is None:
41 time.sleep(.25)
41 time.sleep(.25)
42 if p.poll() is None:
42 if p.poll() is None:
43 try:
43 try:
44 print 'killing'
44 print 'killing'
45 p.kill()
45 p.kill()
46 except:
46 except:
47 print "couldn't shutdown process: ", p
47 print "couldn't shutdown process: ", p
48
48
@@ -1,100 +1,105 b''
1 import sys
2 import tempfile
1 import time
3 import time
2 from signal import SIGINT
4 from signal import SIGINT
3 from multiprocessing import Process
5 from multiprocessing import Process
4
6
5 from nose import SkipTest
7 from nose import SkipTest
6
8
7 from zmq.tests import BaseZMQTestCase
9 from zmq.tests import BaseZMQTestCase
8
10
9 from IPython.external.decorator import decorator
11 from IPython.external.decorator import decorator
10
12
11 from IPython.zmq.parallel import error
13 from IPython.zmq.parallel import error
12 from IPython.zmq.parallel.client import Client
14 from IPython.zmq.parallel.client import Client
13 from IPython.zmq.parallel.ipcluster import launch_process
15 from IPython.zmq.parallel.ipcluster import launch_process
14 from IPython.zmq.parallel.entry_point import select_random_ports
16 from IPython.zmq.parallel.entry_point import select_random_ports
15 from IPython.zmq.parallel.tests import processes,add_engine
17 from IPython.zmq.parallel.tests import processes,add_engine
16
18
17 # simple tasks for use in apply tests
19 # simple tasks for use in apply tests
18
20
19 def segfault():
21 def segfault():
20 """this will segfault"""
22 """this will segfault"""
21 import ctypes
23 import ctypes
22 ctypes.memset(-1,0,1)
24 ctypes.memset(-1,0,1)
23
25
24 def wait(n):
26 def wait(n):
25 """sleep for a time"""
27 """sleep for a time"""
26 import time
28 import time
27 time.sleep(n)
29 time.sleep(n)
28 return n
30 return n
29
31
30 def raiser(eclass):
32 def raiser(eclass):
31 """raise an exception"""
33 """raise an exception"""
32 raise eclass()
34 raise eclass()
33
35
34 # test decorator for skipping tests when libraries are unavailable
36 # test decorator for skipping tests when libraries are unavailable
35 def skip_without(*names):
37 def skip_without(*names):
36 """skip a test if some names are not importable"""
38 """skip a test if some names are not importable"""
37 @decorator
39 @decorator
38 def skip_without_names(f, *args, **kwargs):
40 def skip_without_names(f, *args, **kwargs):
39 """decorator to skip tests in the absence of numpy."""
41 """decorator to skip tests in the absence of numpy."""
40 for name in names:
42 for name in names:
41 try:
43 try:
42 __import__(name)
44 __import__(name)
43 except ImportError:
45 except ImportError:
44 raise SkipTest
46 raise SkipTest
45 return f(*args, **kwargs)
47 return f(*args, **kwargs)
46 return skip_without_names
48 return skip_without_names
47
49
48
50
49 class ClusterTestCase(BaseZMQTestCase):
51 class ClusterTestCase(BaseZMQTestCase):
50
52
51 def add_engines(self, n=1, block=True):
53 def add_engines(self, n=1, block=True):
52 """add multiple engines to our cluster"""
54 """add multiple engines to our cluster"""
53 for i in range(n):
55 for i in range(n):
54 self.engines.append(add_engine())
56 self.engines.append(add_engine())
55 if block:
57 if block:
56 self.wait_on_engines()
58 self.wait_on_engines()
57
59
58 def wait_on_engines(self, timeout=5):
60 def wait_on_engines(self, timeout=5):
59 """wait for our engines to connect."""
61 """wait for our engines to connect."""
60 n = len(self.engines)+self.base_engine_count
62 n = len(self.engines)+self.base_engine_count
61 tic = time.time()
63 tic = time.time()
62 while time.time()-tic < timeout and len(self.client.ids) < n:
64 while time.time()-tic < timeout and len(self.client.ids) < n:
63 time.sleep(0.1)
65 time.sleep(0.1)
64
66
65 assert not self.client.ids < n, "waiting for engines timed out"
67 assert not len(self.client.ids) < n, "waiting for engines timed out"
66
68
67 def connect_client(self):
69 def connect_client(self):
68 """connect a client with my Context, and track its sockets for cleanup"""
70 """connect a client with my Context, and track its sockets for cleanup"""
69 c = Client(profile='iptest',context=self.context)
71 c = Client(profile='iptest',context=self.context)
70 for name in filter(lambda n:n.endswith('socket'), dir(c)):
72 for name in filter(lambda n:n.endswith('socket'), dir(c)):
71 self.sockets.append(getattr(c, name))
73 self.sockets.append(getattr(c, name))
72 return c
74 return c
73
75
74 def assertRaisesRemote(self, etype, f, *args, **kwargs):
76 def assertRaisesRemote(self, etype, f, *args, **kwargs):
75 try:
77 try:
76 try:
78 try:
77 f(*args, **kwargs)
79 f(*args, **kwargs)
78 except error.CompositeError as e:
80 except error.CompositeError as e:
79 e.raise_exception()
81 e.raise_exception()
80 except error.RemoteError as e:
82 except error.RemoteError as e:
81 self.assertEquals(etype.__name__, e.ename, "Should have raised %r, but raised %r"%(e.ename, etype.__name__))
83 self.assertEquals(etype.__name__, e.ename, "Should have raised %r, but raised %r"%(e.ename, etype.__name__))
82 else:
84 else:
83 self.fail("should have raised a RemoteError")
85 self.fail("should have raised a RemoteError")
84
86
85 def setUp(self):
87 def setUp(self):
86 BaseZMQTestCase.setUp(self)
88 BaseZMQTestCase.setUp(self)
87 self.client = self.connect_client()
89 self.client = self.connect_client()
88 self.base_engine_count=len(self.client.ids)
90 self.base_engine_count=len(self.client.ids)
89 self.engines=[]
91 self.engines=[]
90
92
91 def tearDown(self):
93 def tearDown(self):
94
95 # close fds:
96 for e in filter(lambda e: e.poll() is not None, processes):
97 processes.remove(e)
98
92 self.client.close()
99 self.client.close()
93 BaseZMQTestCase.tearDown(self)
100 BaseZMQTestCase.tearDown(self)
94 # [ e.terminate() for e in filter(lambda e: e.poll() is None, self.engines) ]
101 # this will be superfluous when pyzmq merges PR #88
95 # [ e.wait() for e in self.engines ]
102 self.context.term()
96 # while len(self.client.ids) > self.base_engine_count:
103 print tempfile.TemporaryFile().fileno(),
97 # time.sleep(.1)
104 sys.stdout.flush()
98 # del self.engines
99 # BaseZMQTestCase.tearDown(self)
100 No newline at end of file
105
@@ -1,262 +1,262 b''
1 import time
1 import time
2 from tempfile import mktemp
2 from tempfile import mktemp
3
3
4 import nose.tools as nt
4 import nose.tools as nt
5 import zmq
5 import zmq
6
6
7 from IPython.zmq.parallel import client as clientmod
7 from IPython.zmq.parallel import client as clientmod
8 from IPython.zmq.parallel import error
8 from IPython.zmq.parallel import error
9 from IPython.zmq.parallel.asyncresult import AsyncResult, AsyncHubResult
9 from IPython.zmq.parallel.asyncresult import AsyncResult, AsyncHubResult
10 from IPython.zmq.parallel.view import LoadBalancedView, DirectView
10 from IPython.zmq.parallel.view import LoadBalancedView, DirectView
11
11
12 from clienttest import ClusterTestCase, segfault, wait
12 from clienttest import ClusterTestCase, segfault, wait
13
13
14 class TestClient(ClusterTestCase):
14 class TestClient(ClusterTestCase):
15
15
16 def test_ids(self):
16 def test_ids(self):
17 n = len(self.client.ids)
17 n = len(self.client.ids)
18 self.add_engines(3)
18 self.add_engines(3)
19 self.assertEquals(len(self.client.ids), n+3)
19 self.assertEquals(len(self.client.ids), n+3)
20 self.assertTrue
21
20
22 def test_segfault_task(self):
21 def test_segfault_task(self):
23 """test graceful handling of engine death (balanced)"""
22 """test graceful handling of engine death (balanced)"""
24 self.add_engines(1)
23 self.add_engines(1)
25 ar = self.client.apply(segfault, block=False)
24 ar = self.client.apply(segfault, block=False)
26 self.assertRaisesRemote(error.EngineError, ar.get)
25 self.assertRaisesRemote(error.EngineError, ar.get)
27 eid = ar.engine_id
26 eid = ar.engine_id
28 while eid in self.client.ids:
27 while eid in self.client.ids:
29 time.sleep(.01)
28 time.sleep(.01)
30 self.client.spin()
29 self.client.spin()
31
30
32 def test_segfault_mux(self):
31 def test_segfault_mux(self):
33 """test graceful handling of engine death (direct)"""
32 """test graceful handling of engine death (direct)"""
34 self.add_engines(1)
33 self.add_engines(1)
35 eid = self.client.ids[-1]
34 eid = self.client.ids[-1]
36 ar = self.client[eid].apply_async(segfault)
35 ar = self.client[eid].apply_async(segfault)
37 self.assertRaisesRemote(error.EngineError, ar.get)
36 self.assertRaisesRemote(error.EngineError, ar.get)
38 eid = ar.engine_id
37 eid = ar.engine_id
39 while eid in self.client.ids:
38 while eid in self.client.ids:
40 time.sleep(.01)
39 time.sleep(.01)
41 self.client.spin()
40 self.client.spin()
42
41
43 def test_view_indexing(self):
42 def test_view_indexing(self):
44 """test index access for views"""
43 """test index access for views"""
45 self.add_engines(2)
44 self.add_engines(2)
46 targets = self.client._build_targets('all')[-1]
45 targets = self.client._build_targets('all')[-1]
47 v = self.client[:]
46 v = self.client[:]
48 self.assertEquals(v.targets, targets)
47 self.assertEquals(v.targets, targets)
49 t = self.client.ids[2]
48 t = self.client.ids[2]
50 v = self.client[t]
49 v = self.client[t]
51 self.assert_(isinstance(v, DirectView))
50 self.assert_(isinstance(v, DirectView))
52 self.assertEquals(v.targets, t)
51 self.assertEquals(v.targets, t)
53 t = self.client.ids[2:4]
52 t = self.client.ids[2:4]
54 v = self.client[t]
53 v = self.client[t]
55 self.assert_(isinstance(v, DirectView))
54 self.assert_(isinstance(v, DirectView))
56 self.assertEquals(v.targets, t)
55 self.assertEquals(v.targets, t)
57 v = self.client[::2]
56 v = self.client[::2]
58 self.assert_(isinstance(v, DirectView))
57 self.assert_(isinstance(v, DirectView))
59 self.assertEquals(v.targets, targets[::2])
58 self.assertEquals(v.targets, targets[::2])
60 v = self.client[1::3]
59 v = self.client[1::3]
61 self.assert_(isinstance(v, DirectView))
60 self.assert_(isinstance(v, DirectView))
62 self.assertEquals(v.targets, targets[1::3])
61 self.assertEquals(v.targets, targets[1::3])
63 v = self.client[:-3]
62 v = self.client[:-3]
64 self.assert_(isinstance(v, DirectView))
63 self.assert_(isinstance(v, DirectView))
65 self.assertEquals(v.targets, targets[:-3])
64 self.assertEquals(v.targets, targets[:-3])
66 v = self.client[-1]
65 v = self.client[-1]
67 self.assert_(isinstance(v, DirectView))
66 self.assert_(isinstance(v, DirectView))
68 self.assertEquals(v.targets, targets[-1])
67 self.assertEquals(v.targets, targets[-1])
69 nt.assert_raises(TypeError, lambda : self.client[None])
68 nt.assert_raises(TypeError, lambda : self.client[None])
70
69
71 def test_view_cache(self):
70 def test_view_cache(self):
72 """test that multiple view requests return the same object"""
71 """test that multiple view requests return the same object"""
73 v = self.client[:2]
72 v = self.client[:2]
74 v2 =self.client[:2]
73 v2 =self.client[:2]
75 self.assertTrue(v is v2)
74 self.assertTrue(v is v2)
76 v = self.client.view()
75 v = self.client.view()
77 v2 = self.client.view(balanced=True)
76 v2 = self.client.view(balanced=True)
78 self.assertTrue(v is v2)
77 self.assertTrue(v is v2)
79
78
80 def test_targets(self):
79 def test_targets(self):
81 """test various valid targets arguments"""
80 """test various valid targets arguments"""
82 build = self.client._build_targets
81 build = self.client._build_targets
83 ids = self.client.ids
82 ids = self.client.ids
84 idents,targets = build(None)
83 idents,targets = build(None)
85 self.assertEquals(ids, targets)
84 self.assertEquals(ids, targets)
86
85
87 def test_clear(self):
86 def test_clear(self):
88 """test clear behavior"""
87 """test clear behavior"""
89 self.add_engines(2)
88 self.add_engines(2)
90 self.client.block=True
89 self.client.block=True
91 self.client.push(dict(a=5))
90 self.client.push(dict(a=5))
92 self.client.pull('a')
91 self.client.pull('a')
93 id0 = self.client.ids[-1]
92 id0 = self.client.ids[-1]
94 self.client.clear(targets=id0)
93 self.client.clear(targets=id0)
95 self.client.pull('a', targets=self.client.ids[:-1])
94 self.client.pull('a', targets=self.client.ids[:-1])
96 self.assertRaisesRemote(NameError, self.client.pull, 'a')
95 self.assertRaisesRemote(NameError, self.client.pull, 'a')
97 self.client.clear()
96 self.client.clear()
98 for i in self.client.ids:
97 for i in self.client.ids:
99 self.assertRaisesRemote(NameError, self.client.pull, 'a', targets=i)
98 self.assertRaisesRemote(NameError, self.client.pull, 'a', targets=i)
100
99
101
100
102 def test_push_pull(self):
101 def test_push_pull(self):
103 """test pushing and pulling"""
102 """test pushing and pulling"""
104 data = dict(a=10, b=1.05, c=range(10), d={'e':(1,2),'f':'hi'})
103 data = dict(a=10, b=1.05, c=range(10), d={'e':(1,2),'f':'hi'})
105 t = self.client.ids[-1]
104 t = self.client.ids[-1]
106 self.add_engines(2)
105 self.add_engines(2)
107 push = self.client.push
106 push = self.client.push
108 pull = self.client.pull
107 pull = self.client.pull
109 self.client.block=True
108 self.client.block=True
110 nengines = len(self.client)
109 nengines = len(self.client)
111 push({'data':data}, targets=t)
110 push({'data':data}, targets=t)
112 d = pull('data', targets=t)
111 d = pull('data', targets=t)
113 self.assertEquals(d, data)
112 self.assertEquals(d, data)
114 push({'data':data})
113 push({'data':data})
115 d = pull('data')
114 d = pull('data')
116 self.assertEquals(d, nengines*[data])
115 self.assertEquals(d, nengines*[data])
117 ar = push({'data':data}, block=False)
116 ar = push({'data':data}, block=False)
118 self.assertTrue(isinstance(ar, AsyncResult))
117 self.assertTrue(isinstance(ar, AsyncResult))
119 r = ar.get()
118 r = ar.get()
120 ar = pull('data', block=False)
119 ar = pull('data', block=False)
121 self.assertTrue(isinstance(ar, AsyncResult))
120 self.assertTrue(isinstance(ar, AsyncResult))
122 r = ar.get()
121 r = ar.get()
123 self.assertEquals(r, nengines*[data])
122 self.assertEquals(r, nengines*[data])
124 push(dict(a=10,b=20))
123 push(dict(a=10,b=20))
125 r = pull(('a','b'))
124 r = pull(('a','b'))
126 self.assertEquals(r, nengines*[[10,20]])
125 self.assertEquals(r, nengines*[[10,20]])
127
126
128 def test_push_pull_function(self):
127 def test_push_pull_function(self):
129 "test pushing and pulling functions"
128 "test pushing and pulling functions"
130 def testf(x):
129 def testf(x):
131 return 2.0*x
130 return 2.0*x
132
131
133 self.add_engines(4)
132 self.add_engines(4)
134 t = self.client.ids[-1]
133 t = self.client.ids[-1]
135 self.client.block=True
134 self.client.block=True
136 push = self.client.push
135 push = self.client.push
137 pull = self.client.pull
136 pull = self.client.pull
138 execute = self.client.execute
137 execute = self.client.execute
139 push({'testf':testf}, targets=t)
138 push({'testf':testf}, targets=t)
140 r = pull('testf', targets=t)
139 r = pull('testf', targets=t)
141 self.assertEqual(r(1.0), testf(1.0))
140 self.assertEqual(r(1.0), testf(1.0))
142 execute('r = testf(10)', targets=t)
141 execute('r = testf(10)', targets=t)
143 r = pull('r', targets=t)
142 r = pull('r', targets=t)
144 self.assertEquals(r, testf(10))
143 self.assertEquals(r, testf(10))
145 ar = push({'testf':testf}, block=False)
144 ar = push({'testf':testf}, block=False)
146 ar.get()
145 ar.get()
147 ar = pull('testf', block=False)
146 ar = pull('testf', block=False)
148 rlist = ar.get()
147 rlist = ar.get()
149 for r in rlist:
148 for r in rlist:
150 self.assertEqual(r(1.0), testf(1.0))
149 self.assertEqual(r(1.0), testf(1.0))
151 execute("def g(x): return x*x", targets=t)
150 execute("def g(x): return x*x", targets=t)
152 r = pull(('testf','g'),targets=t)
151 r = pull(('testf','g'),targets=t)
153 self.assertEquals((r[0](10),r[1](10)), (testf(10), 100))
152 self.assertEquals((r[0](10),r[1](10)), (testf(10), 100))
154
153
155 def test_push_function_globals(self):
154 def test_push_function_globals(self):
156 """test that pushed functions have access to globals"""
155 """test that pushed functions have access to globals"""
157 def geta():
156 def geta():
158 return a
157 return a
159 self.add_engines(1)
158 self.add_engines(1)
160 v = self.client[-1]
159 v = self.client[-1]
161 v.block=True
160 v.block=True
162 v['f'] = geta
161 v['f'] = geta
163 self.assertRaisesRemote(NameError, v.execute, 'b=f()')
162 self.assertRaisesRemote(NameError, v.execute, 'b=f()')
164 v.execute('a=5')
163 v.execute('a=5')
165 v.execute('b=f()')
164 v.execute('b=f()')
166 self.assertEquals(v['b'], 5)
165 self.assertEquals(v['b'], 5)
167
166
168 def test_push_function_defaults(self):
167 def test_push_function_defaults(self):
169 """test that pushed functions preserve default args"""
168 """test that pushed functions preserve default args"""
170 def echo(a=10):
169 def echo(a=10):
171 return a
170 return a
172 self.add_engines(1)
171 self.add_engines(1)
173 v = self.client[-1]
172 v = self.client[-1]
174 v.block=True
173 v.block=True
175 v['f'] = echo
174 v['f'] = echo
176 v.execute('b=f()')
175 v.execute('b=f()')
177 self.assertEquals(v['b'], 10)
176 self.assertEquals(v['b'], 10)
178
177
179 def test_get_result(self):
178 def test_get_result(self):
180 """test getting results from the Hub."""
179 """test getting results from the Hub."""
181 c = clientmod.Client(profile='iptest')
180 c = clientmod.Client(profile='iptest')
182 t = self.client.ids[-1]
181 self.add_engines(1)
183 ar = c.apply(wait, (1,), block=False, targets=t)
182 ar = c.apply(wait, (1,), block=False, targets=t)
183 # give the monitor time to notice the message
184 time.sleep(.25)
184 time.sleep(.25)
185 ahr = self.client.get_result(ar.msg_ids)
185 ahr = self.client.get_result(ar.msg_ids)
186 self.assertTrue(isinstance(ahr, AsyncHubResult))
186 self.assertTrue(isinstance(ahr, AsyncHubResult))
187 self.assertEquals(ahr.get(), ar.get())
187 self.assertEquals(ahr.get(), ar.get())
188 ar2 = self.client.get_result(ar.msg_ids)
188 ar2 = self.client.get_result(ar.msg_ids)
189 self.assertFalse(isinstance(ar2, AsyncHubResult))
189 self.assertFalse(isinstance(ar2, AsyncHubResult))
190
190
191 def test_ids_list(self):
191 def test_ids_list(self):
192 """test client.ids"""
192 """test client.ids"""
193 self.add_engines(2)
193 self.add_engines(2)
194 ids = self.client.ids
194 ids = self.client.ids
195 self.assertEquals(ids, self.client._ids)
195 self.assertEquals(ids, self.client._ids)
196 self.assertFalse(ids is self.client._ids)
196 self.assertFalse(ids is self.client._ids)
197 ids.remove(ids[-1])
197 ids.remove(ids[-1])
198 self.assertNotEquals(ids, self.client._ids)
198 self.assertNotEquals(ids, self.client._ids)
199
199
200 def test_run_newline(self):
200 def test_run_newline(self):
201 """test that run appends newline to files"""
201 """test that run appends newline to files"""
202 tmpfile = mktemp()
202 tmpfile = mktemp()
203 with open(tmpfile, 'w') as f:
203 with open(tmpfile, 'w') as f:
204 f.write("""def g():
204 f.write("""def g():
205 return 5
205 return 5
206 """)
206 """)
207 v = self.client[-1]
207 v = self.client[-1]
208 v.run(tmpfile, block=True)
208 v.run(tmpfile, block=True)
209 self.assertEquals(v.apply_sync(lambda : g()), 5)
209 self.assertEquals(v.apply_sync(lambda : g()), 5)
210
210
211 def test_apply_tracked(self):
211 def test_apply_tracked(self):
212 """test tracking for apply"""
212 """test tracking for apply"""
213 # self.add_engines(1)
213 # self.add_engines(1)
214 t = self.client.ids[-1]
214 t = self.client.ids[-1]
215 self.client.block=False
215 self.client.block=False
216 def echo(n=1024*1024, **kwargs):
216 def echo(n=1024*1024, **kwargs):
217 return self.client.apply(lambda x: x, args=('x'*n,), targets=t, **kwargs)
217 return self.client.apply(lambda x: x, args=('x'*n,), targets=t, **kwargs)
218 ar = echo(1)
218 ar = echo(1)
219 self.assertTrue(ar._tracker is None)
219 self.assertTrue(ar._tracker is None)
220 self.assertTrue(ar.sent)
220 self.assertTrue(ar.sent)
221 ar = echo(track=True)
221 ar = echo(track=True)
222 self.assertTrue(isinstance(ar._tracker, zmq.MessageTracker))
222 self.assertTrue(isinstance(ar._tracker, zmq.MessageTracker))
223 self.assertEquals(ar.sent, ar._tracker.done)
223 self.assertEquals(ar.sent, ar._tracker.done)
224 ar._tracker.wait()
224 ar._tracker.wait()
225 self.assertTrue(ar.sent)
225 self.assertTrue(ar.sent)
226
226
227 def test_push_tracked(self):
227 def test_push_tracked(self):
228 t = self.client.ids[-1]
228 t = self.client.ids[-1]
229 ns = dict(x='x'*1024*1024)
229 ns = dict(x='x'*1024*1024)
230 ar = self.client.push(ns, targets=t, block=False)
230 ar = self.client.push(ns, targets=t, block=False)
231 self.assertTrue(ar._tracker is None)
231 self.assertTrue(ar._tracker is None)
232 self.assertTrue(ar.sent)
232 self.assertTrue(ar.sent)
233
233
234 ar = self.client.push(ns, targets=t, block=False, track=True)
234 ar = self.client.push(ns, targets=t, block=False, track=True)
235 self.assertTrue(isinstance(ar._tracker, zmq.MessageTracker))
235 self.assertTrue(isinstance(ar._tracker, zmq.MessageTracker))
236 self.assertEquals(ar.sent, ar._tracker.done)
236 self.assertEquals(ar.sent, ar._tracker.done)
237 ar._tracker.wait()
237 ar._tracker.wait()
238 self.assertTrue(ar.sent)
238 self.assertTrue(ar.sent)
239 ar.get()
239 ar.get()
240
240
241 def test_scatter_tracked(self):
241 def test_scatter_tracked(self):
242 t = self.client.ids
242 t = self.client.ids
243 x='x'*1024*1024
243 x='x'*1024*1024
244 ar = self.client.scatter('x', x, targets=t, block=False)
244 ar = self.client.scatter('x', x, targets=t, block=False)
245 self.assertTrue(ar._tracker is None)
245 self.assertTrue(ar._tracker is None)
246 self.assertTrue(ar.sent)
246 self.assertTrue(ar.sent)
247
247
248 ar = self.client.scatter('x', x, targets=t, block=False, track=True)
248 ar = self.client.scatter('x', x, targets=t, block=False, track=True)
249 self.assertTrue(isinstance(ar._tracker, zmq.MessageTracker))
249 self.assertTrue(isinstance(ar._tracker, zmq.MessageTracker))
250 self.assertEquals(ar.sent, ar._tracker.done)
250 self.assertEquals(ar.sent, ar._tracker.done)
251 ar._tracker.wait()
251 ar._tracker.wait()
252 self.assertTrue(ar.sent)
252 self.assertTrue(ar.sent)
253 ar.get()
253 ar.get()
254
254
255 def test_remote_reference(self):
255 def test_remote_reference(self):
256 v = self.client[-1]
256 v = self.client[-1]
257 v['a'] = 123
257 v['a'] = 123
258 ra = clientmod.Reference('a')
258 ra = clientmod.Reference('a')
259 b = v.apply_sync(lambda x: x, ra)
259 b = v.apply_sync(lambda x: x, ra)
260 self.assertEquals(b, 123)
260 self.assertEquals(b, 123)
261
261
262
262
1 NO CONTENT: modified file, binary diff hidden
NO CONTENT: modified file, binary diff hidden
1 NO CONTENT: modified file
NO CONTENT: modified file
The requested commit or file is too big and content was truncated. Show full diff
1 NO CONTENT: modified file, binary diff hidden
NO CONTENT: modified file, binary diff hidden
1 NO CONTENT: modified file, binary diff hidden
NO CONTENT: modified file, binary diff hidden
1 NO CONTENT: modified file, binary diff hidden
NO CONTENT: modified file, binary diff hidden
1 NO CONTENT: modified file, binary diff hidden
NO CONTENT: modified file, binary diff hidden
1 NO CONTENT: modified file
NO CONTENT: modified file
The requested commit or file is too big and content was truncated. Show full diff
1 NO CONTENT: modified file
NO CONTENT: modified file
The requested commit or file is too big and content was truncated. Show full diff
1 NO CONTENT: file was removed, binary diff hidden
NO CONTENT: file was removed, binary diff hidden
1 NO CONTENT: file was removed, binary diff hidden
NO CONTENT: file was removed, binary diff hidden
General Comments 0
You need to be logged in to leave comments. Login now