##// END OF EJS Templates
add Client.resubmit for re-running tasks...
MinRK -
Show More
@@ -1,1292 +1,1354 b''
1 """A semi-synchronous Client for the ZMQ cluster"""
1 """A semi-synchronous Client for the ZMQ cluster"""
2 #-----------------------------------------------------------------------------
2 #-----------------------------------------------------------------------------
3 # Copyright (C) 2010 The IPython Development Team
3 # Copyright (C) 2010 The IPython Development Team
4 #
4 #
5 # Distributed under the terms of the BSD License. The full license is in
5 # Distributed under the terms of the BSD License. The full license is in
6 # the file COPYING, distributed as part of this software.
6 # the file COPYING, distributed as part of this software.
7 #-----------------------------------------------------------------------------
7 #-----------------------------------------------------------------------------
8
8
9 #-----------------------------------------------------------------------------
9 #-----------------------------------------------------------------------------
10 # Imports
10 # Imports
11 #-----------------------------------------------------------------------------
11 #-----------------------------------------------------------------------------
12
12
13 import os
13 import os
14 import json
14 import json
15 import time
15 import time
16 import warnings
16 import warnings
17 from datetime import datetime
17 from datetime import datetime
18 from getpass import getpass
18 from getpass import getpass
19 from pprint import pprint
19 from pprint import pprint
20
20
21 pjoin = os.path.join
21 pjoin = os.path.join
22
22
23 import zmq
23 import zmq
24 # from zmq.eventloop import ioloop, zmqstream
24 # from zmq.eventloop import ioloop, zmqstream
25
25
26 from IPython.utils.path import get_ipython_dir
26 from IPython.utils.path import get_ipython_dir
27 from IPython.utils.traitlets import (HasTraits, Int, Instance, CUnicode,
27 from IPython.utils.traitlets import (HasTraits, Int, Instance, CUnicode,
28 Dict, List, Bool, Str, Set)
28 Dict, List, Bool, Str, Set)
29 from IPython.external.decorator import decorator
29 from IPython.external.decorator import decorator
30 from IPython.external.ssh import tunnel
30 from IPython.external.ssh import tunnel
31
31
32 from IPython.parallel import error
32 from IPython.parallel import error
33 from IPython.parallel import streamsession as ss
33 from IPython.parallel import streamsession as ss
34 from IPython.parallel import util
34 from IPython.parallel import util
35
35
36 from .asyncresult import AsyncResult, AsyncHubResult
36 from .asyncresult import AsyncResult, AsyncHubResult
37 from IPython.parallel.apps.clusterdir import ClusterDir, ClusterDirError
37 from IPython.parallel.apps.clusterdir import ClusterDir, ClusterDirError
38 from .view import DirectView, LoadBalancedView
38 from .view import DirectView, LoadBalancedView
39
39
40 #--------------------------------------------------------------------------
40 #--------------------------------------------------------------------------
41 # Decorators for Client methods
41 # Decorators for Client methods
42 #--------------------------------------------------------------------------
42 #--------------------------------------------------------------------------
43
43
44 @decorator
44 @decorator
45 def spin_first(f, self, *args, **kwargs):
45 def spin_first(f, self, *args, **kwargs):
46 """Call spin() to sync state prior to calling the method."""
46 """Call spin() to sync state prior to calling the method."""
47 self.spin()
47 self.spin()
48 return f(self, *args, **kwargs)
48 return f(self, *args, **kwargs)
49
49
50
50
51 #--------------------------------------------------------------------------
51 #--------------------------------------------------------------------------
52 # Classes
52 # Classes
53 #--------------------------------------------------------------------------
53 #--------------------------------------------------------------------------
54
54
55 class Metadata(dict):
55 class Metadata(dict):
56 """Subclass of dict for initializing metadata values.
56 """Subclass of dict for initializing metadata values.
57
57
58 Attribute access works on keys.
58 Attribute access works on keys.
59
59
60 These objects have a strict set of keys - errors will raise if you try
60 These objects have a strict set of keys - errors will raise if you try
61 to add new keys.
61 to add new keys.
62 """
62 """
63 def __init__(self, *args, **kwargs):
63 def __init__(self, *args, **kwargs):
64 dict.__init__(self)
64 dict.__init__(self)
65 md = {'msg_id' : None,
65 md = {'msg_id' : None,
66 'submitted' : None,
66 'submitted' : None,
67 'started' : None,
67 'started' : None,
68 'completed' : None,
68 'completed' : None,
69 'received' : None,
69 'received' : None,
70 'engine_uuid' : None,
70 'engine_uuid' : None,
71 'engine_id' : None,
71 'engine_id' : None,
72 'follow' : None,
72 'follow' : None,
73 'after' : None,
73 'after' : None,
74 'status' : None,
74 'status' : None,
75
75
76 'pyin' : None,
76 'pyin' : None,
77 'pyout' : None,
77 'pyout' : None,
78 'pyerr' : None,
78 'pyerr' : None,
79 'stdout' : '',
79 'stdout' : '',
80 'stderr' : '',
80 'stderr' : '',
81 }
81 }
82 self.update(md)
82 self.update(md)
83 self.update(dict(*args, **kwargs))
83 self.update(dict(*args, **kwargs))
84
84
85 def __getattr__(self, key):
85 def __getattr__(self, key):
86 """getattr aliased to getitem"""
86 """getattr aliased to getitem"""
87 if key in self.iterkeys():
87 if key in self.iterkeys():
88 return self[key]
88 return self[key]
89 else:
89 else:
90 raise AttributeError(key)
90 raise AttributeError(key)
91
91
92 def __setattr__(self, key, value):
92 def __setattr__(self, key, value):
93 """setattr aliased to setitem, with strict"""
93 """setattr aliased to setitem, with strict"""
94 if key in self.iterkeys():
94 if key in self.iterkeys():
95 self[key] = value
95 self[key] = value
96 else:
96 else:
97 raise AttributeError(key)
97 raise AttributeError(key)
98
98
99 def __setitem__(self, key, value):
99 def __setitem__(self, key, value):
100 """strict static key enforcement"""
100 """strict static key enforcement"""
101 if key in self.iterkeys():
101 if key in self.iterkeys():
102 dict.__setitem__(self, key, value)
102 dict.__setitem__(self, key, value)
103 else:
103 else:
104 raise KeyError(key)
104 raise KeyError(key)
105
105
106
106
107 class Client(HasTraits):
107 class Client(HasTraits):
108 """A semi-synchronous client to the IPython ZMQ cluster
108 """A semi-synchronous client to the IPython ZMQ cluster
109
109
110 Parameters
110 Parameters
111 ----------
111 ----------
112
112
113 url_or_file : bytes; zmq url or path to ipcontroller-client.json
113 url_or_file : bytes; zmq url or path to ipcontroller-client.json
114 Connection information for the Hub's registration. If a json connector
114 Connection information for the Hub's registration. If a json connector
115 file is given, then likely no further configuration is necessary.
115 file is given, then likely no further configuration is necessary.
116 [Default: use profile]
116 [Default: use profile]
117 profile : bytes
117 profile : bytes
118 The name of the Cluster profile to be used to find connector information.
118 The name of the Cluster profile to be used to find connector information.
119 [Default: 'default']
119 [Default: 'default']
120 context : zmq.Context
120 context : zmq.Context
121 Pass an existing zmq.Context instance, otherwise the client will create its own.
121 Pass an existing zmq.Context instance, otherwise the client will create its own.
122 username : bytes
122 username : bytes
123 set username to be passed to the Session object
123 set username to be passed to the Session object
124 debug : bool
124 debug : bool
125 flag for lots of message printing for debug purposes
125 flag for lots of message printing for debug purposes
126
126
127 #-------------- ssh related args ----------------
127 #-------------- ssh related args ----------------
128 # These are args for configuring the ssh tunnel to be used
128 # These are args for configuring the ssh tunnel to be used
129 # credentials are used to forward connections over ssh to the Controller
129 # credentials are used to forward connections over ssh to the Controller
130 # Note that the ip given in `addr` needs to be relative to sshserver
130 # Note that the ip given in `addr` needs to be relative to sshserver
131 # The most basic case is to leave addr as pointing to localhost (127.0.0.1),
131 # The most basic case is to leave addr as pointing to localhost (127.0.0.1),
132 # and set sshserver as the same machine the Controller is on. However,
132 # and set sshserver as the same machine the Controller is on. However,
133 # the only requirement is that sshserver is able to see the Controller
133 # the only requirement is that sshserver is able to see the Controller
134 # (i.e. is within the same trusted network).
134 # (i.e. is within the same trusted network).
135
135
136 sshserver : str
136 sshserver : str
137 A string of the form passed to ssh, i.e. 'server.tld' or 'user@server.tld:port'
137 A string of the form passed to ssh, i.e. 'server.tld' or 'user@server.tld:port'
138 If keyfile or password is specified, and this is not, it will default to
138 If keyfile or password is specified, and this is not, it will default to
139 the ip given in addr.
139 the ip given in addr.
140 sshkey : str; path to public ssh key file
140 sshkey : str; path to public ssh key file
141 This specifies a key to be used in ssh login, default None.
141 This specifies a key to be used in ssh login, default None.
142 Regular default ssh keys will be used without specifying this argument.
142 Regular default ssh keys will be used without specifying this argument.
143 password : str
143 password : str
144 Your ssh password to sshserver. Note that if this is left None,
144 Your ssh password to sshserver. Note that if this is left None,
145 you will be prompted for it if passwordless key based login is unavailable.
145 you will be prompted for it if passwordless key based login is unavailable.
146 paramiko : bool
146 paramiko : bool
147 flag for whether to use paramiko instead of shell ssh for tunneling.
147 flag for whether to use paramiko instead of shell ssh for tunneling.
148 [default: True on win32, False else]
148 [default: True on win32, False else]
149
149
150 ------- exec authentication args -------
150 ------- exec authentication args -------
151 If even localhost is untrusted, you can have some protection against
151 If even localhost is untrusted, you can have some protection against
152 unauthorized execution by using a key. Messages are still sent
152 unauthorized execution by using a key. Messages are still sent
153 as cleartext, so if someone can snoop your loopback traffic this will
153 as cleartext, so if someone can snoop your loopback traffic this will
154 not help against malicious attacks.
154 not help against malicious attacks.
155
155
156 exec_key : str
156 exec_key : str
157 an authentication key or file containing a key
157 an authentication key or file containing a key
158 default: None
158 default: None
159
159
160
160
161 Attributes
161 Attributes
162 ----------
162 ----------
163
163
164 ids : list of int engine IDs
164 ids : list of int engine IDs
165 requesting the ids attribute always synchronizes
165 requesting the ids attribute always synchronizes
166 the registration state. To request ids without synchronization,
166 the registration state. To request ids without synchronization,
167 use semi-private _ids attributes.
167 use semi-private _ids attributes.
168
168
169 history : list of msg_ids
169 history : list of msg_ids
170 a list of msg_ids, keeping track of all the execution
170 a list of msg_ids, keeping track of all the execution
171 messages you have submitted in order.
171 messages you have submitted in order.
172
172
173 outstanding : set of msg_ids
173 outstanding : set of msg_ids
174 a set of msg_ids that have been submitted, but whose
174 a set of msg_ids that have been submitted, but whose
175 results have not yet been received.
175 results have not yet been received.
176
176
177 results : dict
177 results : dict
178 a dict of all our results, keyed by msg_id
178 a dict of all our results, keyed by msg_id
179
179
180 block : bool
180 block : bool
181 determines default behavior when block not specified
181 determines default behavior when block not specified
182 in execution methods
182 in execution methods
183
183
184 Methods
184 Methods
185 -------
185 -------
186
186
187 spin
187 spin
188 flushes incoming results and registration state changes
188 flushes incoming results and registration state changes
189 control methods spin, and requesting `ids` also ensures up to date
189 control methods spin, and requesting `ids` also ensures up to date
190
190
191 wait
191 wait
192 wait on one or more msg_ids
192 wait on one or more msg_ids
193
193
194 execution methods
194 execution methods
195 apply
195 apply
196 legacy: execute, run
196 legacy: execute, run
197
197
198 data movement
198 data movement
199 push, pull, scatter, gather
199 push, pull, scatter, gather
200
200
201 query methods
201 query methods
202 queue_status, get_result, purge, result_status
202 queue_status, get_result, purge, result_status
203
203
204 control methods
204 control methods
205 abort, shutdown
205 abort, shutdown
206
206
207 """
207 """
208
208
209
209
210 block = Bool(False)
210 block = Bool(False)
211 outstanding = Set()
211 outstanding = Set()
212 results = Instance('collections.defaultdict', (dict,))
212 results = Instance('collections.defaultdict', (dict,))
213 metadata = Instance('collections.defaultdict', (Metadata,))
213 metadata = Instance('collections.defaultdict', (Metadata,))
214 history = List()
214 history = List()
215 debug = Bool(False)
215 debug = Bool(False)
216 profile=CUnicode('default')
216 profile=CUnicode('default')
217
217
218 _outstanding_dict = Instance('collections.defaultdict', (set,))
218 _outstanding_dict = Instance('collections.defaultdict', (set,))
219 _ids = List()
219 _ids = List()
220 _connected=Bool(False)
220 _connected=Bool(False)
221 _ssh=Bool(False)
221 _ssh=Bool(False)
222 _context = Instance('zmq.Context')
222 _context = Instance('zmq.Context')
223 _config = Dict()
223 _config = Dict()
224 _engines=Instance(util.ReverseDict, (), {})
224 _engines=Instance(util.ReverseDict, (), {})
225 # _hub_socket=Instance('zmq.Socket')
225 # _hub_socket=Instance('zmq.Socket')
226 _query_socket=Instance('zmq.Socket')
226 _query_socket=Instance('zmq.Socket')
227 _control_socket=Instance('zmq.Socket')
227 _control_socket=Instance('zmq.Socket')
228 _iopub_socket=Instance('zmq.Socket')
228 _iopub_socket=Instance('zmq.Socket')
229 _notification_socket=Instance('zmq.Socket')
229 _notification_socket=Instance('zmq.Socket')
230 _mux_socket=Instance('zmq.Socket')
230 _mux_socket=Instance('zmq.Socket')
231 _task_socket=Instance('zmq.Socket')
231 _task_socket=Instance('zmq.Socket')
232 _task_scheme=Str()
232 _task_scheme=Str()
233 _closed = False
233 _closed = False
234 _ignored_control_replies=Int(0)
234 _ignored_control_replies=Int(0)
235 _ignored_hub_replies=Int(0)
235 _ignored_hub_replies=Int(0)
236
236
237 def __init__(self, url_or_file=None, profile='default', cluster_dir=None, ipython_dir=None,
237 def __init__(self, url_or_file=None, profile='default', cluster_dir=None, ipython_dir=None,
238 context=None, username=None, debug=False, exec_key=None,
238 context=None, username=None, debug=False, exec_key=None,
239 sshserver=None, sshkey=None, password=None, paramiko=None,
239 sshserver=None, sshkey=None, password=None, paramiko=None,
240 timeout=10
240 timeout=10
241 ):
241 ):
242 super(Client, self).__init__(debug=debug, profile=profile)
242 super(Client, self).__init__(debug=debug, profile=profile)
243 if context is None:
243 if context is None:
244 context = zmq.Context.instance()
244 context = zmq.Context.instance()
245 self._context = context
245 self._context = context
246
246
247
247
248 self._setup_cluster_dir(profile, cluster_dir, ipython_dir)
248 self._setup_cluster_dir(profile, cluster_dir, ipython_dir)
249 if self._cd is not None:
249 if self._cd is not None:
250 if url_or_file is None:
250 if url_or_file is None:
251 url_or_file = pjoin(self._cd.security_dir, 'ipcontroller-client.json')
251 url_or_file = pjoin(self._cd.security_dir, 'ipcontroller-client.json')
252 assert url_or_file is not None, "I can't find enough information to connect to a hub!"\
252 assert url_or_file is not None, "I can't find enough information to connect to a hub!"\
253 " Please specify at least one of url_or_file or profile."
253 " Please specify at least one of url_or_file or profile."
254
254
255 try:
255 try:
256 util.validate_url(url_or_file)
256 util.validate_url(url_or_file)
257 except AssertionError:
257 except AssertionError:
258 if not os.path.exists(url_or_file):
258 if not os.path.exists(url_or_file):
259 if self._cd:
259 if self._cd:
260 url_or_file = os.path.join(self._cd.security_dir, url_or_file)
260 url_or_file = os.path.join(self._cd.security_dir, url_or_file)
261 assert os.path.exists(url_or_file), "Not a valid connection file or url: %r"%url_or_file
261 assert os.path.exists(url_or_file), "Not a valid connection file or url: %r"%url_or_file
262 with open(url_or_file) as f:
262 with open(url_or_file) as f:
263 cfg = json.loads(f.read())
263 cfg = json.loads(f.read())
264 else:
264 else:
265 cfg = {'url':url_or_file}
265 cfg = {'url':url_or_file}
266
266
267 # sync defaults from args, json:
267 # sync defaults from args, json:
268 if sshserver:
268 if sshserver:
269 cfg['ssh'] = sshserver
269 cfg['ssh'] = sshserver
270 if exec_key:
270 if exec_key:
271 cfg['exec_key'] = exec_key
271 cfg['exec_key'] = exec_key
272 exec_key = cfg['exec_key']
272 exec_key = cfg['exec_key']
273 sshserver=cfg['ssh']
273 sshserver=cfg['ssh']
274 url = cfg['url']
274 url = cfg['url']
275 location = cfg.setdefault('location', None)
275 location = cfg.setdefault('location', None)
276 cfg['url'] = util.disambiguate_url(cfg['url'], location)
276 cfg['url'] = util.disambiguate_url(cfg['url'], location)
277 url = cfg['url']
277 url = cfg['url']
278
278
279 self._config = cfg
279 self._config = cfg
280
280
281 self._ssh = bool(sshserver or sshkey or password)
281 self._ssh = bool(sshserver or sshkey or password)
282 if self._ssh and sshserver is None:
282 if self._ssh and sshserver is None:
283 # default to ssh via localhost
283 # default to ssh via localhost
284 sshserver = url.split('://')[1].split(':')[0]
284 sshserver = url.split('://')[1].split(':')[0]
285 if self._ssh and password is None:
285 if self._ssh and password is None:
286 if tunnel.try_passwordless_ssh(sshserver, sshkey, paramiko):
286 if tunnel.try_passwordless_ssh(sshserver, sshkey, paramiko):
287 password=False
287 password=False
288 else:
288 else:
289 password = getpass("SSH Password for %s: "%sshserver)
289 password = getpass("SSH Password for %s: "%sshserver)
290 ssh_kwargs = dict(keyfile=sshkey, password=password, paramiko=paramiko)
290 ssh_kwargs = dict(keyfile=sshkey, password=password, paramiko=paramiko)
291 if exec_key is not None and os.path.isfile(exec_key):
291 if exec_key is not None and os.path.isfile(exec_key):
292 arg = 'keyfile'
292 arg = 'keyfile'
293 else:
293 else:
294 arg = 'key'
294 arg = 'key'
295 key_arg = {arg:exec_key}
295 key_arg = {arg:exec_key}
296 if username is None:
296 if username is None:
297 self.session = ss.StreamSession(**key_arg)
297 self.session = ss.StreamSession(**key_arg)
298 else:
298 else:
299 self.session = ss.StreamSession(username, **key_arg)
299 self.session = ss.StreamSession(username, **key_arg)
300 self._query_socket = self._context.socket(zmq.XREQ)
300 self._query_socket = self._context.socket(zmq.XREQ)
301 self._query_socket.setsockopt(zmq.IDENTITY, self.session.session)
301 self._query_socket.setsockopt(zmq.IDENTITY, self.session.session)
302 if self._ssh:
302 if self._ssh:
303 tunnel.tunnel_connection(self._query_socket, url, sshserver, **ssh_kwargs)
303 tunnel.tunnel_connection(self._query_socket, url, sshserver, **ssh_kwargs)
304 else:
304 else:
305 self._query_socket.connect(url)
305 self._query_socket.connect(url)
306
306
307 self.session.debug = self.debug
307 self.session.debug = self.debug
308
308
309 self._notification_handlers = {'registration_notification' : self._register_engine,
309 self._notification_handlers = {'registration_notification' : self._register_engine,
310 'unregistration_notification' : self._unregister_engine,
310 'unregistration_notification' : self._unregister_engine,
311 'shutdown_notification' : lambda msg: self.close(),
311 'shutdown_notification' : lambda msg: self.close(),
312 }
312 }
313 self._queue_handlers = {'execute_reply' : self._handle_execute_reply,
313 self._queue_handlers = {'execute_reply' : self._handle_execute_reply,
314 'apply_reply' : self._handle_apply_reply}
314 'apply_reply' : self._handle_apply_reply}
315 self._connect(sshserver, ssh_kwargs, timeout)
315 self._connect(sshserver, ssh_kwargs, timeout)
316
316
317 def __del__(self):
317 def __del__(self):
318 """cleanup sockets, but _not_ context."""
318 """cleanup sockets, but _not_ context."""
319 self.close()
319 self.close()
320
320
321 def _setup_cluster_dir(self, profile, cluster_dir, ipython_dir):
321 def _setup_cluster_dir(self, profile, cluster_dir, ipython_dir):
322 if ipython_dir is None:
322 if ipython_dir is None:
323 ipython_dir = get_ipython_dir()
323 ipython_dir = get_ipython_dir()
324 if cluster_dir is not None:
324 if cluster_dir is not None:
325 try:
325 try:
326 self._cd = ClusterDir.find_cluster_dir(cluster_dir)
326 self._cd = ClusterDir.find_cluster_dir(cluster_dir)
327 return
327 return
328 except ClusterDirError:
328 except ClusterDirError:
329 pass
329 pass
330 elif profile is not None:
330 elif profile is not None:
331 try:
331 try:
332 self._cd = ClusterDir.find_cluster_dir_by_profile(
332 self._cd = ClusterDir.find_cluster_dir_by_profile(
333 ipython_dir, profile)
333 ipython_dir, profile)
334 return
334 return
335 except ClusterDirError:
335 except ClusterDirError:
336 pass
336 pass
337 self._cd = None
337 self._cd = None
338
338
339 def _update_engines(self, engines):
339 def _update_engines(self, engines):
340 """Update our engines dict and _ids from a dict of the form: {id:uuid}."""
340 """Update our engines dict and _ids from a dict of the form: {id:uuid}."""
341 for k,v in engines.iteritems():
341 for k,v in engines.iteritems():
342 eid = int(k)
342 eid = int(k)
343 self._engines[eid] = bytes(v) # force not unicode
343 self._engines[eid] = bytes(v) # force not unicode
344 self._ids.append(eid)
344 self._ids.append(eid)
345 self._ids = sorted(self._ids)
345 self._ids = sorted(self._ids)
346 if sorted(self._engines.keys()) != range(len(self._engines)) and \
346 if sorted(self._engines.keys()) != range(len(self._engines)) and \
347 self._task_scheme == 'pure' and self._task_socket:
347 self._task_scheme == 'pure' and self._task_socket:
348 self._stop_scheduling_tasks()
348 self._stop_scheduling_tasks()
349
349
350 def _stop_scheduling_tasks(self):
350 def _stop_scheduling_tasks(self):
351 """Stop scheduling tasks because an engine has been unregistered
351 """Stop scheduling tasks because an engine has been unregistered
352 from a pure ZMQ scheduler.
352 from a pure ZMQ scheduler.
353 """
353 """
354 self._task_socket.close()
354 self._task_socket.close()
355 self._task_socket = None
355 self._task_socket = None
356 msg = "An engine has been unregistered, and we are using pure " +\
356 msg = "An engine has been unregistered, and we are using pure " +\
357 "ZMQ task scheduling. Task farming will be disabled."
357 "ZMQ task scheduling. Task farming will be disabled."
358 if self.outstanding:
358 if self.outstanding:
359 msg += " If you were running tasks when this happened, " +\
359 msg += " If you were running tasks when this happened, " +\
360 "some `outstanding` msg_ids may never resolve."
360 "some `outstanding` msg_ids may never resolve."
361 warnings.warn(msg, RuntimeWarning)
361 warnings.warn(msg, RuntimeWarning)
362
362
363 def _build_targets(self, targets):
363 def _build_targets(self, targets):
364 """Turn valid target IDs or 'all' into two lists:
364 """Turn valid target IDs or 'all' into two lists:
365 (int_ids, uuids).
365 (int_ids, uuids).
366 """
366 """
367 if not self._ids:
367 if not self._ids:
368 # flush notification socket if no engines yet, just in case
368 # flush notification socket if no engines yet, just in case
369 if not self.ids:
369 if not self.ids:
370 raise error.NoEnginesRegistered("Can't build targets without any engines")
370 raise error.NoEnginesRegistered("Can't build targets without any engines")
371
371
372 if targets is None:
372 if targets is None:
373 targets = self._ids
373 targets = self._ids
374 elif isinstance(targets, str):
374 elif isinstance(targets, str):
375 if targets.lower() == 'all':
375 if targets.lower() == 'all':
376 targets = self._ids
376 targets = self._ids
377 else:
377 else:
378 raise TypeError("%r not valid str target, must be 'all'"%(targets))
378 raise TypeError("%r not valid str target, must be 'all'"%(targets))
379 elif isinstance(targets, int):
379 elif isinstance(targets, int):
380 if targets < 0:
380 if targets < 0:
381 targets = self.ids[targets]
381 targets = self.ids[targets]
382 if targets not in self._ids:
382 if targets not in self._ids:
383 raise IndexError("No such engine: %i"%targets)
383 raise IndexError("No such engine: %i"%targets)
384 targets = [targets]
384 targets = [targets]
385
385
386 if isinstance(targets, slice):
386 if isinstance(targets, slice):
387 indices = range(len(self._ids))[targets]
387 indices = range(len(self._ids))[targets]
388 ids = self.ids
388 ids = self.ids
389 targets = [ ids[i] for i in indices ]
389 targets = [ ids[i] for i in indices ]
390
390
391 if not isinstance(targets, (tuple, list, xrange)):
391 if not isinstance(targets, (tuple, list, xrange)):
392 raise TypeError("targets by int/slice/collection of ints only, not %s"%(type(targets)))
392 raise TypeError("targets by int/slice/collection of ints only, not %s"%(type(targets)))
393
393
394 return [self._engines[t] for t in targets], list(targets)
394 return [self._engines[t] for t in targets], list(targets)
395
395
396 def _connect(self, sshserver, ssh_kwargs, timeout):
396 def _connect(self, sshserver, ssh_kwargs, timeout):
397 """setup all our socket connections to the cluster. This is called from
397 """setup all our socket connections to the cluster. This is called from
398 __init__."""
398 __init__."""
399
399
400 # Maybe allow reconnecting?
400 # Maybe allow reconnecting?
401 if self._connected:
401 if self._connected:
402 return
402 return
403 self._connected=True
403 self._connected=True
404
404
405 def connect_socket(s, url):
405 def connect_socket(s, url):
406 url = util.disambiguate_url(url, self._config['location'])
406 url = util.disambiguate_url(url, self._config['location'])
407 if self._ssh:
407 if self._ssh:
408 return tunnel.tunnel_connection(s, url, sshserver, **ssh_kwargs)
408 return tunnel.tunnel_connection(s, url, sshserver, **ssh_kwargs)
409 else:
409 else:
410 return s.connect(url)
410 return s.connect(url)
411
411
412 self.session.send(self._query_socket, 'connection_request')
412 self.session.send(self._query_socket, 'connection_request')
413 r,w,x = zmq.select([self._query_socket],[],[], timeout)
413 r,w,x = zmq.select([self._query_socket],[],[], timeout)
414 if not r:
414 if not r:
415 raise error.TimeoutError("Hub connection request timed out")
415 raise error.TimeoutError("Hub connection request timed out")
416 idents,msg = self.session.recv(self._query_socket,mode=0)
416 idents,msg = self.session.recv(self._query_socket,mode=0)
417 if self.debug:
417 if self.debug:
418 pprint(msg)
418 pprint(msg)
419 msg = ss.Message(msg)
419 msg = ss.Message(msg)
420 content = msg.content
420 content = msg.content
421 self._config['registration'] = dict(content)
421 self._config['registration'] = dict(content)
422 if content.status == 'ok':
422 if content.status == 'ok':
423 if content.mux:
423 if content.mux:
424 self._mux_socket = self._context.socket(zmq.XREQ)
424 self._mux_socket = self._context.socket(zmq.XREQ)
425 self._mux_socket.setsockopt(zmq.IDENTITY, self.session.session)
425 self._mux_socket.setsockopt(zmq.IDENTITY, self.session.session)
426 connect_socket(self._mux_socket, content.mux)
426 connect_socket(self._mux_socket, content.mux)
427 if content.task:
427 if content.task:
428 self._task_scheme, task_addr = content.task
428 self._task_scheme, task_addr = content.task
429 self._task_socket = self._context.socket(zmq.XREQ)
429 self._task_socket = self._context.socket(zmq.XREQ)
430 self._task_socket.setsockopt(zmq.IDENTITY, self.session.session)
430 self._task_socket.setsockopt(zmq.IDENTITY, self.session.session)
431 connect_socket(self._task_socket, task_addr)
431 connect_socket(self._task_socket, task_addr)
432 if content.notification:
432 if content.notification:
433 self._notification_socket = self._context.socket(zmq.SUB)
433 self._notification_socket = self._context.socket(zmq.SUB)
434 connect_socket(self._notification_socket, content.notification)
434 connect_socket(self._notification_socket, content.notification)
435 self._notification_socket.setsockopt(zmq.SUBSCRIBE, b'')
435 self._notification_socket.setsockopt(zmq.SUBSCRIBE, b'')
436 # if content.query:
436 # if content.query:
437 # self._query_socket = self._context.socket(zmq.XREQ)
437 # self._query_socket = self._context.socket(zmq.XREQ)
438 # self._query_socket.setsockopt(zmq.IDENTITY, self.session.session)
438 # self._query_socket.setsockopt(zmq.IDENTITY, self.session.session)
439 # connect_socket(self._query_socket, content.query)
439 # connect_socket(self._query_socket, content.query)
440 if content.control:
440 if content.control:
441 self._control_socket = self._context.socket(zmq.XREQ)
441 self._control_socket = self._context.socket(zmq.XREQ)
442 self._control_socket.setsockopt(zmq.IDENTITY, self.session.session)
442 self._control_socket.setsockopt(zmq.IDENTITY, self.session.session)
443 connect_socket(self._control_socket, content.control)
443 connect_socket(self._control_socket, content.control)
444 if content.iopub:
444 if content.iopub:
445 self._iopub_socket = self._context.socket(zmq.SUB)
445 self._iopub_socket = self._context.socket(zmq.SUB)
446 self._iopub_socket.setsockopt(zmq.SUBSCRIBE, b'')
446 self._iopub_socket.setsockopt(zmq.SUBSCRIBE, b'')
447 self._iopub_socket.setsockopt(zmq.IDENTITY, self.session.session)
447 self._iopub_socket.setsockopt(zmq.IDENTITY, self.session.session)
448 connect_socket(self._iopub_socket, content.iopub)
448 connect_socket(self._iopub_socket, content.iopub)
449 self._update_engines(dict(content.engines))
449 self._update_engines(dict(content.engines))
450 else:
450 else:
451 self._connected = False
451 self._connected = False
452 raise Exception("Failed to connect!")
452 raise Exception("Failed to connect!")
453
453
454 #--------------------------------------------------------------------------
454 #--------------------------------------------------------------------------
455 # handlers and callbacks for incoming messages
455 # handlers and callbacks for incoming messages
456 #--------------------------------------------------------------------------
456 #--------------------------------------------------------------------------
457
457
458 def _unwrap_exception(self, content):
458 def _unwrap_exception(self, content):
459 """unwrap exception, and remap engine_id to int."""
459 """unwrap exception, and remap engine_id to int."""
460 e = error.unwrap_exception(content)
460 e = error.unwrap_exception(content)
461 # print e.traceback
461 # print e.traceback
462 if e.engine_info:
462 if e.engine_info:
463 e_uuid = e.engine_info['engine_uuid']
463 e_uuid = e.engine_info['engine_uuid']
464 eid = self._engines[e_uuid]
464 eid = self._engines[e_uuid]
465 e.engine_info['engine_id'] = eid
465 e.engine_info['engine_id'] = eid
466 return e
466 return e
467
467
468 def _extract_metadata(self, header, parent, content):
468 def _extract_metadata(self, header, parent, content):
469 md = {'msg_id' : parent['msg_id'],
469 md = {'msg_id' : parent['msg_id'],
470 'received' : datetime.now(),
470 'received' : datetime.now(),
471 'engine_uuid' : header.get('engine', None),
471 'engine_uuid' : header.get('engine', None),
472 'follow' : parent.get('follow', []),
472 'follow' : parent.get('follow', []),
473 'after' : parent.get('after', []),
473 'after' : parent.get('after', []),
474 'status' : content['status'],
474 'status' : content['status'],
475 }
475 }
476
476
477 if md['engine_uuid'] is not None:
477 if md['engine_uuid'] is not None:
478 md['engine_id'] = self._engines.get(md['engine_uuid'], None)
478 md['engine_id'] = self._engines.get(md['engine_uuid'], None)
479
479
480 if 'date' in parent:
480 if 'date' in parent:
481 md['submitted'] = datetime.strptime(parent['date'], util.ISO8601)
481 md['submitted'] = datetime.strptime(parent['date'], util.ISO8601)
482 if 'started' in header:
482 if 'started' in header:
483 md['started'] = datetime.strptime(header['started'], util.ISO8601)
483 md['started'] = datetime.strptime(header['started'], util.ISO8601)
484 if 'date' in header:
484 if 'date' in header:
485 md['completed'] = datetime.strptime(header['date'], util.ISO8601)
485 md['completed'] = datetime.strptime(header['date'], util.ISO8601)
486 return md
486 return md
487
487
488 def _register_engine(self, msg):
488 def _register_engine(self, msg):
489 """Register a new engine, and update our connection info."""
489 """Register a new engine, and update our connection info."""
490 content = msg['content']
490 content = msg['content']
491 eid = content['id']
491 eid = content['id']
492 d = {eid : content['queue']}
492 d = {eid : content['queue']}
493 self._update_engines(d)
493 self._update_engines(d)
494
494
495 def _unregister_engine(self, msg):
495 def _unregister_engine(self, msg):
496 """Unregister an engine that has died."""
496 """Unregister an engine that has died."""
497 content = msg['content']
497 content = msg['content']
498 eid = int(content['id'])
498 eid = int(content['id'])
499 if eid in self._ids:
499 if eid in self._ids:
500 self._ids.remove(eid)
500 self._ids.remove(eid)
501 uuid = self._engines.pop(eid)
501 uuid = self._engines.pop(eid)
502
502
503 self._handle_stranded_msgs(eid, uuid)
503 self._handle_stranded_msgs(eid, uuid)
504
504
505 if self._task_socket and self._task_scheme == 'pure':
505 if self._task_socket and self._task_scheme == 'pure':
506 self._stop_scheduling_tasks()
506 self._stop_scheduling_tasks()
507
507
508 def _handle_stranded_msgs(self, eid, uuid):
508 def _handle_stranded_msgs(self, eid, uuid):
509 """Handle messages known to be on an engine when the engine unregisters.
509 """Handle messages known to be on an engine when the engine unregisters.
510
510
511 It is possible that this will fire prematurely - that is, an engine will
511 It is possible that this will fire prematurely - that is, an engine will
512 go down after completing a result, and the client will be notified
512 go down after completing a result, and the client will be notified
513 of the unregistration and later receive the successful result.
513 of the unregistration and later receive the successful result.
514 """
514 """
515
515
516 outstanding = self._outstanding_dict[uuid]
516 outstanding = self._outstanding_dict[uuid]
517
517
518 for msg_id in list(outstanding):
518 for msg_id in list(outstanding):
519 if msg_id in self.results:
519 if msg_id in self.results:
520 # we already
520 # we already
521 continue
521 continue
522 try:
522 try:
523 raise error.EngineError("Engine %r died while running task %r"%(eid, msg_id))
523 raise error.EngineError("Engine %r died while running task %r"%(eid, msg_id))
524 except:
524 except:
525 content = error.wrap_exception()
525 content = error.wrap_exception()
526 # build a fake message:
526 # build a fake message:
527 parent = {}
527 parent = {}
528 header = {}
528 header = {}
529 parent['msg_id'] = msg_id
529 parent['msg_id'] = msg_id
530 header['engine'] = uuid
530 header['engine'] = uuid
531 header['date'] = datetime.now().strftime(util.ISO8601)
531 header['date'] = datetime.now().strftime(util.ISO8601)
532 msg = dict(parent_header=parent, header=header, content=content)
532 msg = dict(parent_header=parent, header=header, content=content)
533 self._handle_apply_reply(msg)
533 self._handle_apply_reply(msg)
534
534
535 def _handle_execute_reply(self, msg):
535 def _handle_execute_reply(self, msg):
536 """Save the reply to an execute_request into our results.
536 """Save the reply to an execute_request into our results.
537
537
538 execute messages are never actually used. apply is used instead.
538 execute messages are never actually used. apply is used instead.
539 """
539 """
540
540
541 parent = msg['parent_header']
541 parent = msg['parent_header']
542 msg_id = parent['msg_id']
542 msg_id = parent['msg_id']
543 if msg_id not in self.outstanding:
543 if msg_id not in self.outstanding:
544 if msg_id in self.history:
544 if msg_id in self.history:
545 print ("got stale result: %s"%msg_id)
545 print ("got stale result: %s"%msg_id)
546 else:
546 else:
547 print ("got unknown result: %s"%msg_id)
547 print ("got unknown result: %s"%msg_id)
548 else:
548 else:
549 self.outstanding.remove(msg_id)
549 self.outstanding.remove(msg_id)
550 self.results[msg_id] = self._unwrap_exception(msg['content'])
550 self.results[msg_id] = self._unwrap_exception(msg['content'])
551
551
552 def _handle_apply_reply(self, msg):
552 def _handle_apply_reply(self, msg):
553 """Save the reply to an apply_request into our results."""
553 """Save the reply to an apply_request into our results."""
554 parent = msg['parent_header']
554 parent = msg['parent_header']
555 msg_id = parent['msg_id']
555 msg_id = parent['msg_id']
556 if msg_id not in self.outstanding:
556 if msg_id not in self.outstanding:
557 if msg_id in self.history:
557 if msg_id in self.history:
558 print ("got stale result: %s"%msg_id)
558 print ("got stale result: %s"%msg_id)
559 print self.results[msg_id]
559 print self.results[msg_id]
560 print msg
560 print msg
561 else:
561 else:
562 print ("got unknown result: %s"%msg_id)
562 print ("got unknown result: %s"%msg_id)
563 else:
563 else:
564 self.outstanding.remove(msg_id)
564 self.outstanding.remove(msg_id)
565 content = msg['content']
565 content = msg['content']
566 header = msg['header']
566 header = msg['header']
567
567
568 # construct metadata:
568 # construct metadata:
569 md = self.metadata[msg_id]
569 md = self.metadata[msg_id]
570 md.update(self._extract_metadata(header, parent, content))
570 md.update(self._extract_metadata(header, parent, content))
571 # is this redundant?
571 # is this redundant?
572 self.metadata[msg_id] = md
572 self.metadata[msg_id] = md
573
573
574 e_outstanding = self._outstanding_dict[md['engine_uuid']]
574 e_outstanding = self._outstanding_dict[md['engine_uuid']]
575 if msg_id in e_outstanding:
575 if msg_id in e_outstanding:
576 e_outstanding.remove(msg_id)
576 e_outstanding.remove(msg_id)
577
577
578 # construct result:
578 # construct result:
579 if content['status'] == 'ok':
579 if content['status'] == 'ok':
580 self.results[msg_id] = util.unserialize_object(msg['buffers'])[0]
580 self.results[msg_id] = util.unserialize_object(msg['buffers'])[0]
581 elif content['status'] == 'aborted':
581 elif content['status'] == 'aborted':
582 self.results[msg_id] = error.TaskAborted(msg_id)
582 self.results[msg_id] = error.TaskAborted(msg_id)
583 elif content['status'] == 'resubmitted':
583 elif content['status'] == 'resubmitted':
584 # TODO: handle resubmission
584 # TODO: handle resubmission
585 pass
585 pass
586 else:
586 else:
587 self.results[msg_id] = self._unwrap_exception(content)
587 self.results[msg_id] = self._unwrap_exception(content)
588
588
589 def _flush_notifications(self):
589 def _flush_notifications(self):
590 """Flush notifications of engine registrations waiting
590 """Flush notifications of engine registrations waiting
591 in ZMQ queue."""
591 in ZMQ queue."""
592 msg = self.session.recv(self._notification_socket, mode=zmq.NOBLOCK)
592 msg = self.session.recv(self._notification_socket, mode=zmq.NOBLOCK)
593 while msg is not None:
593 while msg is not None:
594 if self.debug:
594 if self.debug:
595 pprint(msg)
595 pprint(msg)
596 msg = msg[-1]
596 msg = msg[-1]
597 msg_type = msg['msg_type']
597 msg_type = msg['msg_type']
598 handler = self._notification_handlers.get(msg_type, None)
598 handler = self._notification_handlers.get(msg_type, None)
599 if handler is None:
599 if handler is None:
600 raise Exception("Unhandled message type: %s"%msg.msg_type)
600 raise Exception("Unhandled message type: %s"%msg.msg_type)
601 else:
601 else:
602 handler(msg)
602 handler(msg)
603 msg = self.session.recv(self._notification_socket, mode=zmq.NOBLOCK)
603 msg = self.session.recv(self._notification_socket, mode=zmq.NOBLOCK)
604
604
605 def _flush_results(self, sock):
605 def _flush_results(self, sock):
606 """Flush task or queue results waiting in ZMQ queue."""
606 """Flush task or queue results waiting in ZMQ queue."""
607 msg = self.session.recv(sock, mode=zmq.NOBLOCK)
607 msg = self.session.recv(sock, mode=zmq.NOBLOCK)
608 while msg is not None:
608 while msg is not None:
609 if self.debug:
609 if self.debug:
610 pprint(msg)
610 pprint(msg)
611 msg = msg[-1]
611 msg = msg[-1]
612 msg_type = msg['msg_type']
612 msg_type = msg['msg_type']
613 handler = self._queue_handlers.get(msg_type, None)
613 handler = self._queue_handlers.get(msg_type, None)
614 if handler is None:
614 if handler is None:
615 raise Exception("Unhandled message type: %s"%msg.msg_type)
615 raise Exception("Unhandled message type: %s"%msg.msg_type)
616 else:
616 else:
617 handler(msg)
617 handler(msg)
618 msg = self.session.recv(sock, mode=zmq.NOBLOCK)
618 msg = self.session.recv(sock, mode=zmq.NOBLOCK)
619
619
620 def _flush_control(self, sock):
620 def _flush_control(self, sock):
621 """Flush replies from the control channel waiting
621 """Flush replies from the control channel waiting
622 in the ZMQ queue.
622 in the ZMQ queue.
623
623
624 Currently: ignore them."""
624 Currently: ignore them."""
625 if self._ignored_control_replies <= 0:
625 if self._ignored_control_replies <= 0:
626 return
626 return
627 msg = self.session.recv(sock, mode=zmq.NOBLOCK)
627 msg = self.session.recv(sock, mode=zmq.NOBLOCK)
628 while msg is not None:
628 while msg is not None:
629 self._ignored_control_replies -= 1
629 self._ignored_control_replies -= 1
630 if self.debug:
630 if self.debug:
631 pprint(msg)
631 pprint(msg)
632 msg = self.session.recv(sock, mode=zmq.NOBLOCK)
632 msg = self.session.recv(sock, mode=zmq.NOBLOCK)
633
633
634 def _flush_ignored_control(self):
634 def _flush_ignored_control(self):
635 """flush ignored control replies"""
635 """flush ignored control replies"""
636 while self._ignored_control_replies > 0:
636 while self._ignored_control_replies > 0:
637 self.session.recv(self._control_socket)
637 self.session.recv(self._control_socket)
638 self._ignored_control_replies -= 1
638 self._ignored_control_replies -= 1
639
639
640 def _flush_ignored_hub_replies(self):
640 def _flush_ignored_hub_replies(self):
641 msg = self.session.recv(self._query_socket, mode=zmq.NOBLOCK)
641 msg = self.session.recv(self._query_socket, mode=zmq.NOBLOCK)
642 while msg is not None:
642 while msg is not None:
643 msg = self.session.recv(self._query_socket, mode=zmq.NOBLOCK)
643 msg = self.session.recv(self._query_socket, mode=zmq.NOBLOCK)
644
644
645 def _flush_iopub(self, sock):
645 def _flush_iopub(self, sock):
646 """Flush replies from the iopub channel waiting
646 """Flush replies from the iopub channel waiting
647 in the ZMQ queue.
647 in the ZMQ queue.
648 """
648 """
649 msg = self.session.recv(sock, mode=zmq.NOBLOCK)
649 msg = self.session.recv(sock, mode=zmq.NOBLOCK)
650 while msg is not None:
650 while msg is not None:
651 if self.debug:
651 if self.debug:
652 pprint(msg)
652 pprint(msg)
653 msg = msg[-1]
653 msg = msg[-1]
654 parent = msg['parent_header']
654 parent = msg['parent_header']
655 msg_id = parent['msg_id']
655 msg_id = parent['msg_id']
656 content = msg['content']
656 content = msg['content']
657 header = msg['header']
657 header = msg['header']
658 msg_type = msg['msg_type']
658 msg_type = msg['msg_type']
659
659
660 # init metadata:
660 # init metadata:
661 md = self.metadata[msg_id]
661 md = self.metadata[msg_id]
662
662
663 if msg_type == 'stream':
663 if msg_type == 'stream':
664 name = content['name']
664 name = content['name']
665 s = md[name] or ''
665 s = md[name] or ''
666 md[name] = s + content['data']
666 md[name] = s + content['data']
667 elif msg_type == 'pyerr':
667 elif msg_type == 'pyerr':
668 md.update({'pyerr' : self._unwrap_exception(content)})
668 md.update({'pyerr' : self._unwrap_exception(content)})
669 elif msg_type == 'pyin':
669 elif msg_type == 'pyin':
670 md.update({'pyin' : content['code']})
670 md.update({'pyin' : content['code']})
671 else:
671 else:
672 md.update({msg_type : content.get('data', '')})
672 md.update({msg_type : content.get('data', '')})
673
673
674 # reduntant?
674 # reduntant?
675 self.metadata[msg_id] = md
675 self.metadata[msg_id] = md
676
676
677 msg = self.session.recv(sock, mode=zmq.NOBLOCK)
677 msg = self.session.recv(sock, mode=zmq.NOBLOCK)
678
678
679 #--------------------------------------------------------------------------
679 #--------------------------------------------------------------------------
680 # len, getitem
680 # len, getitem
681 #--------------------------------------------------------------------------
681 #--------------------------------------------------------------------------
682
682
683 def __len__(self):
683 def __len__(self):
684 """len(client) returns # of engines."""
684 """len(client) returns # of engines."""
685 return len(self.ids)
685 return len(self.ids)
686
686
687 def __getitem__(self, key):
687 def __getitem__(self, key):
688 """index access returns DirectView multiplexer objects
688 """index access returns DirectView multiplexer objects
689
689
690 Must be int, slice, or list/tuple/xrange of ints"""
690 Must be int, slice, or list/tuple/xrange of ints"""
691 if not isinstance(key, (int, slice, tuple, list, xrange)):
691 if not isinstance(key, (int, slice, tuple, list, xrange)):
692 raise TypeError("key by int/slice/iterable of ints only, not %s"%(type(key)))
692 raise TypeError("key by int/slice/iterable of ints only, not %s"%(type(key)))
693 else:
693 else:
694 return self.direct_view(key)
694 return self.direct_view(key)
695
695
696 #--------------------------------------------------------------------------
696 #--------------------------------------------------------------------------
697 # Begin public methods
697 # Begin public methods
698 #--------------------------------------------------------------------------
698 #--------------------------------------------------------------------------
699
699
700 @property
700 @property
701 def ids(self):
701 def ids(self):
702 """Always up-to-date ids property."""
702 """Always up-to-date ids property."""
703 self._flush_notifications()
703 self._flush_notifications()
704 # always copy:
704 # always copy:
705 return list(self._ids)
705 return list(self._ids)
706
706
707 def close(self):
707 def close(self):
708 if self._closed:
708 if self._closed:
709 return
709 return
710 snames = filter(lambda n: n.endswith('socket'), dir(self))
710 snames = filter(lambda n: n.endswith('socket'), dir(self))
711 for socket in map(lambda name: getattr(self, name), snames):
711 for socket in map(lambda name: getattr(self, name), snames):
712 if isinstance(socket, zmq.Socket) and not socket.closed:
712 if isinstance(socket, zmq.Socket) and not socket.closed:
713 socket.close()
713 socket.close()
714 self._closed = True
714 self._closed = True
715
715
716 def spin(self):
716 def spin(self):
717 """Flush any registration notifications and execution results
717 """Flush any registration notifications and execution results
718 waiting in the ZMQ queue.
718 waiting in the ZMQ queue.
719 """
719 """
720 if self._notification_socket:
720 if self._notification_socket:
721 self._flush_notifications()
721 self._flush_notifications()
722 if self._mux_socket:
722 if self._mux_socket:
723 self._flush_results(self._mux_socket)
723 self._flush_results(self._mux_socket)
724 if self._task_socket:
724 if self._task_socket:
725 self._flush_results(self._task_socket)
725 self._flush_results(self._task_socket)
726 if self._control_socket:
726 if self._control_socket:
727 self._flush_control(self._control_socket)
727 self._flush_control(self._control_socket)
728 if self._iopub_socket:
728 if self._iopub_socket:
729 self._flush_iopub(self._iopub_socket)
729 self._flush_iopub(self._iopub_socket)
730 if self._query_socket:
730 if self._query_socket:
731 self._flush_ignored_hub_replies()
731 self._flush_ignored_hub_replies()
732
732
733 def wait(self, jobs=None, timeout=-1):
733 def wait(self, jobs=None, timeout=-1):
734 """waits on one or more `jobs`, for up to `timeout` seconds.
734 """waits on one or more `jobs`, for up to `timeout` seconds.
735
735
736 Parameters
736 Parameters
737 ----------
737 ----------
738
738
739 jobs : int, str, or list of ints and/or strs, or one or more AsyncResult objects
739 jobs : int, str, or list of ints and/or strs, or one or more AsyncResult objects
740 ints are indices to self.history
740 ints are indices to self.history
741 strs are msg_ids
741 strs are msg_ids
742 default: wait on all outstanding messages
742 default: wait on all outstanding messages
743 timeout : float
743 timeout : float
744 a time in seconds, after which to give up.
744 a time in seconds, after which to give up.
745 default is -1, which means no timeout
745 default is -1, which means no timeout
746
746
747 Returns
747 Returns
748 -------
748 -------
749
749
750 True : when all msg_ids are done
750 True : when all msg_ids are done
751 False : timeout reached, some msg_ids still outstanding
751 False : timeout reached, some msg_ids still outstanding
752 """
752 """
753 tic = time.time()
753 tic = time.time()
754 if jobs is None:
754 if jobs is None:
755 theids = self.outstanding
755 theids = self.outstanding
756 else:
756 else:
757 if isinstance(jobs, (int, str, AsyncResult)):
757 if isinstance(jobs, (int, str, AsyncResult)):
758 jobs = [jobs]
758 jobs = [jobs]
759 theids = set()
759 theids = set()
760 for job in jobs:
760 for job in jobs:
761 if isinstance(job, int):
761 if isinstance(job, int):
762 # index access
762 # index access
763 job = self.history[job]
763 job = self.history[job]
764 elif isinstance(job, AsyncResult):
764 elif isinstance(job, AsyncResult):
765 map(theids.add, job.msg_ids)
765 map(theids.add, job.msg_ids)
766 continue
766 continue
767 theids.add(job)
767 theids.add(job)
768 if not theids.intersection(self.outstanding):
768 if not theids.intersection(self.outstanding):
769 return True
769 return True
770 self.spin()
770 self.spin()
771 while theids.intersection(self.outstanding):
771 while theids.intersection(self.outstanding):
772 if timeout >= 0 and ( time.time()-tic ) > timeout:
772 if timeout >= 0 and ( time.time()-tic ) > timeout:
773 break
773 break
774 time.sleep(1e-3)
774 time.sleep(1e-3)
775 self.spin()
775 self.spin()
776 return len(theids.intersection(self.outstanding)) == 0
776 return len(theids.intersection(self.outstanding)) == 0
777
777
778 #--------------------------------------------------------------------------
778 #--------------------------------------------------------------------------
779 # Control methods
779 # Control methods
780 #--------------------------------------------------------------------------
780 #--------------------------------------------------------------------------
781
781
782 @spin_first
782 @spin_first
783 def clear(self, targets=None, block=None):
783 def clear(self, targets=None, block=None):
784 """Clear the namespace in target(s)."""
784 """Clear the namespace in target(s)."""
785 block = self.block if block is None else block
785 block = self.block if block is None else block
786 targets = self._build_targets(targets)[0]
786 targets = self._build_targets(targets)[0]
787 for t in targets:
787 for t in targets:
788 self.session.send(self._control_socket, 'clear_request', content={}, ident=t)
788 self.session.send(self._control_socket, 'clear_request', content={}, ident=t)
789 error = False
789 error = False
790 if block:
790 if block:
791 self._flush_ignored_control()
791 self._flush_ignored_control()
792 for i in range(len(targets)):
792 for i in range(len(targets)):
793 idents,msg = self.session.recv(self._control_socket,0)
793 idents,msg = self.session.recv(self._control_socket,0)
794 if self.debug:
794 if self.debug:
795 pprint(msg)
795 pprint(msg)
796 if msg['content']['status'] != 'ok':
796 if msg['content']['status'] != 'ok':
797 error = self._unwrap_exception(msg['content'])
797 error = self._unwrap_exception(msg['content'])
798 else:
798 else:
799 self._ignored_control_replies += len(targets)
799 self._ignored_control_replies += len(targets)
800 if error:
800 if error:
801 raise error
801 raise error
802
802
803
803
804 @spin_first
804 @spin_first
805 def abort(self, jobs=None, targets=None, block=None):
805 def abort(self, jobs=None, targets=None, block=None):
806 """Abort specific jobs from the execution queues of target(s).
806 """Abort specific jobs from the execution queues of target(s).
807
807
808 This is a mechanism to prevent jobs that have already been submitted
808 This is a mechanism to prevent jobs that have already been submitted
809 from executing.
809 from executing.
810
810
811 Parameters
811 Parameters
812 ----------
812 ----------
813
813
814 jobs : msg_id, list of msg_ids, or AsyncResult
814 jobs : msg_id, list of msg_ids, or AsyncResult
815 The jobs to be aborted
815 The jobs to be aborted
816
816
817
817
818 """
818 """
819 block = self.block if block is None else block
819 block = self.block if block is None else block
820 targets = self._build_targets(targets)[0]
820 targets = self._build_targets(targets)[0]
821 msg_ids = []
821 msg_ids = []
822 if isinstance(jobs, (basestring,AsyncResult)):
822 if isinstance(jobs, (basestring,AsyncResult)):
823 jobs = [jobs]
823 jobs = [jobs]
824 bad_ids = filter(lambda obj: not isinstance(obj, (basestring, AsyncResult)), jobs)
824 bad_ids = filter(lambda obj: not isinstance(obj, (basestring, AsyncResult)), jobs)
825 if bad_ids:
825 if bad_ids:
826 raise TypeError("Invalid msg_id type %r, expected str or AsyncResult"%bad_ids[0])
826 raise TypeError("Invalid msg_id type %r, expected str or AsyncResult"%bad_ids[0])
827 for j in jobs:
827 for j in jobs:
828 if isinstance(j, AsyncResult):
828 if isinstance(j, AsyncResult):
829 msg_ids.extend(j.msg_ids)
829 msg_ids.extend(j.msg_ids)
830 else:
830 else:
831 msg_ids.append(j)
831 msg_ids.append(j)
832 content = dict(msg_ids=msg_ids)
832 content = dict(msg_ids=msg_ids)
833 for t in targets:
833 for t in targets:
834 self.session.send(self._control_socket, 'abort_request',
834 self.session.send(self._control_socket, 'abort_request',
835 content=content, ident=t)
835 content=content, ident=t)
836 error = False
836 error = False
837 if block:
837 if block:
838 self._flush_ignored_control()
838 self._flush_ignored_control()
839 for i in range(len(targets)):
839 for i in range(len(targets)):
840 idents,msg = self.session.recv(self._control_socket,0)
840 idents,msg = self.session.recv(self._control_socket,0)
841 if self.debug:
841 if self.debug:
842 pprint(msg)
842 pprint(msg)
843 if msg['content']['status'] != 'ok':
843 if msg['content']['status'] != 'ok':
844 error = self._unwrap_exception(msg['content'])
844 error = self._unwrap_exception(msg['content'])
845 else:
845 else:
846 self._ignored_control_replies += len(targets)
846 self._ignored_control_replies += len(targets)
847 if error:
847 if error:
848 raise error
848 raise error
849
849
850 @spin_first
850 @spin_first
851 def shutdown(self, targets=None, restart=False, hub=False, block=None):
851 def shutdown(self, targets=None, restart=False, hub=False, block=None):
852 """Terminates one or more engine processes, optionally including the hub."""
852 """Terminates one or more engine processes, optionally including the hub."""
853 block = self.block if block is None else block
853 block = self.block if block is None else block
854 if hub:
854 if hub:
855 targets = 'all'
855 targets = 'all'
856 targets = self._build_targets(targets)[0]
856 targets = self._build_targets(targets)[0]
857 for t in targets:
857 for t in targets:
858 self.session.send(self._control_socket, 'shutdown_request',
858 self.session.send(self._control_socket, 'shutdown_request',
859 content={'restart':restart},ident=t)
859 content={'restart':restart},ident=t)
860 error = False
860 error = False
861 if block or hub:
861 if block or hub:
862 self._flush_ignored_control()
862 self._flush_ignored_control()
863 for i in range(len(targets)):
863 for i in range(len(targets)):
864 idents,msg = self.session.recv(self._control_socket, 0)
864 idents,msg = self.session.recv(self._control_socket, 0)
865 if self.debug:
865 if self.debug:
866 pprint(msg)
866 pprint(msg)
867 if msg['content']['status'] != 'ok':
867 if msg['content']['status'] != 'ok':
868 error = self._unwrap_exception(msg['content'])
868 error = self._unwrap_exception(msg['content'])
869 else:
869 else:
870 self._ignored_control_replies += len(targets)
870 self._ignored_control_replies += len(targets)
871
871
872 if hub:
872 if hub:
873 time.sleep(0.25)
873 time.sleep(0.25)
874 self.session.send(self._query_socket, 'shutdown_request')
874 self.session.send(self._query_socket, 'shutdown_request')
875 idents,msg = self.session.recv(self._query_socket, 0)
875 idents,msg = self.session.recv(self._query_socket, 0)
876 if self.debug:
876 if self.debug:
877 pprint(msg)
877 pprint(msg)
878 if msg['content']['status'] != 'ok':
878 if msg['content']['status'] != 'ok':
879 error = self._unwrap_exception(msg['content'])
879 error = self._unwrap_exception(msg['content'])
880
880
881 if error:
881 if error:
882 raise error
882 raise error
883
883
884 #--------------------------------------------------------------------------
884 #--------------------------------------------------------------------------
885 # Execution related methods
885 # Execution related methods
886 #--------------------------------------------------------------------------
886 #--------------------------------------------------------------------------
887
887
888 def _maybe_raise(self, result):
888 def _maybe_raise(self, result):
889 """wrapper for maybe raising an exception if apply failed."""
889 """wrapper for maybe raising an exception if apply failed."""
890 if isinstance(result, error.RemoteError):
890 if isinstance(result, error.RemoteError):
891 raise result
891 raise result
892
892
893 return result
893 return result
894
894
895 def send_apply_message(self, socket, f, args=None, kwargs=None, subheader=None, track=False,
895 def send_apply_message(self, socket, f, args=None, kwargs=None, subheader=None, track=False,
896 ident=None):
896 ident=None):
897 """construct and send an apply message via a socket.
897 """construct and send an apply message via a socket.
898
898
899 This is the principal method with which all engine execution is performed by views.
899 This is the principal method with which all engine execution is performed by views.
900 """
900 """
901
901
902 assert not self._closed, "cannot use me anymore, I'm closed!"
902 assert not self._closed, "cannot use me anymore, I'm closed!"
903 # defaults:
903 # defaults:
904 args = args if args is not None else []
904 args = args if args is not None else []
905 kwargs = kwargs if kwargs is not None else {}
905 kwargs = kwargs if kwargs is not None else {}
906 subheader = subheader if subheader is not None else {}
906 subheader = subheader if subheader is not None else {}
907
907
908 # validate arguments
908 # validate arguments
909 if not callable(f):
909 if not callable(f):
910 raise TypeError("f must be callable, not %s"%type(f))
910 raise TypeError("f must be callable, not %s"%type(f))
911 if not isinstance(args, (tuple, list)):
911 if not isinstance(args, (tuple, list)):
912 raise TypeError("args must be tuple or list, not %s"%type(args))
912 raise TypeError("args must be tuple or list, not %s"%type(args))
913 if not isinstance(kwargs, dict):
913 if not isinstance(kwargs, dict):
914 raise TypeError("kwargs must be dict, not %s"%type(kwargs))
914 raise TypeError("kwargs must be dict, not %s"%type(kwargs))
915 if not isinstance(subheader, dict):
915 if not isinstance(subheader, dict):
916 raise TypeError("subheader must be dict, not %s"%type(subheader))
916 raise TypeError("subheader must be dict, not %s"%type(subheader))
917
917
918 bufs = util.pack_apply_message(f,args,kwargs)
918 bufs = util.pack_apply_message(f,args,kwargs)
919
919
920 msg = self.session.send(socket, "apply_request", buffers=bufs, ident=ident,
920 msg = self.session.send(socket, "apply_request", buffers=bufs, ident=ident,
921 subheader=subheader, track=track)
921 subheader=subheader, track=track)
922
922
923 msg_id = msg['msg_id']
923 msg_id = msg['msg_id']
924 self.outstanding.add(msg_id)
924 self.outstanding.add(msg_id)
925 if ident:
925 if ident:
926 # possibly routed to a specific engine
926 # possibly routed to a specific engine
927 if isinstance(ident, list):
927 if isinstance(ident, list):
928 ident = ident[-1]
928 ident = ident[-1]
929 if ident in self._engines.values():
929 if ident in self._engines.values():
930 # save for later, in case of engine death
930 # save for later, in case of engine death
931 self._outstanding_dict[ident].add(msg_id)
931 self._outstanding_dict[ident].add(msg_id)
932 self.history.append(msg_id)
932 self.history.append(msg_id)
933 self.metadata[msg_id]['submitted'] = datetime.now()
933 self.metadata[msg_id]['submitted'] = datetime.now()
934
934
935 return msg
935 return msg
936
936
937 #--------------------------------------------------------------------------
937 #--------------------------------------------------------------------------
938 # construct a View object
938 # construct a View object
939 #--------------------------------------------------------------------------
939 #--------------------------------------------------------------------------
940
940
941 def load_balanced_view(self, targets=None):
941 def load_balanced_view(self, targets=None):
942 """construct a DirectView object.
942 """construct a DirectView object.
943
943
944 If no arguments are specified, create a LoadBalancedView
944 If no arguments are specified, create a LoadBalancedView
945 using all engines.
945 using all engines.
946
946
947 Parameters
947 Parameters
948 ----------
948 ----------
949
949
950 targets: list,slice,int,etc. [default: use all engines]
950 targets: list,slice,int,etc. [default: use all engines]
951 The subset of engines across which to load-balance
951 The subset of engines across which to load-balance
952 """
952 """
953 if targets is not None:
953 if targets is not None:
954 targets = self._build_targets(targets)[1]
954 targets = self._build_targets(targets)[1]
955 return LoadBalancedView(client=self, socket=self._task_socket, targets=targets)
955 return LoadBalancedView(client=self, socket=self._task_socket, targets=targets)
956
956
957 def direct_view(self, targets='all'):
957 def direct_view(self, targets='all'):
958 """construct a DirectView object.
958 """construct a DirectView object.
959
959
960 If no targets are specified, create a DirectView
960 If no targets are specified, create a DirectView
961 using all engines.
961 using all engines.
962
962
963 Parameters
963 Parameters
964 ----------
964 ----------
965
965
966 targets: list,slice,int,etc. [default: use all engines]
966 targets: list,slice,int,etc. [default: use all engines]
967 The engines to use for the View
967 The engines to use for the View
968 """
968 """
969 single = isinstance(targets, int)
969 single = isinstance(targets, int)
970 targets = self._build_targets(targets)[1]
970 targets = self._build_targets(targets)[1]
971 if single:
971 if single:
972 targets = targets[0]
972 targets = targets[0]
973 return DirectView(client=self, socket=self._mux_socket, targets=targets)
973 return DirectView(client=self, socket=self._mux_socket, targets=targets)
974
974
975 #--------------------------------------------------------------------------
975 #--------------------------------------------------------------------------
976 # Query methods
976 # Query methods
977 #--------------------------------------------------------------------------
977 #--------------------------------------------------------------------------
978
978
979 @spin_first
979 @spin_first
980 def get_result(self, indices_or_msg_ids=None, block=None):
980 def get_result(self, indices_or_msg_ids=None, block=None):
981 """Retrieve a result by msg_id or history index, wrapped in an AsyncResult object.
981 """Retrieve a result by msg_id or history index, wrapped in an AsyncResult object.
982
982
983 If the client already has the results, no request to the Hub will be made.
983 If the client already has the results, no request to the Hub will be made.
984
984
985 This is a convenient way to construct AsyncResult objects, which are wrappers
985 This is a convenient way to construct AsyncResult objects, which are wrappers
986 that include metadata about execution, and allow for awaiting results that
986 that include metadata about execution, and allow for awaiting results that
987 were not submitted by this Client.
987 were not submitted by this Client.
988
988
989 It can also be a convenient way to retrieve the metadata associated with
989 It can also be a convenient way to retrieve the metadata associated with
990 blocking execution, since it always retrieves
990 blocking execution, since it always retrieves
991
991
992 Examples
992 Examples
993 --------
993 --------
994 ::
994 ::
995
995
996 In [10]: r = client.apply()
996 In [10]: r = client.apply()
997
997
998 Parameters
998 Parameters
999 ----------
999 ----------
1000
1000
1001 indices_or_msg_ids : integer history index, str msg_id, or list of either
1001 indices_or_msg_ids : integer history index, str msg_id, or list of either
1002 The indices or msg_ids of indices to be retrieved
1002 The indices or msg_ids of indices to be retrieved
1003
1003
1004 block : bool
1004 block : bool
1005 Whether to wait for the result to be done
1005 Whether to wait for the result to be done
1006
1006
1007 Returns
1007 Returns
1008 -------
1008 -------
1009
1009
1010 AsyncResult
1010 AsyncResult
1011 A single AsyncResult object will always be returned.
1011 A single AsyncResult object will always be returned.
1012
1012
1013 AsyncHubResult
1013 AsyncHubResult
1014 A subclass of AsyncResult that retrieves results from the Hub
1014 A subclass of AsyncResult that retrieves results from the Hub
1015
1015
1016 """
1016 """
1017 block = self.block if block is None else block
1017 block = self.block if block is None else block
1018 if indices_or_msg_ids is None:
1018 if indices_or_msg_ids is None:
1019 indices_or_msg_ids = -1
1019 indices_or_msg_ids = -1
1020
1020
1021 if not isinstance(indices_or_msg_ids, (list,tuple)):
1021 if not isinstance(indices_or_msg_ids, (list,tuple)):
1022 indices_or_msg_ids = [indices_or_msg_ids]
1022 indices_or_msg_ids = [indices_or_msg_ids]
1023
1023
1024 theids = []
1024 theids = []
1025 for id in indices_or_msg_ids:
1025 for id in indices_or_msg_ids:
1026 if isinstance(id, int):
1026 if isinstance(id, int):
1027 id = self.history[id]
1027 id = self.history[id]
1028 if not isinstance(id, str):
1028 if not isinstance(id, str):
1029 raise TypeError("indices must be str or int, not %r"%id)
1029 raise TypeError("indices must be str or int, not %r"%id)
1030 theids.append(id)
1030 theids.append(id)
1031
1031
1032 local_ids = filter(lambda msg_id: msg_id in self.history or msg_id in self.results, theids)
1032 local_ids = filter(lambda msg_id: msg_id in self.history or msg_id in self.results, theids)
1033 remote_ids = filter(lambda msg_id: msg_id not in local_ids, theids)
1033 remote_ids = filter(lambda msg_id: msg_id not in local_ids, theids)
1034
1034
1035 if remote_ids:
1035 if remote_ids:
1036 ar = AsyncHubResult(self, msg_ids=theids)
1036 ar = AsyncHubResult(self, msg_ids=theids)
1037 else:
1037 else:
1038 ar = AsyncResult(self, msg_ids=theids)
1038 ar = AsyncResult(self, msg_ids=theids)
1039
1039
1040 if block:
1040 if block:
1041 ar.wait()
1041 ar.wait()
1042
1042
1043 return ar
1043 return ar
1044
1045 @spin_first
1046 def resubmit(self, indices_or_msg_ids=None, subheader=None, block=None):
1047 """Resubmit one or more tasks.
1048
1049 in-flight tasks may not be resubmitted.
1050
1051 Parameters
1052 ----------
1053
1054 indices_or_msg_ids : integer history index, str msg_id, or list of either
1055 The indices or msg_ids of indices to be retrieved
1056
1057 block : bool
1058 Whether to wait for the result to be done
1059
1060 Returns
1061 -------
1062
1063 AsyncHubResult
1064 A subclass of AsyncResult that retrieves results from the Hub
1065
1066 """
1067 block = self.block if block is None else block
1068 if indices_or_msg_ids is None:
1069 indices_or_msg_ids = -1
1070
1071 if not isinstance(indices_or_msg_ids, (list,tuple)):
1072 indices_or_msg_ids = [indices_or_msg_ids]
1073
1074 theids = []
1075 for id in indices_or_msg_ids:
1076 if isinstance(id, int):
1077 id = self.history[id]
1078 if not isinstance(id, str):
1079 raise TypeError("indices must be str or int, not %r"%id)
1080 theids.append(id)
1081
1082 for msg_id in theids:
1083 self.outstanding.discard(msg_id)
1084 if msg_id in self.history:
1085 self.history.remove(msg_id)
1086 self.results.pop(msg_id, None)
1087 self.metadata.pop(msg_id, None)
1088 content = dict(msg_ids = theids)
1089
1090 self.session.send(self._query_socket, 'resubmit_request', content)
1091
1092 zmq.select([self._query_socket], [], [])
1093 idents,msg = self.session.recv(self._query_socket, zmq.NOBLOCK)
1094 if self.debug:
1095 pprint(msg)
1096 content = msg['content']
1097 if content['status'] != 'ok':
1098 raise self._unwrap_exception(content)
1099
1100 ar = AsyncHubResult(self, msg_ids=theids)
1101
1102 if block:
1103 ar.wait()
1104
1105 return ar
1044
1106
1045 @spin_first
1107 @spin_first
1046 def result_status(self, msg_ids, status_only=True):
1108 def result_status(self, msg_ids, status_only=True):
1047 """Check on the status of the result(s) of the apply request with `msg_ids`.
1109 """Check on the status of the result(s) of the apply request with `msg_ids`.
1048
1110
1049 If status_only is False, then the actual results will be retrieved, else
1111 If status_only is False, then the actual results will be retrieved, else
1050 only the status of the results will be checked.
1112 only the status of the results will be checked.
1051
1113
1052 Parameters
1114 Parameters
1053 ----------
1115 ----------
1054
1116
1055 msg_ids : list of msg_ids
1117 msg_ids : list of msg_ids
1056 if int:
1118 if int:
1057 Passed as index to self.history for convenience.
1119 Passed as index to self.history for convenience.
1058 status_only : bool (default: True)
1120 status_only : bool (default: True)
1059 if False:
1121 if False:
1060 Retrieve the actual results of completed tasks.
1122 Retrieve the actual results of completed tasks.
1061
1123
1062 Returns
1124 Returns
1063 -------
1125 -------
1064
1126
1065 results : dict
1127 results : dict
1066 There will always be the keys 'pending' and 'completed', which will
1128 There will always be the keys 'pending' and 'completed', which will
1067 be lists of msg_ids that are incomplete or complete. If `status_only`
1129 be lists of msg_ids that are incomplete or complete. If `status_only`
1068 is False, then completed results will be keyed by their `msg_id`.
1130 is False, then completed results will be keyed by their `msg_id`.
1069 """
1131 """
1070 if not isinstance(msg_ids, (list,tuple)):
1132 if not isinstance(msg_ids, (list,tuple)):
1071 msg_ids = [msg_ids]
1133 msg_ids = [msg_ids]
1072
1134
1073 theids = []
1135 theids = []
1074 for msg_id in msg_ids:
1136 for msg_id in msg_ids:
1075 if isinstance(msg_id, int):
1137 if isinstance(msg_id, int):
1076 msg_id = self.history[msg_id]
1138 msg_id = self.history[msg_id]
1077 if not isinstance(msg_id, basestring):
1139 if not isinstance(msg_id, basestring):
1078 raise TypeError("msg_ids must be str, not %r"%msg_id)
1140 raise TypeError("msg_ids must be str, not %r"%msg_id)
1079 theids.append(msg_id)
1141 theids.append(msg_id)
1080
1142
1081 completed = []
1143 completed = []
1082 local_results = {}
1144 local_results = {}
1083
1145
1084 # comment this block out to temporarily disable local shortcut:
1146 # comment this block out to temporarily disable local shortcut:
1085 for msg_id in theids:
1147 for msg_id in theids:
1086 if msg_id in self.results:
1148 if msg_id in self.results:
1087 completed.append(msg_id)
1149 completed.append(msg_id)
1088 local_results[msg_id] = self.results[msg_id]
1150 local_results[msg_id] = self.results[msg_id]
1089 theids.remove(msg_id)
1151 theids.remove(msg_id)
1090
1152
1091 if theids: # some not locally cached
1153 if theids: # some not locally cached
1092 content = dict(msg_ids=theids, status_only=status_only)
1154 content = dict(msg_ids=theids, status_only=status_only)
1093 msg = self.session.send(self._query_socket, "result_request", content=content)
1155 msg = self.session.send(self._query_socket, "result_request", content=content)
1094 zmq.select([self._query_socket], [], [])
1156 zmq.select([self._query_socket], [], [])
1095 idents,msg = self.session.recv(self._query_socket, zmq.NOBLOCK)
1157 idents,msg = self.session.recv(self._query_socket, zmq.NOBLOCK)
1096 if self.debug:
1158 if self.debug:
1097 pprint(msg)
1159 pprint(msg)
1098 content = msg['content']
1160 content = msg['content']
1099 if content['status'] != 'ok':
1161 if content['status'] != 'ok':
1100 raise self._unwrap_exception(content)
1162 raise self._unwrap_exception(content)
1101 buffers = msg['buffers']
1163 buffers = msg['buffers']
1102 else:
1164 else:
1103 content = dict(completed=[],pending=[])
1165 content = dict(completed=[],pending=[])
1104
1166
1105 content['completed'].extend(completed)
1167 content['completed'].extend(completed)
1106
1168
1107 if status_only:
1169 if status_only:
1108 return content
1170 return content
1109
1171
1110 failures = []
1172 failures = []
1111 # load cached results into result:
1173 # load cached results into result:
1112 content.update(local_results)
1174 content.update(local_results)
1113 # update cache with results:
1175 # update cache with results:
1114 for msg_id in sorted(theids):
1176 for msg_id in sorted(theids):
1115 if msg_id in content['completed']:
1177 if msg_id in content['completed']:
1116 rec = content[msg_id]
1178 rec = content[msg_id]
1117 parent = rec['header']
1179 parent = rec['header']
1118 header = rec['result_header']
1180 header = rec['result_header']
1119 rcontent = rec['result_content']
1181 rcontent = rec['result_content']
1120 iodict = rec['io']
1182 iodict = rec['io']
1121 if isinstance(rcontent, str):
1183 if isinstance(rcontent, str):
1122 rcontent = self.session.unpack(rcontent)
1184 rcontent = self.session.unpack(rcontent)
1123
1185
1124 md = self.metadata[msg_id]
1186 md = self.metadata[msg_id]
1125 md.update(self._extract_metadata(header, parent, rcontent))
1187 md.update(self._extract_metadata(header, parent, rcontent))
1126 md.update(iodict)
1188 md.update(iodict)
1127
1189
1128 if rcontent['status'] == 'ok':
1190 if rcontent['status'] == 'ok':
1129 res,buffers = util.unserialize_object(buffers)
1191 res,buffers = util.unserialize_object(buffers)
1130 else:
1192 else:
1131 print rcontent
1193 print rcontent
1132 res = self._unwrap_exception(rcontent)
1194 res = self._unwrap_exception(rcontent)
1133 failures.append(res)
1195 failures.append(res)
1134
1196
1135 self.results[msg_id] = res
1197 self.results[msg_id] = res
1136 content[msg_id] = res
1198 content[msg_id] = res
1137
1199
1138 if len(theids) == 1 and failures:
1200 if len(theids) == 1 and failures:
1139 raise failures[0]
1201 raise failures[0]
1140
1202
1141 error.collect_exceptions(failures, "result_status")
1203 error.collect_exceptions(failures, "result_status")
1142 return content
1204 return content
1143
1205
1144 @spin_first
1206 @spin_first
1145 def queue_status(self, targets='all', verbose=False):
1207 def queue_status(self, targets='all', verbose=False):
1146 """Fetch the status of engine queues.
1208 """Fetch the status of engine queues.
1147
1209
1148 Parameters
1210 Parameters
1149 ----------
1211 ----------
1150
1212
1151 targets : int/str/list of ints/strs
1213 targets : int/str/list of ints/strs
1152 the engines whose states are to be queried.
1214 the engines whose states are to be queried.
1153 default : all
1215 default : all
1154 verbose : bool
1216 verbose : bool
1155 Whether to return lengths only, or lists of ids for each element
1217 Whether to return lengths only, or lists of ids for each element
1156 """
1218 """
1157 engine_ids = self._build_targets(targets)[1]
1219 engine_ids = self._build_targets(targets)[1]
1158 content = dict(targets=engine_ids, verbose=verbose)
1220 content = dict(targets=engine_ids, verbose=verbose)
1159 self.session.send(self._query_socket, "queue_request", content=content)
1221 self.session.send(self._query_socket, "queue_request", content=content)
1160 idents,msg = self.session.recv(self._query_socket, 0)
1222 idents,msg = self.session.recv(self._query_socket, 0)
1161 if self.debug:
1223 if self.debug:
1162 pprint(msg)
1224 pprint(msg)
1163 content = msg['content']
1225 content = msg['content']
1164 status = content.pop('status')
1226 status = content.pop('status')
1165 if status != 'ok':
1227 if status != 'ok':
1166 raise self._unwrap_exception(content)
1228 raise self._unwrap_exception(content)
1167 content = util.rekey(content)
1229 content = util.rekey(content)
1168 if isinstance(targets, int):
1230 if isinstance(targets, int):
1169 return content[targets]
1231 return content[targets]
1170 else:
1232 else:
1171 return content
1233 return content
1172
1234
1173 @spin_first
1235 @spin_first
1174 def purge_results(self, jobs=[], targets=[]):
1236 def purge_results(self, jobs=[], targets=[]):
1175 """Tell the Hub to forget results.
1237 """Tell the Hub to forget results.
1176
1238
1177 Individual results can be purged by msg_id, or the entire
1239 Individual results can be purged by msg_id, or the entire
1178 history of specific targets can be purged.
1240 history of specific targets can be purged.
1179
1241
1180 Parameters
1242 Parameters
1181 ----------
1243 ----------
1182
1244
1183 jobs : str or list of str or AsyncResult objects
1245 jobs : str or list of str or AsyncResult objects
1184 the msg_ids whose results should be forgotten.
1246 the msg_ids whose results should be forgotten.
1185 targets : int/str/list of ints/strs
1247 targets : int/str/list of ints/strs
1186 The targets, by uuid or int_id, whose entire history is to be purged.
1248 The targets, by uuid or int_id, whose entire history is to be purged.
1187 Use `targets='all'` to scrub everything from the Hub's memory.
1249 Use `targets='all'` to scrub everything from the Hub's memory.
1188
1250
1189 default : None
1251 default : None
1190 """
1252 """
1191 if not targets and not jobs:
1253 if not targets and not jobs:
1192 raise ValueError("Must specify at least one of `targets` and `jobs`")
1254 raise ValueError("Must specify at least one of `targets` and `jobs`")
1193 if targets:
1255 if targets:
1194 targets = self._build_targets(targets)[1]
1256 targets = self._build_targets(targets)[1]
1195
1257
1196 # construct msg_ids from jobs
1258 # construct msg_ids from jobs
1197 msg_ids = []
1259 msg_ids = []
1198 if isinstance(jobs, (basestring,AsyncResult)):
1260 if isinstance(jobs, (basestring,AsyncResult)):
1199 jobs = [jobs]
1261 jobs = [jobs]
1200 bad_ids = filter(lambda obj: not isinstance(obj, (basestring, AsyncResult)), jobs)
1262 bad_ids = filter(lambda obj: not isinstance(obj, (basestring, AsyncResult)), jobs)
1201 if bad_ids:
1263 if bad_ids:
1202 raise TypeError("Invalid msg_id type %r, expected str or AsyncResult"%bad_ids[0])
1264 raise TypeError("Invalid msg_id type %r, expected str or AsyncResult"%bad_ids[0])
1203 for j in jobs:
1265 for j in jobs:
1204 if isinstance(j, AsyncResult):
1266 if isinstance(j, AsyncResult):
1205 msg_ids.extend(j.msg_ids)
1267 msg_ids.extend(j.msg_ids)
1206 else:
1268 else:
1207 msg_ids.append(j)
1269 msg_ids.append(j)
1208
1270
1209 content = dict(targets=targets, msg_ids=msg_ids)
1271 content = dict(targets=targets, msg_ids=msg_ids)
1210 self.session.send(self._query_socket, "purge_request", content=content)
1272 self.session.send(self._query_socket, "purge_request", content=content)
1211 idents, msg = self.session.recv(self._query_socket, 0)
1273 idents, msg = self.session.recv(self._query_socket, 0)
1212 if self.debug:
1274 if self.debug:
1213 pprint(msg)
1275 pprint(msg)
1214 content = msg['content']
1276 content = msg['content']
1215 if content['status'] != 'ok':
1277 if content['status'] != 'ok':
1216 raise self._unwrap_exception(content)
1278 raise self._unwrap_exception(content)
1217
1279
1218 @spin_first
1280 @spin_first
1219 def hub_history(self):
1281 def hub_history(self):
1220 """Get the Hub's history
1282 """Get the Hub's history
1221
1283
1222 Just like the Client, the Hub has a history, which is a list of msg_ids.
1284 Just like the Client, the Hub has a history, which is a list of msg_ids.
1223 This will contain the history of all clients, and, depending on configuration,
1285 This will contain the history of all clients, and, depending on configuration,
1224 may contain history across multiple cluster sessions.
1286 may contain history across multiple cluster sessions.
1225
1287
1226 Any msg_id returned here is a valid argument to `get_result`.
1288 Any msg_id returned here is a valid argument to `get_result`.
1227
1289
1228 Returns
1290 Returns
1229 -------
1291 -------
1230
1292
1231 msg_ids : list of strs
1293 msg_ids : list of strs
1232 list of all msg_ids, ordered by task submission time.
1294 list of all msg_ids, ordered by task submission time.
1233 """
1295 """
1234
1296
1235 self.session.send(self._query_socket, "history_request", content={})
1297 self.session.send(self._query_socket, "history_request", content={})
1236 idents, msg = self.session.recv(self._query_socket, 0)
1298 idents, msg = self.session.recv(self._query_socket, 0)
1237
1299
1238 if self.debug:
1300 if self.debug:
1239 pprint(msg)
1301 pprint(msg)
1240 content = msg['content']
1302 content = msg['content']
1241 if content['status'] != 'ok':
1303 if content['status'] != 'ok':
1242 raise self._unwrap_exception(content)
1304 raise self._unwrap_exception(content)
1243 else:
1305 else:
1244 return content['history']
1306 return content['history']
1245
1307
1246 @spin_first
1308 @spin_first
1247 def db_query(self, query, keys=None):
1309 def db_query(self, query, keys=None):
1248 """Query the Hub's TaskRecord database
1310 """Query the Hub's TaskRecord database
1249
1311
1250 This will return a list of task record dicts that match `query`
1312 This will return a list of task record dicts that match `query`
1251
1313
1252 Parameters
1314 Parameters
1253 ----------
1315 ----------
1254
1316
1255 query : mongodb query dict
1317 query : mongodb query dict
1256 The search dict. See mongodb query docs for details.
1318 The search dict. See mongodb query docs for details.
1257 keys : list of strs [optional]
1319 keys : list of strs [optional]
1258 THe subset of keys to be returned. The default is to fetch everything.
1320 THe subset of keys to be returned. The default is to fetch everything.
1259 'msg_id' will *always* be included.
1321 'msg_id' will *always* be included.
1260 """
1322 """
1261 content = dict(query=query, keys=keys)
1323 content = dict(query=query, keys=keys)
1262 self.session.send(self._query_socket, "db_request", content=content)
1324 self.session.send(self._query_socket, "db_request", content=content)
1263 idents, msg = self.session.recv(self._query_socket, 0)
1325 idents, msg = self.session.recv(self._query_socket, 0)
1264 if self.debug:
1326 if self.debug:
1265 pprint(msg)
1327 pprint(msg)
1266 content = msg['content']
1328 content = msg['content']
1267 if content['status'] != 'ok':
1329 if content['status'] != 'ok':
1268 raise self._unwrap_exception(content)
1330 raise self._unwrap_exception(content)
1269
1331
1270 records = content['records']
1332 records = content['records']
1271 buffer_lens = content['buffer_lens']
1333 buffer_lens = content['buffer_lens']
1272 result_buffer_lens = content['result_buffer_lens']
1334 result_buffer_lens = content['result_buffer_lens']
1273 buffers = msg['buffers']
1335 buffers = msg['buffers']
1274 has_bufs = buffer_lens is not None
1336 has_bufs = buffer_lens is not None
1275 has_rbufs = result_buffer_lens is not None
1337 has_rbufs = result_buffer_lens is not None
1276 for i,rec in enumerate(records):
1338 for i,rec in enumerate(records):
1277 # relink buffers
1339 # relink buffers
1278 if has_bufs:
1340 if has_bufs:
1279 blen = buffer_lens[i]
1341 blen = buffer_lens[i]
1280 rec['buffers'], buffers = buffers[:blen],buffers[blen:]
1342 rec['buffers'], buffers = buffers[:blen],buffers[blen:]
1281 if has_rbufs:
1343 if has_rbufs:
1282 blen = result_buffer_lens[i]
1344 blen = result_buffer_lens[i]
1283 rec['result_buffers'], buffers = buffers[:blen],buffers[blen:]
1345 rec['result_buffers'], buffers = buffers[:blen],buffers[blen:]
1284 # turn timestamps back into times
1346 # turn timestamps back into times
1285 for key in 'submitted started completed resubmitted'.split():
1347 for key in 'submitted started completed resubmitted'.split():
1286 maybedate = rec.get(key, None)
1348 maybedate = rec.get(key, None)
1287 if maybedate and util.ISO8601_RE.match(maybedate):
1349 if maybedate and util.ISO8601_RE.match(maybedate):
1288 rec[key] = datetime.strptime(maybedate, util.ISO8601)
1350 rec[key] = datetime.strptime(maybedate, util.ISO8601)
1289
1351
1290 return records
1352 return records
1291
1353
1292 __all__ = [ 'Client' ]
1354 __all__ = [ 'Client' ]
@@ -1,1193 +1,1284 b''
1 #!/usr/bin/env python
1 #!/usr/bin/env python
2 """The IPython Controller Hub with 0MQ
2 """The IPython Controller Hub with 0MQ
3 This is the master object that handles connections from engines and clients,
3 This is the master object that handles connections from engines and clients,
4 and monitors traffic through the various queues.
4 and monitors traffic through the various queues.
5 """
5 """
6 #-----------------------------------------------------------------------------
6 #-----------------------------------------------------------------------------
7 # Copyright (C) 2010 The IPython Development Team
7 # Copyright (C) 2010 The IPython Development Team
8 #
8 #
9 # Distributed under the terms of the BSD License. The full license is in
9 # Distributed under the terms of the BSD License. The full license is in
10 # the file COPYING, distributed as part of this software.
10 # the file COPYING, distributed as part of this software.
11 #-----------------------------------------------------------------------------
11 #-----------------------------------------------------------------------------
12
12
13 #-----------------------------------------------------------------------------
13 #-----------------------------------------------------------------------------
14 # Imports
14 # Imports
15 #-----------------------------------------------------------------------------
15 #-----------------------------------------------------------------------------
16 from __future__ import print_function
16 from __future__ import print_function
17
17
18 import sys
18 import sys
19 import time
19 import time
20 from datetime import datetime
20 from datetime import datetime
21
21
22 import zmq
22 import zmq
23 from zmq.eventloop import ioloop
23 from zmq.eventloop import ioloop
24 from zmq.eventloop.zmqstream import ZMQStream
24 from zmq.eventloop.zmqstream import ZMQStream
25
25
26 # internal:
26 # internal:
27 from IPython.utils.importstring import import_item
27 from IPython.utils.importstring import import_item
28 from IPython.utils.traitlets import HasTraits, Instance, Int, CStr, Str, Dict, Set, List, Bool
28 from IPython.utils.traitlets import HasTraits, Instance, Int, CStr, Str, Dict, Set, List, Bool
29
29
30 from IPython.parallel import error, util
30 from IPython.parallel import error, util
31 from IPython.parallel.factory import RegistrationFactory, LoggingFactory
31 from IPython.parallel.factory import RegistrationFactory, LoggingFactory
32
32
33 from .heartmonitor import HeartMonitor
33 from .heartmonitor import HeartMonitor
34
34
35 #-----------------------------------------------------------------------------
35 #-----------------------------------------------------------------------------
36 # Code
36 # Code
37 #-----------------------------------------------------------------------------
37 #-----------------------------------------------------------------------------
38
38
39 def _passer(*args, **kwargs):
39 def _passer(*args, **kwargs):
40 return
40 return
41
41
42 def _printer(*args, **kwargs):
42 def _printer(*args, **kwargs):
43 print (args)
43 print (args)
44 print (kwargs)
44 print (kwargs)
45
45
46 def empty_record():
46 def empty_record():
47 """Return an empty dict with all record keys."""
47 """Return an empty dict with all record keys."""
48 return {
48 return {
49 'msg_id' : None,
49 'msg_id' : None,
50 'header' : None,
50 'header' : None,
51 'content': None,
51 'content': None,
52 'buffers': None,
52 'buffers': None,
53 'submitted': None,
53 'submitted': None,
54 'client_uuid' : None,
54 'client_uuid' : None,
55 'engine_uuid' : None,
55 'engine_uuid' : None,
56 'started': None,
56 'started': None,
57 'completed': None,
57 'completed': None,
58 'resubmitted': None,
58 'resubmitted': None,
59 'result_header' : None,
59 'result_header' : None,
60 'result_content' : None,
60 'result_content' : None,
61 'result_buffers' : None,
61 'result_buffers' : None,
62 'queue' : None,
62 'queue' : None,
63 'pyin' : None,
63 'pyin' : None,
64 'pyout': None,
64 'pyout': None,
65 'pyerr': None,
65 'pyerr': None,
66 'stdout': '',
66 'stdout': '',
67 'stderr': '',
67 'stderr': '',
68 }
68 }
69
69
70 def init_record(msg):
70 def init_record(msg):
71 """Initialize a TaskRecord based on a request."""
71 """Initialize a TaskRecord based on a request."""
72 header = msg['header']
72 header = msg['header']
73 return {
73 return {
74 'msg_id' : header['msg_id'],
74 'msg_id' : header['msg_id'],
75 'header' : header,
75 'header' : header,
76 'content': msg['content'],
76 'content': msg['content'],
77 'buffers': msg['buffers'],
77 'buffers': msg['buffers'],
78 'submitted': datetime.strptime(header['date'], util.ISO8601),
78 'submitted': datetime.strptime(header['date'], util.ISO8601),
79 'client_uuid' : None,
79 'client_uuid' : None,
80 'engine_uuid' : None,
80 'engine_uuid' : None,
81 'started': None,
81 'started': None,
82 'completed': None,
82 'completed': None,
83 'resubmitted': None,
83 'resubmitted': None,
84 'result_header' : None,
84 'result_header' : None,
85 'result_content' : None,
85 'result_content' : None,
86 'result_buffers' : None,
86 'result_buffers' : None,
87 'queue' : None,
87 'queue' : None,
88 'pyin' : None,
88 'pyin' : None,
89 'pyout': None,
89 'pyout': None,
90 'pyerr': None,
90 'pyerr': None,
91 'stdout': '',
91 'stdout': '',
92 'stderr': '',
92 'stderr': '',
93 }
93 }
94
94
95
95
96 class EngineConnector(HasTraits):
96 class EngineConnector(HasTraits):
97 """A simple object for accessing the various zmq connections of an object.
97 """A simple object for accessing the various zmq connections of an object.
98 Attributes are:
98 Attributes are:
99 id (int): engine ID
99 id (int): engine ID
100 uuid (str): uuid (unused?)
100 uuid (str): uuid (unused?)
101 queue (str): identity of queue's XREQ socket
101 queue (str): identity of queue's XREQ socket
102 registration (str): identity of registration XREQ socket
102 registration (str): identity of registration XREQ socket
103 heartbeat (str): identity of heartbeat XREQ socket
103 heartbeat (str): identity of heartbeat XREQ socket
104 """
104 """
105 id=Int(0)
105 id=Int(0)
106 queue=Str()
106 queue=Str()
107 control=Str()
107 control=Str()
108 registration=Str()
108 registration=Str()
109 heartbeat=Str()
109 heartbeat=Str()
110 pending=Set()
110 pending=Set()
111
111
112 class HubFactory(RegistrationFactory):
112 class HubFactory(RegistrationFactory):
113 """The Configurable for setting up a Hub."""
113 """The Configurable for setting up a Hub."""
114
114
115 # name of a scheduler scheme
115 # name of a scheduler scheme
116 scheme = Str('leastload', config=True)
116 scheme = Str('leastload', config=True)
117
117
118 # port-pairs for monitoredqueues:
118 # port-pairs for monitoredqueues:
119 hb = Instance(list, config=True)
119 hb = Instance(list, config=True)
120 def _hb_default(self):
120 def _hb_default(self):
121 return util.select_random_ports(2)
121 return util.select_random_ports(2)
122
122
123 mux = Instance(list, config=True)
123 mux = Instance(list, config=True)
124 def _mux_default(self):
124 def _mux_default(self):
125 return util.select_random_ports(2)
125 return util.select_random_ports(2)
126
126
127 task = Instance(list, config=True)
127 task = Instance(list, config=True)
128 def _task_default(self):
128 def _task_default(self):
129 return util.select_random_ports(2)
129 return util.select_random_ports(2)
130
130
131 control = Instance(list, config=True)
131 control = Instance(list, config=True)
132 def _control_default(self):
132 def _control_default(self):
133 return util.select_random_ports(2)
133 return util.select_random_ports(2)
134
134
135 iopub = Instance(list, config=True)
135 iopub = Instance(list, config=True)
136 def _iopub_default(self):
136 def _iopub_default(self):
137 return util.select_random_ports(2)
137 return util.select_random_ports(2)
138
138
139 # single ports:
139 # single ports:
140 mon_port = Instance(int, config=True)
140 mon_port = Instance(int, config=True)
141 def _mon_port_default(self):
141 def _mon_port_default(self):
142 return util.select_random_ports(1)[0]
142 return util.select_random_ports(1)[0]
143
143
144 notifier_port = Instance(int, config=True)
144 notifier_port = Instance(int, config=True)
145 def _notifier_port_default(self):
145 def _notifier_port_default(self):
146 return util.select_random_ports(1)[0]
146 return util.select_random_ports(1)[0]
147
147
148 ping = Int(1000, config=True) # ping frequency
148 ping = Int(1000, config=True) # ping frequency
149
149
150 engine_ip = CStr('127.0.0.1', config=True)
150 engine_ip = CStr('127.0.0.1', config=True)
151 engine_transport = CStr('tcp', config=True)
151 engine_transport = CStr('tcp', config=True)
152
152
153 client_ip = CStr('127.0.0.1', config=True)
153 client_ip = CStr('127.0.0.1', config=True)
154 client_transport = CStr('tcp', config=True)
154 client_transport = CStr('tcp', config=True)
155
155
156 monitor_ip = CStr('127.0.0.1', config=True)
156 monitor_ip = CStr('127.0.0.1', config=True)
157 monitor_transport = CStr('tcp', config=True)
157 monitor_transport = CStr('tcp', config=True)
158
158
159 monitor_url = CStr('')
159 monitor_url = CStr('')
160
160
161 db_class = CStr('IPython.parallel.controller.dictdb.DictDB', config=True)
161 db_class = CStr('IPython.parallel.controller.dictdb.DictDB', config=True)
162
162
163 # not configurable
163 # not configurable
164 db = Instance('IPython.parallel.controller.dictdb.BaseDB')
164 db = Instance('IPython.parallel.controller.dictdb.BaseDB')
165 heartmonitor = Instance('IPython.parallel.controller.heartmonitor.HeartMonitor')
165 heartmonitor = Instance('IPython.parallel.controller.heartmonitor.HeartMonitor')
166 subconstructors = List()
166 subconstructors = List()
167 _constructed = Bool(False)
167 _constructed = Bool(False)
168
168
169 def _ip_changed(self, name, old, new):
169 def _ip_changed(self, name, old, new):
170 self.engine_ip = new
170 self.engine_ip = new
171 self.client_ip = new
171 self.client_ip = new
172 self.monitor_ip = new
172 self.monitor_ip = new
173 self._update_monitor_url()
173 self._update_monitor_url()
174
174
175 def _update_monitor_url(self):
175 def _update_monitor_url(self):
176 self.monitor_url = "%s://%s:%i"%(self.monitor_transport, self.monitor_ip, self.mon_port)
176 self.monitor_url = "%s://%s:%i"%(self.monitor_transport, self.monitor_ip, self.mon_port)
177
177
178 def _transport_changed(self, name, old, new):
178 def _transport_changed(self, name, old, new):
179 self.engine_transport = new
179 self.engine_transport = new
180 self.client_transport = new
180 self.client_transport = new
181 self.monitor_transport = new
181 self.monitor_transport = new
182 self._update_monitor_url()
182 self._update_monitor_url()
183
183
184 def __init__(self, **kwargs):
184 def __init__(self, **kwargs):
185 super(HubFactory, self).__init__(**kwargs)
185 super(HubFactory, self).__init__(**kwargs)
186 self._update_monitor_url()
186 self._update_monitor_url()
187 # self.on_trait_change(self._sync_ips, 'ip')
187 # self.on_trait_change(self._sync_ips, 'ip')
188 # self.on_trait_change(self._sync_transports, 'transport')
188 # self.on_trait_change(self._sync_transports, 'transport')
189 self.subconstructors.append(self.construct_hub)
189 self.subconstructors.append(self.construct_hub)
190
190
191
191
192 def construct(self):
192 def construct(self):
193 assert not self._constructed, "already constructed!"
193 assert not self._constructed, "already constructed!"
194
194
195 for subc in self.subconstructors:
195 for subc in self.subconstructors:
196 subc()
196 subc()
197
197
198 self._constructed = True
198 self._constructed = True
199
199
200
200
201 def start(self):
201 def start(self):
202 assert self._constructed, "must be constructed by self.construct() first!"
202 assert self._constructed, "must be constructed by self.construct() first!"
203 self.heartmonitor.start()
203 self.heartmonitor.start()
204 self.log.info("Heartmonitor started")
204 self.log.info("Heartmonitor started")
205
205
206 def construct_hub(self):
206 def construct_hub(self):
207 """construct"""
207 """construct"""
208 client_iface = "%s://%s:"%(self.client_transport, self.client_ip) + "%i"
208 client_iface = "%s://%s:"%(self.client_transport, self.client_ip) + "%i"
209 engine_iface = "%s://%s:"%(self.engine_transport, self.engine_ip) + "%i"
209 engine_iface = "%s://%s:"%(self.engine_transport, self.engine_ip) + "%i"
210
210
211 ctx = self.context
211 ctx = self.context
212 loop = self.loop
212 loop = self.loop
213
213
214 # Registrar socket
214 # Registrar socket
215 q = ZMQStream(ctx.socket(zmq.XREP), loop)
215 q = ZMQStream(ctx.socket(zmq.XREP), loop)
216 q.bind(client_iface % self.regport)
216 q.bind(client_iface % self.regport)
217 self.log.info("Hub listening on %s for registration."%(client_iface%self.regport))
217 self.log.info("Hub listening on %s for registration."%(client_iface%self.regport))
218 if self.client_ip != self.engine_ip:
218 if self.client_ip != self.engine_ip:
219 q.bind(engine_iface % self.regport)
219 q.bind(engine_iface % self.regport)
220 self.log.info("Hub listening on %s for registration."%(engine_iface%self.regport))
220 self.log.info("Hub listening on %s for registration."%(engine_iface%self.regport))
221
221
222 ### Engine connections ###
222 ### Engine connections ###
223
223
224 # heartbeat
224 # heartbeat
225 hpub = ctx.socket(zmq.PUB)
225 hpub = ctx.socket(zmq.PUB)
226 hpub.bind(engine_iface % self.hb[0])
226 hpub.bind(engine_iface % self.hb[0])
227 hrep = ctx.socket(zmq.XREP)
227 hrep = ctx.socket(zmq.XREP)
228 hrep.bind(engine_iface % self.hb[1])
228 hrep.bind(engine_iface % self.hb[1])
229 self.heartmonitor = HeartMonitor(loop=loop, pingstream=ZMQStream(hpub,loop), pongstream=ZMQStream(hrep,loop),
229 self.heartmonitor = HeartMonitor(loop=loop, pingstream=ZMQStream(hpub,loop), pongstream=ZMQStream(hrep,loop),
230 period=self.ping, logname=self.log.name)
230 period=self.ping, logname=self.log.name)
231
231
232 ### Client connections ###
232 ### Client connections ###
233 # Notifier socket
233 # Notifier socket
234 n = ZMQStream(ctx.socket(zmq.PUB), loop)
234 n = ZMQStream(ctx.socket(zmq.PUB), loop)
235 n.bind(client_iface%self.notifier_port)
235 n.bind(client_iface%self.notifier_port)
236
236
237 ### build and launch the queues ###
237 ### build and launch the queues ###
238
238
239 # monitor socket
239 # monitor socket
240 sub = ctx.socket(zmq.SUB)
240 sub = ctx.socket(zmq.SUB)
241 sub.setsockopt(zmq.SUBSCRIBE, "")
241 sub.setsockopt(zmq.SUBSCRIBE, "")
242 sub.bind(self.monitor_url)
242 sub.bind(self.monitor_url)
243 sub.bind('inproc://monitor')
243 sub.bind('inproc://monitor')
244 sub = ZMQStream(sub, loop)
244 sub = ZMQStream(sub, loop)
245
245
246 # connect the db
246 # connect the db
247 self.log.info('Hub using DB backend: %r'%(self.db_class.split()[-1]))
247 self.log.info('Hub using DB backend: %r'%(self.db_class.split()[-1]))
248 # cdir = self.config.Global.cluster_dir
248 # cdir = self.config.Global.cluster_dir
249 self.db = import_item(self.db_class)(session=self.session.session, config=self.config)
249 self.db = import_item(self.db_class)(session=self.session.session, config=self.config)
250 time.sleep(.25)
250 time.sleep(.25)
251
251
252 # build connection dicts
252 # build connection dicts
253 self.engine_info = {
253 self.engine_info = {
254 'control' : engine_iface%self.control[1],
254 'control' : engine_iface%self.control[1],
255 'mux': engine_iface%self.mux[1],
255 'mux': engine_iface%self.mux[1],
256 'heartbeat': (engine_iface%self.hb[0], engine_iface%self.hb[1]),
256 'heartbeat': (engine_iface%self.hb[0], engine_iface%self.hb[1]),
257 'task' : engine_iface%self.task[1],
257 'task' : engine_iface%self.task[1],
258 'iopub' : engine_iface%self.iopub[1],
258 'iopub' : engine_iface%self.iopub[1],
259 # 'monitor' : engine_iface%self.mon_port,
259 # 'monitor' : engine_iface%self.mon_port,
260 }
260 }
261
261
262 self.client_info = {
262 self.client_info = {
263 'control' : client_iface%self.control[0],
263 'control' : client_iface%self.control[0],
264 'mux': client_iface%self.mux[0],
264 'mux': client_iface%self.mux[0],
265 'task' : (self.scheme, client_iface%self.task[0]),
265 'task' : (self.scheme, client_iface%self.task[0]),
266 'iopub' : client_iface%self.iopub[0],
266 'iopub' : client_iface%self.iopub[0],
267 'notification': client_iface%self.notifier_port
267 'notification': client_iface%self.notifier_port
268 }
268 }
269 self.log.debug("Hub engine addrs: %s"%self.engine_info)
269 self.log.debug("Hub engine addrs: %s"%self.engine_info)
270 self.log.debug("Hub client addrs: %s"%self.client_info)
270 self.log.debug("Hub client addrs: %s"%self.client_info)
271
272 # resubmit stream
273 r = ZMQStream(ctx.socket(zmq.XREQ), loop)
274 url = util.disambiguate_url(self.client_info['task'][-1])
275 r.setsockopt(zmq.IDENTITY, self.session.session)
276 r.connect(url)
277
271 self.hub = Hub(loop=loop, session=self.session, monitor=sub, heartmonitor=self.heartmonitor,
278 self.hub = Hub(loop=loop, session=self.session, monitor=sub, heartmonitor=self.heartmonitor,
272 query=q, notifier=n, db=self.db,
279 query=q, notifier=n, resubmit=r, db=self.db,
273 engine_info=self.engine_info, client_info=self.client_info,
280 engine_info=self.engine_info, client_info=self.client_info,
274 logname=self.log.name)
281 logname=self.log.name)
275
282
276
283
277 class Hub(LoggingFactory):
284 class Hub(LoggingFactory):
278 """The IPython Controller Hub with 0MQ connections
285 """The IPython Controller Hub with 0MQ connections
279
286
280 Parameters
287 Parameters
281 ==========
288 ==========
282 loop: zmq IOLoop instance
289 loop: zmq IOLoop instance
283 session: StreamSession object
290 session: StreamSession object
284 <removed> context: zmq context for creating new connections (?)
291 <removed> context: zmq context for creating new connections (?)
285 queue: ZMQStream for monitoring the command queue (SUB)
292 queue: ZMQStream for monitoring the command queue (SUB)
286 query: ZMQStream for engine registration and client queries requests (XREP)
293 query: ZMQStream for engine registration and client queries requests (XREP)
287 heartbeat: HeartMonitor object checking the pulse of the engines
294 heartbeat: HeartMonitor object checking the pulse of the engines
288 notifier: ZMQStream for broadcasting engine registration changes (PUB)
295 notifier: ZMQStream for broadcasting engine registration changes (PUB)
289 db: connection to db for out of memory logging of commands
296 db: connection to db for out of memory logging of commands
290 NotImplemented
297 NotImplemented
291 engine_info: dict of zmq connection information for engines to connect
298 engine_info: dict of zmq connection information for engines to connect
292 to the queues.
299 to the queues.
293 client_info: dict of zmq connection information for engines to connect
300 client_info: dict of zmq connection information for engines to connect
294 to the queues.
301 to the queues.
295 """
302 """
296 # internal data structures:
303 # internal data structures:
297 ids=Set() # engine IDs
304 ids=Set() # engine IDs
298 keytable=Dict()
305 keytable=Dict()
299 by_ident=Dict()
306 by_ident=Dict()
300 engines=Dict()
307 engines=Dict()
301 clients=Dict()
308 clients=Dict()
302 hearts=Dict()
309 hearts=Dict()
303 pending=Set()
310 pending=Set()
304 queues=Dict() # pending msg_ids keyed by engine_id
311 queues=Dict() # pending msg_ids keyed by engine_id
305 tasks=Dict() # pending msg_ids submitted as tasks, keyed by client_id
312 tasks=Dict() # pending msg_ids submitted as tasks, keyed by client_id
306 completed=Dict() # completed msg_ids keyed by engine_id
313 completed=Dict() # completed msg_ids keyed by engine_id
307 all_completed=Set() # completed msg_ids keyed by engine_id
314 all_completed=Set() # completed msg_ids keyed by engine_id
308 dead_engines=Set() # completed msg_ids keyed by engine_id
315 dead_engines=Set() # completed msg_ids keyed by engine_id
309 unassigned=Set() # set of task msg_ds not yet assigned a destination
316 unassigned=Set() # set of task msg_ds not yet assigned a destination
310 incoming_registrations=Dict()
317 incoming_registrations=Dict()
311 registration_timeout=Int()
318 registration_timeout=Int()
312 _idcounter=Int(0)
319 _idcounter=Int(0)
313
320
314 # objects from constructor:
321 # objects from constructor:
315 loop=Instance(ioloop.IOLoop)
322 loop=Instance(ioloop.IOLoop)
316 query=Instance(ZMQStream)
323 query=Instance(ZMQStream)
317 monitor=Instance(ZMQStream)
324 monitor=Instance(ZMQStream)
318 heartmonitor=Instance(HeartMonitor)
319 notifier=Instance(ZMQStream)
325 notifier=Instance(ZMQStream)
326 resubmit=Instance(ZMQStream)
327 heartmonitor=Instance(HeartMonitor)
320 db=Instance(object)
328 db=Instance(object)
321 client_info=Dict()
329 client_info=Dict()
322 engine_info=Dict()
330 engine_info=Dict()
323
331
324
332
325 def __init__(self, **kwargs):
333 def __init__(self, **kwargs):
326 """
334 """
327 # universal:
335 # universal:
328 loop: IOLoop for creating future connections
336 loop: IOLoop for creating future connections
329 session: streamsession for sending serialized data
337 session: streamsession for sending serialized data
330 # engine:
338 # engine:
331 queue: ZMQStream for monitoring queue messages
339 queue: ZMQStream for monitoring queue messages
332 query: ZMQStream for engine+client registration and client requests
340 query: ZMQStream for engine+client registration and client requests
333 heartbeat: HeartMonitor object for tracking engines
341 heartbeat: HeartMonitor object for tracking engines
334 # extra:
342 # extra:
335 db: ZMQStream for db connection (NotImplemented)
343 db: ZMQStream for db connection (NotImplemented)
336 engine_info: zmq address/protocol dict for engine connections
344 engine_info: zmq address/protocol dict for engine connections
337 client_info: zmq address/protocol dict for client connections
345 client_info: zmq address/protocol dict for client connections
338 """
346 """
339
347
340 super(Hub, self).__init__(**kwargs)
348 super(Hub, self).__init__(**kwargs)
341 self.registration_timeout = max(5000, 2*self.heartmonitor.period)
349 self.registration_timeout = max(5000, 2*self.heartmonitor.period)
342
350
343 # validate connection dicts:
351 # validate connection dicts:
344 for k,v in self.client_info.iteritems():
352 for k,v in self.client_info.iteritems():
345 if k == 'task':
353 if k == 'task':
346 util.validate_url_container(v[1])
354 util.validate_url_container(v[1])
347 else:
355 else:
348 util.validate_url_container(v)
356 util.validate_url_container(v)
349 # util.validate_url_container(self.client_info)
357 # util.validate_url_container(self.client_info)
350 util.validate_url_container(self.engine_info)
358 util.validate_url_container(self.engine_info)
351
359
352 # register our callbacks
360 # register our callbacks
353 self.query.on_recv(self.dispatch_query)
361 self.query.on_recv(self.dispatch_query)
354 self.monitor.on_recv(self.dispatch_monitor_traffic)
362 self.monitor.on_recv(self.dispatch_monitor_traffic)
355
363
356 self.heartmonitor.add_heart_failure_handler(self.handle_heart_failure)
364 self.heartmonitor.add_heart_failure_handler(self.handle_heart_failure)
357 self.heartmonitor.add_new_heart_handler(self.handle_new_heart)
365 self.heartmonitor.add_new_heart_handler(self.handle_new_heart)
358
366
359 self.monitor_handlers = { 'in' : self.save_queue_request,
367 self.monitor_handlers = { 'in' : self.save_queue_request,
360 'out': self.save_queue_result,
368 'out': self.save_queue_result,
361 'intask': self.save_task_request,
369 'intask': self.save_task_request,
362 'outtask': self.save_task_result,
370 'outtask': self.save_task_result,
363 'tracktask': self.save_task_destination,
371 'tracktask': self.save_task_destination,
364 'incontrol': _passer,
372 'incontrol': _passer,
365 'outcontrol': _passer,
373 'outcontrol': _passer,
366 'iopub': self.save_iopub_message,
374 'iopub': self.save_iopub_message,
367 }
375 }
368
376
369 self.query_handlers = {'queue_request': self.queue_status,
377 self.query_handlers = {'queue_request': self.queue_status,
370 'result_request': self.get_results,
378 'result_request': self.get_results,
371 'history_request': self.get_history,
379 'history_request': self.get_history,
372 'db_request': self.db_query,
380 'db_request': self.db_query,
373 'purge_request': self.purge_results,
381 'purge_request': self.purge_results,
374 'load_request': self.check_load,
382 'load_request': self.check_load,
375 'resubmit_request': self.resubmit_task,
383 'resubmit_request': self.resubmit_task,
376 'shutdown_request': self.shutdown_request,
384 'shutdown_request': self.shutdown_request,
377 'registration_request' : self.register_engine,
385 'registration_request' : self.register_engine,
378 'unregistration_request' : self.unregister_engine,
386 'unregistration_request' : self.unregister_engine,
379 'connection_request': self.connection_request,
387 'connection_request': self.connection_request,
380 }
388 }
381
389
390 # ignore resubmit replies
391 self.resubmit.on_recv(lambda msg: None, copy=False)
392
382 self.log.info("hub::created hub")
393 self.log.info("hub::created hub")
383
394
384 @property
395 @property
385 def _next_id(self):
396 def _next_id(self):
386 """gemerate a new ID.
397 """gemerate a new ID.
387
398
388 No longer reuse old ids, just count from 0."""
399 No longer reuse old ids, just count from 0."""
389 newid = self._idcounter
400 newid = self._idcounter
390 self._idcounter += 1
401 self._idcounter += 1
391 return newid
402 return newid
392 # newid = 0
403 # newid = 0
393 # incoming = [id[0] for id in self.incoming_registrations.itervalues()]
404 # incoming = [id[0] for id in self.incoming_registrations.itervalues()]
394 # # print newid, self.ids, self.incoming_registrations
405 # # print newid, self.ids, self.incoming_registrations
395 # while newid in self.ids or newid in incoming:
406 # while newid in self.ids or newid in incoming:
396 # newid += 1
407 # newid += 1
397 # return newid
408 # return newid
398
409
399 #-----------------------------------------------------------------------------
410 #-----------------------------------------------------------------------------
400 # message validation
411 # message validation
401 #-----------------------------------------------------------------------------
412 #-----------------------------------------------------------------------------
402
413
403 def _validate_targets(self, targets):
414 def _validate_targets(self, targets):
404 """turn any valid targets argument into a list of integer ids"""
415 """turn any valid targets argument into a list of integer ids"""
405 if targets is None:
416 if targets is None:
406 # default to all
417 # default to all
407 targets = self.ids
418 targets = self.ids
408
419
409 if isinstance(targets, (int,str,unicode)):
420 if isinstance(targets, (int,str,unicode)):
410 # only one target specified
421 # only one target specified
411 targets = [targets]
422 targets = [targets]
412 _targets = []
423 _targets = []
413 for t in targets:
424 for t in targets:
414 # map raw identities to ids
425 # map raw identities to ids
415 if isinstance(t, (str,unicode)):
426 if isinstance(t, (str,unicode)):
416 t = self.by_ident.get(t, t)
427 t = self.by_ident.get(t, t)
417 _targets.append(t)
428 _targets.append(t)
418 targets = _targets
429 targets = _targets
419 bad_targets = [ t for t in targets if t not in self.ids ]
430 bad_targets = [ t for t in targets if t not in self.ids ]
420 if bad_targets:
431 if bad_targets:
421 raise IndexError("No Such Engine: %r"%bad_targets)
432 raise IndexError("No Such Engine: %r"%bad_targets)
422 if not targets:
433 if not targets:
423 raise IndexError("No Engines Registered")
434 raise IndexError("No Engines Registered")
424 return targets
435 return targets
425
436
426 #-----------------------------------------------------------------------------
437 #-----------------------------------------------------------------------------
427 # dispatch methods (1 per stream)
438 # dispatch methods (1 per stream)
428 #-----------------------------------------------------------------------------
439 #-----------------------------------------------------------------------------
429
440
430 # def dispatch_registration_request(self, msg):
441 # def dispatch_registration_request(self, msg):
431 # """"""
442 # """"""
432 # self.log.debug("registration::dispatch_register_request(%s)"%msg)
443 # self.log.debug("registration::dispatch_register_request(%s)"%msg)
433 # idents,msg = self.session.feed_identities(msg)
444 # idents,msg = self.session.feed_identities(msg)
434 # if not idents:
445 # if not idents:
435 # self.log.error("Bad Query Message: %s"%msg, exc_info=True)
446 # self.log.error("Bad Query Message: %s"%msg, exc_info=True)
436 # return
447 # return
437 # try:
448 # try:
438 # msg = self.session.unpack_message(msg,content=True)
449 # msg = self.session.unpack_message(msg,content=True)
439 # except:
450 # except:
440 # self.log.error("registration::got bad registration message: %s"%msg, exc_info=True)
451 # self.log.error("registration::got bad registration message: %s"%msg, exc_info=True)
441 # return
452 # return
442 #
453 #
443 # msg_type = msg['msg_type']
454 # msg_type = msg['msg_type']
444 # content = msg['content']
455 # content = msg['content']
445 #
456 #
446 # handler = self.query_handlers.get(msg_type, None)
457 # handler = self.query_handlers.get(msg_type, None)
447 # if handler is None:
458 # if handler is None:
448 # self.log.error("registration::got bad registration message: %s"%msg)
459 # self.log.error("registration::got bad registration message: %s"%msg)
449 # else:
460 # else:
450 # handler(idents, msg)
461 # handler(idents, msg)
451
462
452 def dispatch_monitor_traffic(self, msg):
463 def dispatch_monitor_traffic(self, msg):
453 """all ME and Task queue messages come through here, as well as
464 """all ME and Task queue messages come through here, as well as
454 IOPub traffic."""
465 IOPub traffic."""
455 self.log.debug("monitor traffic: %s"%msg[:2])
466 self.log.debug("monitor traffic: %r"%msg[:2])
456 switch = msg[0]
467 switch = msg[0]
457 idents, msg = self.session.feed_identities(msg[1:])
468 idents, msg = self.session.feed_identities(msg[1:])
458 if not idents:
469 if not idents:
459 self.log.error("Bad Monitor Message: %s"%msg)
470 self.log.error("Bad Monitor Message: %r"%msg)
460 return
471 return
461 handler = self.monitor_handlers.get(switch, None)
472 handler = self.monitor_handlers.get(switch, None)
462 if handler is not None:
473 if handler is not None:
463 handler(idents, msg)
474 handler(idents, msg)
464 else:
475 else:
465 self.log.error("Invalid monitor topic: %s"%switch)
476 self.log.error("Invalid monitor topic: %r"%switch)
466
477
467
478
468 def dispatch_query(self, msg):
479 def dispatch_query(self, msg):
469 """Route registration requests and queries from clients."""
480 """Route registration requests and queries from clients."""
470 idents, msg = self.session.feed_identities(msg)
481 idents, msg = self.session.feed_identities(msg)
471 if not idents:
482 if not idents:
472 self.log.error("Bad Query Message: %s"%msg)
483 self.log.error("Bad Query Message: %r"%msg)
473 return
484 return
474 client_id = idents[0]
485 client_id = idents[0]
475 try:
486 try:
476 msg = self.session.unpack_message(msg, content=True)
487 msg = self.session.unpack_message(msg, content=True)
477 except:
488 except:
478 content = error.wrap_exception()
489 content = error.wrap_exception()
479 self.log.error("Bad Query Message: %s"%msg, exc_info=True)
490 self.log.error("Bad Query Message: %r"%msg, exc_info=True)
480 self.session.send(self.query, "hub_error", ident=client_id,
491 self.session.send(self.query, "hub_error", ident=client_id,
481 content=content)
492 content=content)
482 return
493 return
483
494
484 # print client_id, header, parent, content
495 # print client_id, header, parent, content
485 #switch on message type:
496 #switch on message type:
486 msg_type = msg['msg_type']
497 msg_type = msg['msg_type']
487 self.log.info("client::client %s requested %s"%(client_id, msg_type))
498 self.log.info("client::client %r requested %r"%(client_id, msg_type))
488 handler = self.query_handlers.get(msg_type, None)
499 handler = self.query_handlers.get(msg_type, None)
489 try:
500 try:
490 assert handler is not None, "Bad Message Type: %s"%msg_type
501 assert handler is not None, "Bad Message Type: %r"%msg_type
491 except:
502 except:
492 content = error.wrap_exception()
503 content = error.wrap_exception()
493 self.log.error("Bad Message Type: %s"%msg_type, exc_info=True)
504 self.log.error("Bad Message Type: %r"%msg_type, exc_info=True)
494 self.session.send(self.query, "hub_error", ident=client_id,
505 self.session.send(self.query, "hub_error", ident=client_id,
495 content=content)
506 content=content)
496 return
507 return
508
497 else:
509 else:
498 handler(idents, msg)
510 handler(idents, msg)
499
511
500 def dispatch_db(self, msg):
512 def dispatch_db(self, msg):
501 """"""
513 """"""
502 raise NotImplementedError
514 raise NotImplementedError
503
515
504 #---------------------------------------------------------------------------
516 #---------------------------------------------------------------------------
505 # handler methods (1 per event)
517 # handler methods (1 per event)
506 #---------------------------------------------------------------------------
518 #---------------------------------------------------------------------------
507
519
508 #----------------------- Heartbeat --------------------------------------
520 #----------------------- Heartbeat --------------------------------------
509
521
510 def handle_new_heart(self, heart):
522 def handle_new_heart(self, heart):
511 """handler to attach to heartbeater.
523 """handler to attach to heartbeater.
512 Called when a new heart starts to beat.
524 Called when a new heart starts to beat.
513 Triggers completion of registration."""
525 Triggers completion of registration."""
514 self.log.debug("heartbeat::handle_new_heart(%r)"%heart)
526 self.log.debug("heartbeat::handle_new_heart(%r)"%heart)
515 if heart not in self.incoming_registrations:
527 if heart not in self.incoming_registrations:
516 self.log.info("heartbeat::ignoring new heart: %r"%heart)
528 self.log.info("heartbeat::ignoring new heart: %r"%heart)
517 else:
529 else:
518 self.finish_registration(heart)
530 self.finish_registration(heart)
519
531
520
532
521 def handle_heart_failure(self, heart):
533 def handle_heart_failure(self, heart):
522 """handler to attach to heartbeater.
534 """handler to attach to heartbeater.
523 called when a previously registered heart fails to respond to beat request.
535 called when a previously registered heart fails to respond to beat request.
524 triggers unregistration"""
536 triggers unregistration"""
525 self.log.debug("heartbeat::handle_heart_failure(%r)"%heart)
537 self.log.debug("heartbeat::handle_heart_failure(%r)"%heart)
526 eid = self.hearts.get(heart, None)
538 eid = self.hearts.get(heart, None)
527 queue = self.engines[eid].queue
539 queue = self.engines[eid].queue
528 if eid is None:
540 if eid is None:
529 self.log.info("heartbeat::ignoring heart failure %r"%heart)
541 self.log.info("heartbeat::ignoring heart failure %r"%heart)
530 else:
542 else:
531 self.unregister_engine(heart, dict(content=dict(id=eid, queue=queue)))
543 self.unregister_engine(heart, dict(content=dict(id=eid, queue=queue)))
532
544
533 #----------------------- MUX Queue Traffic ------------------------------
545 #----------------------- MUX Queue Traffic ------------------------------
534
546
535 def save_queue_request(self, idents, msg):
547 def save_queue_request(self, idents, msg):
536 if len(idents) < 2:
548 if len(idents) < 2:
537 self.log.error("invalid identity prefix: %s"%idents)
549 self.log.error("invalid identity prefix: %s"%idents)
538 return
550 return
539 queue_id, client_id = idents[:2]
551 queue_id, client_id = idents[:2]
540 try:
552 try:
541 msg = self.session.unpack_message(msg, content=False)
553 msg = self.session.unpack_message(msg, content=False)
542 except:
554 except:
543 self.log.error("queue::client %r sent invalid message to %r: %s"%(client_id, queue_id, msg), exc_info=True)
555 self.log.error("queue::client %r sent invalid message to %r: %s"%(client_id, queue_id, msg), exc_info=True)
544 return
556 return
545
557
546 eid = self.by_ident.get(queue_id, None)
558 eid = self.by_ident.get(queue_id, None)
547 if eid is None:
559 if eid is None:
548 self.log.error("queue::target %r not registered"%queue_id)
560 self.log.error("queue::target %r not registered"%queue_id)
549 self.log.debug("queue:: valid are: %s"%(self.by_ident.keys()))
561 self.log.debug("queue:: valid are: %s"%(self.by_ident.keys()))
550 return
562 return
551
563
552 header = msg['header']
564 header = msg['header']
553 msg_id = header['msg_id']
565 msg_id = header['msg_id']
554 record = init_record(msg)
566 record = init_record(msg)
555 record['engine_uuid'] = queue_id
567 record['engine_uuid'] = queue_id
556 record['client_uuid'] = client_id
568 record['client_uuid'] = client_id
557 record['queue'] = 'mux'
569 record['queue'] = 'mux'
558
570
559 try:
571 try:
560 # it's posible iopub arrived first:
572 # it's posible iopub arrived first:
561 existing = self.db.get_record(msg_id)
573 existing = self.db.get_record(msg_id)
562 for key,evalue in existing.iteritems():
574 for key,evalue in existing.iteritems():
563 rvalue = record[key]
575 rvalue = record.get(key, None)
564 if evalue and rvalue and evalue != rvalue:
576 if evalue and rvalue and evalue != rvalue:
565 self.log.error("conflicting initial state for record: %s:%s <> %s"%(msg_id, rvalue, evalue))
577 self.log.warn("conflicting initial state for record: %r:%r <%r> %r"%(msg_id, rvalue, key, evalue))
566 elif evalue and not rvalue:
578 elif evalue and not rvalue:
567 record[key] = evalue
579 record[key] = evalue
568 self.db.update_record(msg_id, record)
580 self.db.update_record(msg_id, record)
569 except KeyError:
581 except KeyError:
570 self.db.add_record(msg_id, record)
582 self.db.add_record(msg_id, record)
571
583
572 self.pending.add(msg_id)
584 self.pending.add(msg_id)
573 self.queues[eid].append(msg_id)
585 self.queues[eid].append(msg_id)
574
586
575 def save_queue_result(self, idents, msg):
587 def save_queue_result(self, idents, msg):
576 if len(idents) < 2:
588 if len(idents) < 2:
577 self.log.error("invalid identity prefix: %s"%idents)
589 self.log.error("invalid identity prefix: %s"%idents)
578 return
590 return
579
591
580 client_id, queue_id = idents[:2]
592 client_id, queue_id = idents[:2]
581 try:
593 try:
582 msg = self.session.unpack_message(msg, content=False)
594 msg = self.session.unpack_message(msg, content=False)
583 except:
595 except:
584 self.log.error("queue::engine %r sent invalid message to %r: %s"%(
596 self.log.error("queue::engine %r sent invalid message to %r: %s"%(
585 queue_id,client_id, msg), exc_info=True)
597 queue_id,client_id, msg), exc_info=True)
586 return
598 return
587
599
588 eid = self.by_ident.get(queue_id, None)
600 eid = self.by_ident.get(queue_id, None)
589 if eid is None:
601 if eid is None:
590 self.log.error("queue::unknown engine %r is sending a reply: "%queue_id)
602 self.log.error("queue::unknown engine %r is sending a reply: "%queue_id)
591 # self.log.debug("queue:: %s"%msg[2:])
603 # self.log.debug("queue:: %s"%msg[2:])
592 return
604 return
593
605
594 parent = msg['parent_header']
606 parent = msg['parent_header']
595 if not parent:
607 if not parent:
596 return
608 return
597 msg_id = parent['msg_id']
609 msg_id = parent['msg_id']
598 if msg_id in self.pending:
610 if msg_id in self.pending:
599 self.pending.remove(msg_id)
611 self.pending.remove(msg_id)
600 self.all_completed.add(msg_id)
612 self.all_completed.add(msg_id)
601 self.queues[eid].remove(msg_id)
613 self.queues[eid].remove(msg_id)
602 self.completed[eid].append(msg_id)
614 self.completed[eid].append(msg_id)
603 elif msg_id not in self.all_completed:
615 elif msg_id not in self.all_completed:
604 # it could be a result from a dead engine that died before delivering the
616 # it could be a result from a dead engine that died before delivering the
605 # result
617 # result
606 self.log.warn("queue:: unknown msg finished %s"%msg_id)
618 self.log.warn("queue:: unknown msg finished %s"%msg_id)
607 return
619 return
608 # update record anyway, because the unregistration could have been premature
620 # update record anyway, because the unregistration could have been premature
609 rheader = msg['header']
621 rheader = msg['header']
610 completed = datetime.strptime(rheader['date'], util.ISO8601)
622 completed = datetime.strptime(rheader['date'], util.ISO8601)
611 started = rheader.get('started', None)
623 started = rheader.get('started', None)
612 if started is not None:
624 if started is not None:
613 started = datetime.strptime(started, util.ISO8601)
625 started = datetime.strptime(started, util.ISO8601)
614 result = {
626 result = {
615 'result_header' : rheader,
627 'result_header' : rheader,
616 'result_content': msg['content'],
628 'result_content': msg['content'],
617 'started' : started,
629 'started' : started,
618 'completed' : completed
630 'completed' : completed
619 }
631 }
620
632
621 result['result_buffers'] = msg['buffers']
633 result['result_buffers'] = msg['buffers']
622 try:
634 try:
623 self.db.update_record(msg_id, result)
635 self.db.update_record(msg_id, result)
624 except Exception:
636 except Exception:
625 self.log.error("DB Error updating record %r"%msg_id, exc_info=True)
637 self.log.error("DB Error updating record %r"%msg_id, exc_info=True)
626
638
627
639
628 #--------------------- Task Queue Traffic ------------------------------
640 #--------------------- Task Queue Traffic ------------------------------
629
641
630 def save_task_request(self, idents, msg):
642 def save_task_request(self, idents, msg):
631 """Save the submission of a task."""
643 """Save the submission of a task."""
632 client_id = idents[0]
644 client_id = idents[0]
633
645
634 try:
646 try:
635 msg = self.session.unpack_message(msg, content=False)
647 msg = self.session.unpack_message(msg, content=False)
636 except:
648 except:
637 self.log.error("task::client %r sent invalid task message: %s"%(
649 self.log.error("task::client %r sent invalid task message: %s"%(
638 client_id, msg), exc_info=True)
650 client_id, msg), exc_info=True)
639 return
651 return
640 record = init_record(msg)
652 record = init_record(msg)
641
653
642 record['client_uuid'] = client_id
654 record['client_uuid'] = client_id
643 record['queue'] = 'task'
655 record['queue'] = 'task'
644 header = msg['header']
656 header = msg['header']
645 msg_id = header['msg_id']
657 msg_id = header['msg_id']
646 self.pending.add(msg_id)
658 self.pending.add(msg_id)
647 self.unassigned.add(msg_id)
659 self.unassigned.add(msg_id)
648 try:
660 try:
649 # it's posible iopub arrived first:
661 # it's posible iopub arrived first:
650 existing = self.db.get_record(msg_id)
662 existing = self.db.get_record(msg_id)
663 if existing['resubmitted']:
664 for key in ('submitted', 'client_uuid', 'buffers'):
665 # don't clobber these keys on resubmit
666 # submitted and client_uuid should be different
667 # and buffers might be big, and shouldn't have changed
668 record.pop(key)
669 # still check content,header which should not change
670 # but are not expensive to compare as buffers
671
651 for key,evalue in existing.iteritems():
672 for key,evalue in existing.iteritems():
652 rvalue = record[key]
673 if key.endswith('buffers'):
674 # don't compare buffers
675 continue
676 rvalue = record.get(key, None)
653 if evalue and rvalue and evalue != rvalue:
677 if evalue and rvalue and evalue != rvalue:
654 self.log.error("conflicting initial state for record: %s:%s <> %s"%(msg_id, rvalue, evalue))
678 self.log.warn("conflicting initial state for record: %r:%r <%r> %r"%(msg_id, rvalue, key, evalue))
655 elif evalue and not rvalue:
679 elif evalue and not rvalue:
656 record[key] = evalue
680 record[key] = evalue
657 self.db.update_record(msg_id, record)
681 self.db.update_record(msg_id, record)
658 except KeyError:
682 except KeyError:
659 self.db.add_record(msg_id, record)
683 self.db.add_record(msg_id, record)
660 except Exception:
684 except Exception:
661 self.log.error("DB Error saving task request %r"%msg_id, exc_info=True)
685 self.log.error("DB Error saving task request %r"%msg_id, exc_info=True)
662
686
663 def save_task_result(self, idents, msg):
687 def save_task_result(self, idents, msg):
664 """save the result of a completed task."""
688 """save the result of a completed task."""
665 client_id = idents[0]
689 client_id = idents[0]
666 try:
690 try:
667 msg = self.session.unpack_message(msg, content=False)
691 msg = self.session.unpack_message(msg, content=False)
668 except:
692 except:
669 self.log.error("task::invalid task result message send to %r: %s"%(
693 self.log.error("task::invalid task result message send to %r: %s"%(
670 client_id, msg), exc_info=True)
694 client_id, msg), exc_info=True)
671 raise
695 raise
672 return
696 return
673
697
674 parent = msg['parent_header']
698 parent = msg['parent_header']
675 if not parent:
699 if not parent:
676 # print msg
700 # print msg
677 self.log.warn("Task %r had no parent!"%msg)
701 self.log.warn("Task %r had no parent!"%msg)
678 return
702 return
679 msg_id = parent['msg_id']
703 msg_id = parent['msg_id']
680 if msg_id in self.unassigned:
704 if msg_id in self.unassigned:
681 self.unassigned.remove(msg_id)
705 self.unassigned.remove(msg_id)
682
706
683 header = msg['header']
707 header = msg['header']
684 engine_uuid = header.get('engine', None)
708 engine_uuid = header.get('engine', None)
685 eid = self.by_ident.get(engine_uuid, None)
709 eid = self.by_ident.get(engine_uuid, None)
686
710
687 if msg_id in self.pending:
711 if msg_id in self.pending:
688 self.pending.remove(msg_id)
712 self.pending.remove(msg_id)
689 self.all_completed.add(msg_id)
713 self.all_completed.add(msg_id)
690 if eid is not None:
714 if eid is not None:
691 self.completed[eid].append(msg_id)
715 self.completed[eid].append(msg_id)
692 if msg_id in self.tasks[eid]:
716 if msg_id in self.tasks[eid]:
693 self.tasks[eid].remove(msg_id)
717 self.tasks[eid].remove(msg_id)
694 completed = datetime.strptime(header['date'], util.ISO8601)
718 completed = datetime.strptime(header['date'], util.ISO8601)
695 started = header.get('started', None)
719 started = header.get('started', None)
696 if started is not None:
720 if started is not None:
697 started = datetime.strptime(started, util.ISO8601)
721 started = datetime.strptime(started, util.ISO8601)
698 result = {
722 result = {
699 'result_header' : header,
723 'result_header' : header,
700 'result_content': msg['content'],
724 'result_content': msg['content'],
701 'started' : started,
725 'started' : started,
702 'completed' : completed,
726 'completed' : completed,
703 'engine_uuid': engine_uuid
727 'engine_uuid': engine_uuid
704 }
728 }
705
729
706 result['result_buffers'] = msg['buffers']
730 result['result_buffers'] = msg['buffers']
707 try:
731 try:
708 self.db.update_record(msg_id, result)
732 self.db.update_record(msg_id, result)
709 except Exception:
733 except Exception:
710 self.log.error("DB Error saving task request %r"%msg_id, exc_info=True)
734 self.log.error("DB Error saving task request %r"%msg_id, exc_info=True)
711
735
712 else:
736 else:
713 self.log.debug("task::unknown task %s finished"%msg_id)
737 self.log.debug("task::unknown task %s finished"%msg_id)
714
738
715 def save_task_destination(self, idents, msg):
739 def save_task_destination(self, idents, msg):
716 try:
740 try:
717 msg = self.session.unpack_message(msg, content=True)
741 msg = self.session.unpack_message(msg, content=True)
718 except:
742 except:
719 self.log.error("task::invalid task tracking message", exc_info=True)
743 self.log.error("task::invalid task tracking message", exc_info=True)
720 return
744 return
721 content = msg['content']
745 content = msg['content']
722 # print (content)
746 # print (content)
723 msg_id = content['msg_id']
747 msg_id = content['msg_id']
724 engine_uuid = content['engine_id']
748 engine_uuid = content['engine_id']
725 eid = self.by_ident[engine_uuid]
749 eid = self.by_ident[engine_uuid]
726
750
727 self.log.info("task::task %s arrived on %s"%(msg_id, eid))
751 self.log.info("task::task %s arrived on %s"%(msg_id, eid))
728 if msg_id in self.unassigned:
752 if msg_id in self.unassigned:
729 self.unassigned.remove(msg_id)
753 self.unassigned.remove(msg_id)
730 # else:
754 # else:
731 # self.log.debug("task::task %s not listed as MIA?!"%(msg_id))
755 # self.log.debug("task::task %s not listed as MIA?!"%(msg_id))
732
756
733 self.tasks[eid].append(msg_id)
757 self.tasks[eid].append(msg_id)
734 # self.pending[msg_id][1].update(received=datetime.now(),engine=(eid,engine_uuid))
758 # self.pending[msg_id][1].update(received=datetime.now(),engine=(eid,engine_uuid))
735 try:
759 try:
736 self.db.update_record(msg_id, dict(engine_uuid=engine_uuid))
760 self.db.update_record(msg_id, dict(engine_uuid=engine_uuid))
737 except Exception:
761 except Exception:
738 self.log.error("DB Error saving task destination %r"%msg_id, exc_info=True)
762 self.log.error("DB Error saving task destination %r"%msg_id, exc_info=True)
739
763
740
764
741 def mia_task_request(self, idents, msg):
765 def mia_task_request(self, idents, msg):
742 raise NotImplementedError
766 raise NotImplementedError
743 client_id = idents[0]
767 client_id = idents[0]
744 # content = dict(mia=self.mia,status='ok')
768 # content = dict(mia=self.mia,status='ok')
745 # self.session.send('mia_reply', content=content, idents=client_id)
769 # self.session.send('mia_reply', content=content, idents=client_id)
746
770
747
771
748 #--------------------- IOPub Traffic ------------------------------
772 #--------------------- IOPub Traffic ------------------------------
749
773
750 def save_iopub_message(self, topics, msg):
774 def save_iopub_message(self, topics, msg):
751 """save an iopub message into the db"""
775 """save an iopub message into the db"""
752 # print (topics)
776 # print (topics)
753 try:
777 try:
754 msg = self.session.unpack_message(msg, content=True)
778 msg = self.session.unpack_message(msg, content=True)
755 except:
779 except:
756 self.log.error("iopub::invalid IOPub message", exc_info=True)
780 self.log.error("iopub::invalid IOPub message", exc_info=True)
757 return
781 return
758
782
759 parent = msg['parent_header']
783 parent = msg['parent_header']
760 if not parent:
784 if not parent:
761 self.log.error("iopub::invalid IOPub message: %s"%msg)
785 self.log.error("iopub::invalid IOPub message: %s"%msg)
762 return
786 return
763 msg_id = parent['msg_id']
787 msg_id = parent['msg_id']
764 msg_type = msg['msg_type']
788 msg_type = msg['msg_type']
765 content = msg['content']
789 content = msg['content']
766
790
767 # ensure msg_id is in db
791 # ensure msg_id is in db
768 try:
792 try:
769 rec = self.db.get_record(msg_id)
793 rec = self.db.get_record(msg_id)
770 except KeyError:
794 except KeyError:
771 rec = empty_record()
795 rec = empty_record()
772 rec['msg_id'] = msg_id
796 rec['msg_id'] = msg_id
773 self.db.add_record(msg_id, rec)
797 self.db.add_record(msg_id, rec)
774 # stream
798 # stream
775 d = {}
799 d = {}
776 if msg_type == 'stream':
800 if msg_type == 'stream':
777 name = content['name']
801 name = content['name']
778 s = rec[name] or ''
802 s = rec[name] or ''
779 d[name] = s + content['data']
803 d[name] = s + content['data']
780
804
781 elif msg_type == 'pyerr':
805 elif msg_type == 'pyerr':
782 d['pyerr'] = content
806 d['pyerr'] = content
783 elif msg_type == 'pyin':
807 elif msg_type == 'pyin':
784 d['pyin'] = content['code']
808 d['pyin'] = content['code']
785 else:
809 else:
786 d[msg_type] = content.get('data', '')
810 d[msg_type] = content.get('data', '')
787
811
788 try:
812 try:
789 self.db.update_record(msg_id, d)
813 self.db.update_record(msg_id, d)
790 except Exception:
814 except Exception:
791 self.log.error("DB Error saving iopub message %r"%msg_id, exc_info=True)
815 self.log.error("DB Error saving iopub message %r"%msg_id, exc_info=True)
792
816
793
817
794
818
795 #-------------------------------------------------------------------------
819 #-------------------------------------------------------------------------
796 # Registration requests
820 # Registration requests
797 #-------------------------------------------------------------------------
821 #-------------------------------------------------------------------------
798
822
799 def connection_request(self, client_id, msg):
823 def connection_request(self, client_id, msg):
800 """Reply with connection addresses for clients."""
824 """Reply with connection addresses for clients."""
801 self.log.info("client::client %s connected"%client_id)
825 self.log.info("client::client %s connected"%client_id)
802 content = dict(status='ok')
826 content = dict(status='ok')
803 content.update(self.client_info)
827 content.update(self.client_info)
804 jsonable = {}
828 jsonable = {}
805 for k,v in self.keytable.iteritems():
829 for k,v in self.keytable.iteritems():
806 if v not in self.dead_engines:
830 if v not in self.dead_engines:
807 jsonable[str(k)] = v
831 jsonable[str(k)] = v
808 content['engines'] = jsonable
832 content['engines'] = jsonable
809 self.session.send(self.query, 'connection_reply', content, parent=msg, ident=client_id)
833 self.session.send(self.query, 'connection_reply', content, parent=msg, ident=client_id)
810
834
811 def register_engine(self, reg, msg):
835 def register_engine(self, reg, msg):
812 """Register a new engine."""
836 """Register a new engine."""
813 content = msg['content']
837 content = msg['content']
814 try:
838 try:
815 queue = content['queue']
839 queue = content['queue']
816 except KeyError:
840 except KeyError:
817 self.log.error("registration::queue not specified", exc_info=True)
841 self.log.error("registration::queue not specified", exc_info=True)
818 return
842 return
819 heart = content.get('heartbeat', None)
843 heart = content.get('heartbeat', None)
820 """register a new engine, and create the socket(s) necessary"""
844 """register a new engine, and create the socket(s) necessary"""
821 eid = self._next_id
845 eid = self._next_id
822 # print (eid, queue, reg, heart)
846 # print (eid, queue, reg, heart)
823
847
824 self.log.debug("registration::register_engine(%i, %r, %r, %r)"%(eid, queue, reg, heart))
848 self.log.debug("registration::register_engine(%i, %r, %r, %r)"%(eid, queue, reg, heart))
825
849
826 content = dict(id=eid,status='ok')
850 content = dict(id=eid,status='ok')
827 content.update(self.engine_info)
851 content.update(self.engine_info)
828 # check if requesting available IDs:
852 # check if requesting available IDs:
829 if queue in self.by_ident:
853 if queue in self.by_ident:
830 try:
854 try:
831 raise KeyError("queue_id %r in use"%queue)
855 raise KeyError("queue_id %r in use"%queue)
832 except:
856 except:
833 content = error.wrap_exception()
857 content = error.wrap_exception()
834 self.log.error("queue_id %r in use"%queue, exc_info=True)
858 self.log.error("queue_id %r in use"%queue, exc_info=True)
835 elif heart in self.hearts: # need to check unique hearts?
859 elif heart in self.hearts: # need to check unique hearts?
836 try:
860 try:
837 raise KeyError("heart_id %r in use"%heart)
861 raise KeyError("heart_id %r in use"%heart)
838 except:
862 except:
839 self.log.error("heart_id %r in use"%heart, exc_info=True)
863 self.log.error("heart_id %r in use"%heart, exc_info=True)
840 content = error.wrap_exception()
864 content = error.wrap_exception()
841 else:
865 else:
842 for h, pack in self.incoming_registrations.iteritems():
866 for h, pack in self.incoming_registrations.iteritems():
843 if heart == h:
867 if heart == h:
844 try:
868 try:
845 raise KeyError("heart_id %r in use"%heart)
869 raise KeyError("heart_id %r in use"%heart)
846 except:
870 except:
847 self.log.error("heart_id %r in use"%heart, exc_info=True)
871 self.log.error("heart_id %r in use"%heart, exc_info=True)
848 content = error.wrap_exception()
872 content = error.wrap_exception()
849 break
873 break
850 elif queue == pack[1]:
874 elif queue == pack[1]:
851 try:
875 try:
852 raise KeyError("queue_id %r in use"%queue)
876 raise KeyError("queue_id %r in use"%queue)
853 except:
877 except:
854 self.log.error("queue_id %r in use"%queue, exc_info=True)
878 self.log.error("queue_id %r in use"%queue, exc_info=True)
855 content = error.wrap_exception()
879 content = error.wrap_exception()
856 break
880 break
857
881
858 msg = self.session.send(self.query, "registration_reply",
882 msg = self.session.send(self.query, "registration_reply",
859 content=content,
883 content=content,
860 ident=reg)
884 ident=reg)
861
885
862 if content['status'] == 'ok':
886 if content['status'] == 'ok':
863 if heart in self.heartmonitor.hearts:
887 if heart in self.heartmonitor.hearts:
864 # already beating
888 # already beating
865 self.incoming_registrations[heart] = (eid,queue,reg[0],None)
889 self.incoming_registrations[heart] = (eid,queue,reg[0],None)
866 self.finish_registration(heart)
890 self.finish_registration(heart)
867 else:
891 else:
868 purge = lambda : self._purge_stalled_registration(heart)
892 purge = lambda : self._purge_stalled_registration(heart)
869 dc = ioloop.DelayedCallback(purge, self.registration_timeout, self.loop)
893 dc = ioloop.DelayedCallback(purge, self.registration_timeout, self.loop)
870 dc.start()
894 dc.start()
871 self.incoming_registrations[heart] = (eid,queue,reg[0],dc)
895 self.incoming_registrations[heart] = (eid,queue,reg[0],dc)
872 else:
896 else:
873 self.log.error("registration::registration %i failed: %s"%(eid, content['evalue']))
897 self.log.error("registration::registration %i failed: %s"%(eid, content['evalue']))
874 return eid
898 return eid
875
899
876 def unregister_engine(self, ident, msg):
900 def unregister_engine(self, ident, msg):
877 """Unregister an engine that explicitly requested to leave."""
901 """Unregister an engine that explicitly requested to leave."""
878 try:
902 try:
879 eid = msg['content']['id']
903 eid = msg['content']['id']
880 except:
904 except:
881 self.log.error("registration::bad engine id for unregistration: %s"%ident, exc_info=True)
905 self.log.error("registration::bad engine id for unregistration: %s"%ident, exc_info=True)
882 return
906 return
883 self.log.info("registration::unregister_engine(%s)"%eid)
907 self.log.info("registration::unregister_engine(%s)"%eid)
884 # print (eid)
908 # print (eid)
885 uuid = self.keytable[eid]
909 uuid = self.keytable[eid]
886 content=dict(id=eid, queue=uuid)
910 content=dict(id=eid, queue=uuid)
887 self.dead_engines.add(uuid)
911 self.dead_engines.add(uuid)
888 # self.ids.remove(eid)
912 # self.ids.remove(eid)
889 # uuid = self.keytable.pop(eid)
913 # uuid = self.keytable.pop(eid)
890 #
914 #
891 # ec = self.engines.pop(eid)
915 # ec = self.engines.pop(eid)
892 # self.hearts.pop(ec.heartbeat)
916 # self.hearts.pop(ec.heartbeat)
893 # self.by_ident.pop(ec.queue)
917 # self.by_ident.pop(ec.queue)
894 # self.completed.pop(eid)
918 # self.completed.pop(eid)
895 handleit = lambda : self._handle_stranded_msgs(eid, uuid)
919 handleit = lambda : self._handle_stranded_msgs(eid, uuid)
896 dc = ioloop.DelayedCallback(handleit, self.registration_timeout, self.loop)
920 dc = ioloop.DelayedCallback(handleit, self.registration_timeout, self.loop)
897 dc.start()
921 dc.start()
898 ############## TODO: HANDLE IT ################
922 ############## TODO: HANDLE IT ################
899
923
900 if self.notifier:
924 if self.notifier:
901 self.session.send(self.notifier, "unregistration_notification", content=content)
925 self.session.send(self.notifier, "unregistration_notification", content=content)
902
926
903 def _handle_stranded_msgs(self, eid, uuid):
927 def _handle_stranded_msgs(self, eid, uuid):
904 """Handle messages known to be on an engine when the engine unregisters.
928 """Handle messages known to be on an engine when the engine unregisters.
905
929
906 It is possible that this will fire prematurely - that is, an engine will
930 It is possible that this will fire prematurely - that is, an engine will
907 go down after completing a result, and the client will be notified
931 go down after completing a result, and the client will be notified
908 that the result failed and later receive the actual result.
932 that the result failed and later receive the actual result.
909 """
933 """
910
934
911 outstanding = self.queues[eid]
935 outstanding = self.queues[eid]
912
936
913 for msg_id in outstanding:
937 for msg_id in outstanding:
914 self.pending.remove(msg_id)
938 self.pending.remove(msg_id)
915 self.all_completed.add(msg_id)
939 self.all_completed.add(msg_id)
916 try:
940 try:
917 raise error.EngineError("Engine %r died while running task %r"%(eid, msg_id))
941 raise error.EngineError("Engine %r died while running task %r"%(eid, msg_id))
918 except:
942 except:
919 content = error.wrap_exception()
943 content = error.wrap_exception()
920 # build a fake header:
944 # build a fake header:
921 header = {}
945 header = {}
922 header['engine'] = uuid
946 header['engine'] = uuid
923 header['date'] = datetime.now()
947 header['date'] = datetime.now()
924 rec = dict(result_content=content, result_header=header, result_buffers=[])
948 rec = dict(result_content=content, result_header=header, result_buffers=[])
925 rec['completed'] = header['date']
949 rec['completed'] = header['date']
926 rec['engine_uuid'] = uuid
950 rec['engine_uuid'] = uuid
927 try:
951 try:
928 self.db.update_record(msg_id, rec)
952 self.db.update_record(msg_id, rec)
929 except Exception:
953 except Exception:
930 self.log.error("DB Error handling stranded msg %r"%msg_id, exc_info=True)
954 self.log.error("DB Error handling stranded msg %r"%msg_id, exc_info=True)
931
955
932
956
933 def finish_registration(self, heart):
957 def finish_registration(self, heart):
934 """Second half of engine registration, called after our HeartMonitor
958 """Second half of engine registration, called after our HeartMonitor
935 has received a beat from the Engine's Heart."""
959 has received a beat from the Engine's Heart."""
936 try:
960 try:
937 (eid,queue,reg,purge) = self.incoming_registrations.pop(heart)
961 (eid,queue,reg,purge) = self.incoming_registrations.pop(heart)
938 except KeyError:
962 except KeyError:
939 self.log.error("registration::tried to finish nonexistant registration", exc_info=True)
963 self.log.error("registration::tried to finish nonexistant registration", exc_info=True)
940 return
964 return
941 self.log.info("registration::finished registering engine %i:%r"%(eid,queue))
965 self.log.info("registration::finished registering engine %i:%r"%(eid,queue))
942 if purge is not None:
966 if purge is not None:
943 purge.stop()
967 purge.stop()
944 control = queue
968 control = queue
945 self.ids.add(eid)
969 self.ids.add(eid)
946 self.keytable[eid] = queue
970 self.keytable[eid] = queue
947 self.engines[eid] = EngineConnector(id=eid, queue=queue, registration=reg,
971 self.engines[eid] = EngineConnector(id=eid, queue=queue, registration=reg,
948 control=control, heartbeat=heart)
972 control=control, heartbeat=heart)
949 self.by_ident[queue] = eid
973 self.by_ident[queue] = eid
950 self.queues[eid] = list()
974 self.queues[eid] = list()
951 self.tasks[eid] = list()
975 self.tasks[eid] = list()
952 self.completed[eid] = list()
976 self.completed[eid] = list()
953 self.hearts[heart] = eid
977 self.hearts[heart] = eid
954 content = dict(id=eid, queue=self.engines[eid].queue)
978 content = dict(id=eid, queue=self.engines[eid].queue)
955 if self.notifier:
979 if self.notifier:
956 self.session.send(self.notifier, "registration_notification", content=content)
980 self.session.send(self.notifier, "registration_notification", content=content)
957 self.log.info("engine::Engine Connected: %i"%eid)
981 self.log.info("engine::Engine Connected: %i"%eid)
958
982
959 def _purge_stalled_registration(self, heart):
983 def _purge_stalled_registration(self, heart):
960 if heart in self.incoming_registrations:
984 if heart in self.incoming_registrations:
961 eid = self.incoming_registrations.pop(heart)[0]
985 eid = self.incoming_registrations.pop(heart)[0]
962 self.log.info("registration::purging stalled registration: %i"%eid)
986 self.log.info("registration::purging stalled registration: %i"%eid)
963 else:
987 else:
964 pass
988 pass
965
989
966 #-------------------------------------------------------------------------
990 #-------------------------------------------------------------------------
967 # Client Requests
991 # Client Requests
968 #-------------------------------------------------------------------------
992 #-------------------------------------------------------------------------
969
993
970 def shutdown_request(self, client_id, msg):
994 def shutdown_request(self, client_id, msg):
971 """handle shutdown request."""
995 """handle shutdown request."""
972 self.session.send(self.query, 'shutdown_reply', content={'status': 'ok'}, ident=client_id)
996 self.session.send(self.query, 'shutdown_reply', content={'status': 'ok'}, ident=client_id)
973 # also notify other clients of shutdown
997 # also notify other clients of shutdown
974 self.session.send(self.notifier, 'shutdown_notice', content={'status': 'ok'})
998 self.session.send(self.notifier, 'shutdown_notice', content={'status': 'ok'})
975 dc = ioloop.DelayedCallback(lambda : self._shutdown(), 1000, self.loop)
999 dc = ioloop.DelayedCallback(lambda : self._shutdown(), 1000, self.loop)
976 dc.start()
1000 dc.start()
977
1001
978 def _shutdown(self):
1002 def _shutdown(self):
979 self.log.info("hub::hub shutting down.")
1003 self.log.info("hub::hub shutting down.")
980 time.sleep(0.1)
1004 time.sleep(0.1)
981 sys.exit(0)
1005 sys.exit(0)
982
1006
983
1007
984 def check_load(self, client_id, msg):
1008 def check_load(self, client_id, msg):
985 content = msg['content']
1009 content = msg['content']
986 try:
1010 try:
987 targets = content['targets']
1011 targets = content['targets']
988 targets = self._validate_targets(targets)
1012 targets = self._validate_targets(targets)
989 except:
1013 except:
990 content = error.wrap_exception()
1014 content = error.wrap_exception()
991 self.session.send(self.query, "hub_error",
1015 self.session.send(self.query, "hub_error",
992 content=content, ident=client_id)
1016 content=content, ident=client_id)
993 return
1017 return
994
1018
995 content = dict(status='ok')
1019 content = dict(status='ok')
996 # loads = {}
1020 # loads = {}
997 for t in targets:
1021 for t in targets:
998 content[bytes(t)] = len(self.queues[t])+len(self.tasks[t])
1022 content[bytes(t)] = len(self.queues[t])+len(self.tasks[t])
999 self.session.send(self.query, "load_reply", content=content, ident=client_id)
1023 self.session.send(self.query, "load_reply", content=content, ident=client_id)
1000
1024
1001
1025
1002 def queue_status(self, client_id, msg):
1026 def queue_status(self, client_id, msg):
1003 """Return the Queue status of one or more targets.
1027 """Return the Queue status of one or more targets.
1004 if verbose: return the msg_ids
1028 if verbose: return the msg_ids
1005 else: return len of each type.
1029 else: return len of each type.
1006 keys: queue (pending MUX jobs)
1030 keys: queue (pending MUX jobs)
1007 tasks (pending Task jobs)
1031 tasks (pending Task jobs)
1008 completed (finished jobs from both queues)"""
1032 completed (finished jobs from both queues)"""
1009 content = msg['content']
1033 content = msg['content']
1010 targets = content['targets']
1034 targets = content['targets']
1011 try:
1035 try:
1012 targets = self._validate_targets(targets)
1036 targets = self._validate_targets(targets)
1013 except:
1037 except:
1014 content = error.wrap_exception()
1038 content = error.wrap_exception()
1015 self.session.send(self.query, "hub_error",
1039 self.session.send(self.query, "hub_error",
1016 content=content, ident=client_id)
1040 content=content, ident=client_id)
1017 return
1041 return
1018 verbose = content.get('verbose', False)
1042 verbose = content.get('verbose', False)
1019 content = dict(status='ok')
1043 content = dict(status='ok')
1020 for t in targets:
1044 for t in targets:
1021 queue = self.queues[t]
1045 queue = self.queues[t]
1022 completed = self.completed[t]
1046 completed = self.completed[t]
1023 tasks = self.tasks[t]
1047 tasks = self.tasks[t]
1024 if not verbose:
1048 if not verbose:
1025 queue = len(queue)
1049 queue = len(queue)
1026 completed = len(completed)
1050 completed = len(completed)
1027 tasks = len(tasks)
1051 tasks = len(tasks)
1028 content[bytes(t)] = {'queue': queue, 'completed': completed , 'tasks': tasks}
1052 content[bytes(t)] = {'queue': queue, 'completed': completed , 'tasks': tasks}
1029 content['unassigned'] = list(self.unassigned) if verbose else len(self.unassigned)
1053 content['unassigned'] = list(self.unassigned) if verbose else len(self.unassigned)
1030
1054
1031 self.session.send(self.query, "queue_reply", content=content, ident=client_id)
1055 self.session.send(self.query, "queue_reply", content=content, ident=client_id)
1032
1056
1033 def purge_results(self, client_id, msg):
1057 def purge_results(self, client_id, msg):
1034 """Purge results from memory. This method is more valuable before we move
1058 """Purge results from memory. This method is more valuable before we move
1035 to a DB based message storage mechanism."""
1059 to a DB based message storage mechanism."""
1036 content = msg['content']
1060 content = msg['content']
1037 msg_ids = content.get('msg_ids', [])
1061 msg_ids = content.get('msg_ids', [])
1038 reply = dict(status='ok')
1062 reply = dict(status='ok')
1039 if msg_ids == 'all':
1063 if msg_ids == 'all':
1040 try:
1064 try:
1041 self.db.drop_matching_records(dict(completed={'$ne':None}))
1065 self.db.drop_matching_records(dict(completed={'$ne':None}))
1042 except Exception:
1066 except Exception:
1043 reply = error.wrap_exception()
1067 reply = error.wrap_exception()
1044 else:
1068 else:
1045 for msg_id in msg_ids:
1069 for msg_id in msg_ids:
1046 if msg_id in self.all_completed:
1070 if msg_id in self.all_completed:
1047 self.db.drop_record(msg_id)
1071 self.db.drop_record(msg_id)
1048 else:
1072 else:
1049 if msg_id in self.pending:
1073 if msg_id in self.pending:
1050 try:
1074 try:
1051 raise IndexError("msg pending: %r"%msg_id)
1075 raise IndexError("msg pending: %r"%msg_id)
1052 except:
1076 except:
1053 reply = error.wrap_exception()
1077 reply = error.wrap_exception()
1054 else:
1078 else:
1055 try:
1079 try:
1056 raise IndexError("No such msg: %r"%msg_id)
1080 raise IndexError("No such msg: %r"%msg_id)
1057 except:
1081 except:
1058 reply = error.wrap_exception()
1082 reply = error.wrap_exception()
1059 break
1083 break
1060 eids = content.get('engine_ids', [])
1084 eids = content.get('engine_ids', [])
1061 for eid in eids:
1085 for eid in eids:
1062 if eid not in self.engines:
1086 if eid not in self.engines:
1063 try:
1087 try:
1064 raise IndexError("No such engine: %i"%eid)
1088 raise IndexError("No such engine: %i"%eid)
1065 except:
1089 except:
1066 reply = error.wrap_exception()
1090 reply = error.wrap_exception()
1067 break
1091 break
1068 msg_ids = self.completed.pop(eid)
1092 msg_ids = self.completed.pop(eid)
1069 uid = self.engines[eid].queue
1093 uid = self.engines[eid].queue
1070 try:
1094 try:
1071 self.db.drop_matching_records(dict(engine_uuid=uid, completed={'$ne':None}))
1095 self.db.drop_matching_records(dict(engine_uuid=uid, completed={'$ne':None}))
1072 except Exception:
1096 except Exception:
1073 reply = error.wrap_exception()
1097 reply = error.wrap_exception()
1074 break
1098 break
1075
1099
1076 self.session.send(self.query, 'purge_reply', content=reply, ident=client_id)
1100 self.session.send(self.query, 'purge_reply', content=reply, ident=client_id)
1077
1101
1078 def resubmit_task(self, client_id, msg, buffers):
1102 def resubmit_task(self, client_id, msg):
1079 """Resubmit a task."""
1103 """Resubmit one or more tasks."""
1080 raise NotImplementedError
1104 def finish(reply):
1105 self.session.send(self.query, 'resubmit_reply', content=reply, ident=client_id)
1106
1107 content = msg['content']
1108 msg_ids = content['msg_ids']
1109 reply = dict(status='ok')
1110 try:
1111 records = self.db.find_records({'msg_id' : {'$in' : msg_ids}}, keys=[
1112 'header', 'content', 'buffers'])
1113 except Exception:
1114 self.log.error('db::db error finding tasks to resubmit', exc_info=True)
1115 return finish(error.wrap_exception())
1116
1117 # validate msg_ids
1118 found_ids = [ rec['msg_id'] for rec in records ]
1119 invalid_ids = filter(lambda m: m in self.pending, found_ids)
1120 if len(records) > len(msg_ids):
1121 try:
1122 raise RuntimeError("DB appears to be in an inconsistent state."
1123 "More matching records were found than should exist")
1124 except Exception:
1125 return finish(error.wrap_exception())
1126 elif len(records) < len(msg_ids):
1127 missing = [ m for m in msg_ids if m not in found_ids ]
1128 try:
1129 raise KeyError("No such msg(s): %s"%missing)
1130 except KeyError:
1131 return finish(error.wrap_exception())
1132 elif invalid_ids:
1133 msg_id = invalid_ids[0]
1134 try:
1135 raise ValueError("Task %r appears to be inflight"%(msg_id))
1136 except Exception:
1137 return finish(error.wrap_exception())
1138
1139 # clear the existing records
1140 rec = empty_record()
1141 map(rec.pop, ['msg_id', 'header', 'content', 'buffers', 'submitted'])
1142 rec['resubmitted'] = datetime.now()
1143 rec['queue'] = 'task'
1144 rec['client_uuid'] = client_id[0]
1145 try:
1146 for msg_id in msg_ids:
1147 self.all_completed.discard(msg_id)
1148 self.db.update_record(msg_id, rec)
1149 except Exception:
1150 self.log.error('db::db error upating record', exc_info=True)
1151 reply = error.wrap_exception()
1152 else:
1153 # send the messages
1154 for rec in records:
1155 header = rec['header']
1156 msg = self.session.msg(header['msg_type'])
1157 msg['content'] = rec['content']
1158 msg['header'] = header
1159 msg['msg_id'] = rec['msg_id']
1160 self.session.send(self.resubmit, msg, buffers=rec['buffers'])
1161
1162 finish(dict(status='ok'))
1163
1081
1164
1082 def _extract_record(self, rec):
1165 def _extract_record(self, rec):
1083 """decompose a TaskRecord dict into subsection of reply for get_result"""
1166 """decompose a TaskRecord dict into subsection of reply for get_result"""
1084 io_dict = {}
1167 io_dict = {}
1085 for key in 'pyin pyout pyerr stdout stderr'.split():
1168 for key in 'pyin pyout pyerr stdout stderr'.split():
1086 io_dict[key] = rec[key]
1169 io_dict[key] = rec[key]
1087 content = { 'result_content': rec['result_content'],
1170 content = { 'result_content': rec['result_content'],
1088 'header': rec['header'],
1171 'header': rec['header'],
1089 'result_header' : rec['result_header'],
1172 'result_header' : rec['result_header'],
1090 'io' : io_dict,
1173 'io' : io_dict,
1091 }
1174 }
1092 if rec['result_buffers']:
1175 if rec['result_buffers']:
1093 buffers = map(str, rec['result_buffers'])
1176 buffers = map(str, rec['result_buffers'])
1094 else:
1177 else:
1095 buffers = []
1178 buffers = []
1096
1179
1097 return content, buffers
1180 return content, buffers
1098
1181
1099 def get_results(self, client_id, msg):
1182 def get_results(self, client_id, msg):
1100 """Get the result of 1 or more messages."""
1183 """Get the result of 1 or more messages."""
1101 content = msg['content']
1184 content = msg['content']
1102 msg_ids = sorted(set(content['msg_ids']))
1185 msg_ids = sorted(set(content['msg_ids']))
1103 statusonly = content.get('status_only', False)
1186 statusonly = content.get('status_only', False)
1104 pending = []
1187 pending = []
1105 completed = []
1188 completed = []
1106 content = dict(status='ok')
1189 content = dict(status='ok')
1107 content['pending'] = pending
1190 content['pending'] = pending
1108 content['completed'] = completed
1191 content['completed'] = completed
1109 buffers = []
1192 buffers = []
1110 if not statusonly:
1193 if not statusonly:
1111 try:
1194 try:
1112 matches = self.db.find_records(dict(msg_id={'$in':msg_ids}))
1195 matches = self.db.find_records(dict(msg_id={'$in':msg_ids}))
1113 # turn match list into dict, for faster lookup
1196 # turn match list into dict, for faster lookup
1114 records = {}
1197 records = {}
1115 for rec in matches:
1198 for rec in matches:
1116 records[rec['msg_id']] = rec
1199 records[rec['msg_id']] = rec
1117 except Exception:
1200 except Exception:
1118 content = error.wrap_exception()
1201 content = error.wrap_exception()
1119 self.session.send(self.query, "result_reply", content=content,
1202 self.session.send(self.query, "result_reply", content=content,
1120 parent=msg, ident=client_id)
1203 parent=msg, ident=client_id)
1121 return
1204 return
1122 else:
1205 else:
1123 records = {}
1206 records = {}
1124 for msg_id in msg_ids:
1207 for msg_id in msg_ids:
1125 if msg_id in self.pending:
1208 if msg_id in self.pending:
1126 pending.append(msg_id)
1209 pending.append(msg_id)
1127 elif msg_id in self.all_completed or msg_id in records:
1210 elif msg_id in self.all_completed:
1128 completed.append(msg_id)
1211 completed.append(msg_id)
1129 if not statusonly:
1212 if not statusonly:
1130 c,bufs = self._extract_record(records[msg_id])
1213 c,bufs = self._extract_record(records[msg_id])
1131 content[msg_id] = c
1214 content[msg_id] = c
1132 buffers.extend(bufs)
1215 buffers.extend(bufs)
1216 elif msg_id in records:
1217 if rec['completed']:
1218 completed.append(msg_id)
1219 c,bufs = self._extract_record(records[msg_id])
1220 content[msg_id] = c
1221 buffers.extend(bufs)
1222 else:
1223 pending.append(msg_id)
1133 else:
1224 else:
1134 try:
1225 try:
1135 raise KeyError('No such message: '+msg_id)
1226 raise KeyError('No such message: '+msg_id)
1136 except:
1227 except:
1137 content = error.wrap_exception()
1228 content = error.wrap_exception()
1138 break
1229 break
1139 self.session.send(self.query, "result_reply", content=content,
1230 self.session.send(self.query, "result_reply", content=content,
1140 parent=msg, ident=client_id,
1231 parent=msg, ident=client_id,
1141 buffers=buffers)
1232 buffers=buffers)
1142
1233
1143 def get_history(self, client_id, msg):
1234 def get_history(self, client_id, msg):
1144 """Get a list of all msg_ids in our DB records"""
1235 """Get a list of all msg_ids in our DB records"""
1145 try:
1236 try:
1146 msg_ids = self.db.get_history()
1237 msg_ids = self.db.get_history()
1147 except Exception as e:
1238 except Exception as e:
1148 content = error.wrap_exception()
1239 content = error.wrap_exception()
1149 else:
1240 else:
1150 content = dict(status='ok', history=msg_ids)
1241 content = dict(status='ok', history=msg_ids)
1151
1242
1152 self.session.send(self.query, "history_reply", content=content,
1243 self.session.send(self.query, "history_reply", content=content,
1153 parent=msg, ident=client_id)
1244 parent=msg, ident=client_id)
1154
1245
1155 def db_query(self, client_id, msg):
1246 def db_query(self, client_id, msg):
1156 """Perform a raw query on the task record database."""
1247 """Perform a raw query on the task record database."""
1157 content = msg['content']
1248 content = msg['content']
1158 query = content.get('query', {})
1249 query = content.get('query', {})
1159 keys = content.get('keys', None)
1250 keys = content.get('keys', None)
1160 query = util.extract_dates(query)
1251 query = util.extract_dates(query)
1161 buffers = []
1252 buffers = []
1162 empty = list()
1253 empty = list()
1163
1254
1164 try:
1255 try:
1165 records = self.db.find_records(query, keys)
1256 records = self.db.find_records(query, keys)
1166 except Exception as e:
1257 except Exception as e:
1167 content = error.wrap_exception()
1258 content = error.wrap_exception()
1168 else:
1259 else:
1169 # extract buffers from reply content:
1260 # extract buffers from reply content:
1170 if keys is not None:
1261 if keys is not None:
1171 buffer_lens = [] if 'buffers' in keys else None
1262 buffer_lens = [] if 'buffers' in keys else None
1172 result_buffer_lens = [] if 'result_buffers' in keys else None
1263 result_buffer_lens = [] if 'result_buffers' in keys else None
1173 else:
1264 else:
1174 buffer_lens = []
1265 buffer_lens = []
1175 result_buffer_lens = []
1266 result_buffer_lens = []
1176
1267
1177 for rec in records:
1268 for rec in records:
1178 # buffers may be None, so double check
1269 # buffers may be None, so double check
1179 if buffer_lens is not None:
1270 if buffer_lens is not None:
1180 b = rec.pop('buffers', empty) or empty
1271 b = rec.pop('buffers', empty) or empty
1181 buffer_lens.append(len(b))
1272 buffer_lens.append(len(b))
1182 buffers.extend(b)
1273 buffers.extend(b)
1183 if result_buffer_lens is not None:
1274 if result_buffer_lens is not None:
1184 rb = rec.pop('result_buffers', empty) or empty
1275 rb = rec.pop('result_buffers', empty) or empty
1185 result_buffer_lens.append(len(rb))
1276 result_buffer_lens.append(len(rb))
1186 buffers.extend(rb)
1277 buffers.extend(rb)
1187 content = dict(status='ok', records=records, buffer_lens=buffer_lens,
1278 content = dict(status='ok', records=records, buffer_lens=buffer_lens,
1188 result_buffer_lens=result_buffer_lens)
1279 result_buffer_lens=result_buffer_lens)
1189
1280
1190 self.session.send(self.query, "db_reply", content=content,
1281 self.session.send(self.query, "db_reply", content=content,
1191 parent=msg, ident=client_id,
1282 parent=msg, ident=client_id,
1192 buffers=buffers)
1283 buffers=buffers)
1193
1284
@@ -1,416 +1,419 b''
1 #!/usr/bin/env python
1 #!/usr/bin/env python
2 """edited session.py to work with streams, and move msg_type to the header
2 """edited session.py to work with streams, and move msg_type to the header
3 """
3 """
4 #-----------------------------------------------------------------------------
4 #-----------------------------------------------------------------------------
5 # Copyright (C) 2010-2011 The IPython Development Team
5 # Copyright (C) 2010-2011 The IPython Development Team
6 #
6 #
7 # Distributed under the terms of the BSD License. The full license is in
7 # Distributed under the terms of the BSD License. The full license is in
8 # the file COPYING, distributed as part of this software.
8 # the file COPYING, distributed as part of this software.
9 #-----------------------------------------------------------------------------
9 #-----------------------------------------------------------------------------
10
10
11
11
12 import os
12 import os
13 import pprint
13 import pprint
14 import uuid
14 import uuid
15 from datetime import datetime
15 from datetime import datetime
16
16
17 try:
17 try:
18 import cPickle
18 import cPickle
19 pickle = cPickle
19 pickle = cPickle
20 except:
20 except:
21 cPickle = None
21 cPickle = None
22 import pickle
22 import pickle
23
23
24 import zmq
24 import zmq
25 from zmq.utils import jsonapi
25 from zmq.utils import jsonapi
26 from zmq.eventloop.zmqstream import ZMQStream
26 from zmq.eventloop.zmqstream import ZMQStream
27
27
28 from .util import ISO8601
28 from .util import ISO8601
29
29
30 def squash_unicode(obj):
30 def squash_unicode(obj):
31 """coerce unicode back to bytestrings."""
31 """coerce unicode back to bytestrings."""
32 if isinstance(obj,dict):
32 if isinstance(obj,dict):
33 for key in obj.keys():
33 for key in obj.keys():
34 obj[key] = squash_unicode(obj[key])
34 obj[key] = squash_unicode(obj[key])
35 if isinstance(key, unicode):
35 if isinstance(key, unicode):
36 obj[squash_unicode(key)] = obj.pop(key)
36 obj[squash_unicode(key)] = obj.pop(key)
37 elif isinstance(obj, list):
37 elif isinstance(obj, list):
38 for i,v in enumerate(obj):
38 for i,v in enumerate(obj):
39 obj[i] = squash_unicode(v)
39 obj[i] = squash_unicode(v)
40 elif isinstance(obj, unicode):
40 elif isinstance(obj, unicode):
41 obj = obj.encode('utf8')
41 obj = obj.encode('utf8')
42 return obj
42 return obj
43
43
44 def _date_default(obj):
44 def _date_default(obj):
45 if isinstance(obj, datetime):
45 if isinstance(obj, datetime):
46 return obj.strftime(ISO8601)
46 return obj.strftime(ISO8601)
47 else:
47 else:
48 raise TypeError("%r is not JSON serializable"%obj)
48 raise TypeError("%r is not JSON serializable"%obj)
49
49
50 _default_key = 'on_unknown' if jsonapi.jsonmod.__name__ == 'jsonlib' else 'default'
50 _default_key = 'on_unknown' if jsonapi.jsonmod.__name__ == 'jsonlib' else 'default'
51 json_packer = lambda obj: jsonapi.dumps(obj, **{_default_key:_date_default})
51 json_packer = lambda obj: jsonapi.dumps(obj, **{_default_key:_date_default})
52 json_unpacker = lambda s: squash_unicode(jsonapi.loads(s))
52 json_unpacker = lambda s: squash_unicode(jsonapi.loads(s))
53
53
54 pickle_packer = lambda o: pickle.dumps(o,-1)
54 pickle_packer = lambda o: pickle.dumps(o,-1)
55 pickle_unpacker = pickle.loads
55 pickle_unpacker = pickle.loads
56
56
57 default_packer = json_packer
57 default_packer = json_packer
58 default_unpacker = json_unpacker
58 default_unpacker = json_unpacker
59
59
60
60
61 DELIM="<IDS|MSG>"
61 DELIM="<IDS|MSG>"
62
62
63 class Message(object):
63 class Message(object):
64 """A simple message object that maps dict keys to attributes.
64 """A simple message object that maps dict keys to attributes.
65
65
66 A Message can be created from a dict and a dict from a Message instance
66 A Message can be created from a dict and a dict from a Message instance
67 simply by calling dict(msg_obj)."""
67 simply by calling dict(msg_obj)."""
68
68
69 def __init__(self, msg_dict):
69 def __init__(self, msg_dict):
70 dct = self.__dict__
70 dct = self.__dict__
71 for k, v in dict(msg_dict).iteritems():
71 for k, v in dict(msg_dict).iteritems():
72 if isinstance(v, dict):
72 if isinstance(v, dict):
73 v = Message(v)
73 v = Message(v)
74 dct[k] = v
74 dct[k] = v
75
75
76 # Having this iterator lets dict(msg_obj) work out of the box.
76 # Having this iterator lets dict(msg_obj) work out of the box.
77 def __iter__(self):
77 def __iter__(self):
78 return iter(self.__dict__.iteritems())
78 return iter(self.__dict__.iteritems())
79
79
80 def __repr__(self):
80 def __repr__(self):
81 return repr(self.__dict__)
81 return repr(self.__dict__)
82
82
83 def __str__(self):
83 def __str__(self):
84 return pprint.pformat(self.__dict__)
84 return pprint.pformat(self.__dict__)
85
85
86 def __contains__(self, k):
86 def __contains__(self, k):
87 return k in self.__dict__
87 return k in self.__dict__
88
88
89 def __getitem__(self, k):
89 def __getitem__(self, k):
90 return self.__dict__[k]
90 return self.__dict__[k]
91
91
92
92
93 def msg_header(msg_id, msg_type, username, session):
93 def msg_header(msg_id, msg_type, username, session):
94 date=datetime.now().strftime(ISO8601)
94 date=datetime.now().strftime(ISO8601)
95 return locals()
95 return locals()
96
96
97 def extract_header(msg_or_header):
97 def extract_header(msg_or_header):
98 """Given a message or header, return the header."""
98 """Given a message or header, return the header."""
99 if not msg_or_header:
99 if not msg_or_header:
100 return {}
100 return {}
101 try:
101 try:
102 # See if msg_or_header is the entire message.
102 # See if msg_or_header is the entire message.
103 h = msg_or_header['header']
103 h = msg_or_header['header']
104 except KeyError:
104 except KeyError:
105 try:
105 try:
106 # See if msg_or_header is just the header
106 # See if msg_or_header is just the header
107 h = msg_or_header['msg_id']
107 h = msg_or_header['msg_id']
108 except KeyError:
108 except KeyError:
109 raise
109 raise
110 else:
110 else:
111 h = msg_or_header
111 h = msg_or_header
112 if not isinstance(h, dict):
112 if not isinstance(h, dict):
113 h = dict(h)
113 h = dict(h)
114 return h
114 return h
115
115
116 class StreamSession(object):
116 class StreamSession(object):
117 """tweaked version of IPython.zmq.session.Session, for development in Parallel"""
117 """tweaked version of IPython.zmq.session.Session, for development in Parallel"""
118 debug=False
118 debug=False
119 key=None
119 key=None
120
120
121 def __init__(self, username=None, session=None, packer=None, unpacker=None, key=None, keyfile=None):
121 def __init__(self, username=None, session=None, packer=None, unpacker=None, key=None, keyfile=None):
122 if username is None:
122 if username is None:
123 username = os.environ.get('USER','username')
123 username = os.environ.get('USER','username')
124 self.username = username
124 self.username = username
125 if session is None:
125 if session is None:
126 self.session = str(uuid.uuid4())
126 self.session = str(uuid.uuid4())
127 else:
127 else:
128 self.session = session
128 self.session = session
129 self.msg_id = str(uuid.uuid4())
129 self.msg_id = str(uuid.uuid4())
130 if packer is None:
130 if packer is None:
131 self.pack = default_packer
131 self.pack = default_packer
132 else:
132 else:
133 if not callable(packer):
133 if not callable(packer):
134 raise TypeError("packer must be callable, not %s"%type(packer))
134 raise TypeError("packer must be callable, not %s"%type(packer))
135 self.pack = packer
135 self.pack = packer
136
136
137 if unpacker is None:
137 if unpacker is None:
138 self.unpack = default_unpacker
138 self.unpack = default_unpacker
139 else:
139 else:
140 if not callable(unpacker):
140 if not callable(unpacker):
141 raise TypeError("unpacker must be callable, not %s"%type(unpacker))
141 raise TypeError("unpacker must be callable, not %s"%type(unpacker))
142 self.unpack = unpacker
142 self.unpack = unpacker
143
143
144 if key is not None and keyfile is not None:
144 if key is not None and keyfile is not None:
145 raise TypeError("Must specify key OR keyfile, not both")
145 raise TypeError("Must specify key OR keyfile, not both")
146 if keyfile is not None:
146 if keyfile is not None:
147 with open(keyfile) as f:
147 with open(keyfile) as f:
148 self.key = f.read().strip()
148 self.key = f.read().strip()
149 else:
149 else:
150 self.key = key
150 self.key = key
151 if isinstance(self.key, unicode):
151 if isinstance(self.key, unicode):
152 self.key = self.key.encode('utf8')
152 self.key = self.key.encode('utf8')
153 # print key, keyfile, self.key
153 # print key, keyfile, self.key
154 self.none = self.pack({})
154 self.none = self.pack({})
155
155
156 def msg_header(self, msg_type):
156 def msg_header(self, msg_type):
157 h = msg_header(self.msg_id, msg_type, self.username, self.session)
157 h = msg_header(self.msg_id, msg_type, self.username, self.session)
158 self.msg_id = str(uuid.uuid4())
158 self.msg_id = str(uuid.uuid4())
159 return h
159 return h
160
160
161 def msg(self, msg_type, content=None, parent=None, subheader=None):
161 def msg(self, msg_type, content=None, parent=None, subheader=None):
162 msg = {}
162 msg = {}
163 msg['header'] = self.msg_header(msg_type)
163 msg['header'] = self.msg_header(msg_type)
164 msg['msg_id'] = msg['header']['msg_id']
164 msg['msg_id'] = msg['header']['msg_id']
165 msg['parent_header'] = {} if parent is None else extract_header(parent)
165 msg['parent_header'] = {} if parent is None else extract_header(parent)
166 msg['msg_type'] = msg_type
166 msg['msg_type'] = msg_type
167 msg['content'] = {} if content is None else content
167 msg['content'] = {} if content is None else content
168 sub = {} if subheader is None else subheader
168 sub = {} if subheader is None else subheader
169 msg['header'].update(sub)
169 msg['header'].update(sub)
170 return msg
170 return msg
171
171
172 def check_key(self, msg_or_header):
172 def check_key(self, msg_or_header):
173 """Check that a message's header has the right key"""
173 """Check that a message's header has the right key"""
174 if self.key is None:
174 if self.key is None:
175 return True
175 return True
176 header = extract_header(msg_or_header)
176 header = extract_header(msg_or_header)
177 return header.get('key', None) == self.key
177 return header.get('key', None) == self.key
178
178
179
179
180 def serialize(self, msg, ident=None):
180 def serialize(self, msg, ident=None):
181 content = msg.get('content', {})
181 content = msg.get('content', {})
182 if content is None:
182 if content is None:
183 content = self.none
183 content = self.none
184 elif isinstance(content, dict):
184 elif isinstance(content, dict):
185 content = self.pack(content)
185 content = self.pack(content)
186 elif isinstance(content, bytes):
186 elif isinstance(content, bytes):
187 # content is already packed, as in a relayed message
187 # content is already packed, as in a relayed message
188 pass
188 pass
189 elif isinstance(content, unicode):
190 # should be bytes, but JSON often spits out unicode
191 content = content.encode('utf8')
189 else:
192 else:
190 raise TypeError("Content incorrect type: %s"%type(content))
193 raise TypeError("Content incorrect type: %s"%type(content))
191
194
192 to_send = []
195 to_send = []
193
196
194 if isinstance(ident, list):
197 if isinstance(ident, list):
195 # accept list of idents
198 # accept list of idents
196 to_send.extend(ident)
199 to_send.extend(ident)
197 elif ident is not None:
200 elif ident is not None:
198 to_send.append(ident)
201 to_send.append(ident)
199 to_send.append(DELIM)
202 to_send.append(DELIM)
200 if self.key is not None:
203 if self.key is not None:
201 to_send.append(self.key)
204 to_send.append(self.key)
202 to_send.append(self.pack(msg['header']))
205 to_send.append(self.pack(msg['header']))
203 to_send.append(self.pack(msg['parent_header']))
206 to_send.append(self.pack(msg['parent_header']))
204 to_send.append(content)
207 to_send.append(content)
205
208
206 return to_send
209 return to_send
207
210
208 def send(self, stream, msg_or_type, content=None, buffers=None, parent=None, subheader=None, ident=None, track=False):
211 def send(self, stream, msg_or_type, content=None, buffers=None, parent=None, subheader=None, ident=None, track=False):
209 """Build and send a message via stream or socket.
212 """Build and send a message via stream or socket.
210
213
211 Parameters
214 Parameters
212 ----------
215 ----------
213
216
214 stream : zmq.Socket or ZMQStream
217 stream : zmq.Socket or ZMQStream
215 the socket-like object used to send the data
218 the socket-like object used to send the data
216 msg_or_type : str or Message/dict
219 msg_or_type : str or Message/dict
217 Normally, msg_or_type will be a msg_type unless a message is being sent more
220 Normally, msg_or_type will be a msg_type unless a message is being sent more
218 than once.
221 than once.
219
222
220 content : dict or None
223 content : dict or None
221 the content of the message (ignored if msg_or_type is a message)
224 the content of the message (ignored if msg_or_type is a message)
222 buffers : list or None
225 buffers : list or None
223 the already-serialized buffers to be appended to the message
226 the already-serialized buffers to be appended to the message
224 parent : Message or dict or None
227 parent : Message or dict or None
225 the parent or parent header describing the parent of this message
228 the parent or parent header describing the parent of this message
226 subheader : dict or None
229 subheader : dict or None
227 extra header keys for this message's header
230 extra header keys for this message's header
228 ident : bytes or list of bytes
231 ident : bytes or list of bytes
229 the zmq.IDENTITY routing path
232 the zmq.IDENTITY routing path
230 track : bool
233 track : bool
231 whether to track. Only for use with Sockets, because ZMQStream objects cannot track messages.
234 whether to track. Only for use with Sockets, because ZMQStream objects cannot track messages.
232
235
233 Returns
236 Returns
234 -------
237 -------
235 msg : message dict
238 msg : message dict
236 the constructed message
239 the constructed message
237 (msg,tracker) : (message dict, MessageTracker)
240 (msg,tracker) : (message dict, MessageTracker)
238 if track=True, then a 2-tuple will be returned, the first element being the constructed
241 if track=True, then a 2-tuple will be returned, the first element being the constructed
239 message, and the second being the MessageTracker
242 message, and the second being the MessageTracker
240
243
241 """
244 """
242
245
243 if not isinstance(stream, (zmq.Socket, ZMQStream)):
246 if not isinstance(stream, (zmq.Socket, ZMQStream)):
244 raise TypeError("stream must be Socket or ZMQStream, not %r"%type(stream))
247 raise TypeError("stream must be Socket or ZMQStream, not %r"%type(stream))
245 elif track and isinstance(stream, ZMQStream):
248 elif track and isinstance(stream, ZMQStream):
246 raise TypeError("ZMQStream cannot track messages")
249 raise TypeError("ZMQStream cannot track messages")
247
250
248 if isinstance(msg_or_type, (Message, dict)):
251 if isinstance(msg_or_type, (Message, dict)):
249 # we got a Message, not a msg_type
252 # we got a Message, not a msg_type
250 # don't build a new Message
253 # don't build a new Message
251 msg = msg_or_type
254 msg = msg_or_type
252 else:
255 else:
253 msg = self.msg(msg_or_type, content, parent, subheader)
256 msg = self.msg(msg_or_type, content, parent, subheader)
254
257
255 buffers = [] if buffers is None else buffers
258 buffers = [] if buffers is None else buffers
256 to_send = self.serialize(msg, ident)
259 to_send = self.serialize(msg, ident)
257 flag = 0
260 flag = 0
258 if buffers:
261 if buffers:
259 flag = zmq.SNDMORE
262 flag = zmq.SNDMORE
260 _track = False
263 _track = False
261 else:
264 else:
262 _track=track
265 _track=track
263 if track:
266 if track:
264 tracker = stream.send_multipart(to_send, flag, copy=False, track=_track)
267 tracker = stream.send_multipart(to_send, flag, copy=False, track=_track)
265 else:
268 else:
266 tracker = stream.send_multipart(to_send, flag, copy=False)
269 tracker = stream.send_multipart(to_send, flag, copy=False)
267 for b in buffers[:-1]:
270 for b in buffers[:-1]:
268 stream.send(b, flag, copy=False)
271 stream.send(b, flag, copy=False)
269 if buffers:
272 if buffers:
270 if track:
273 if track:
271 tracker = stream.send(buffers[-1], copy=False, track=track)
274 tracker = stream.send(buffers[-1], copy=False, track=track)
272 else:
275 else:
273 tracker = stream.send(buffers[-1], copy=False)
276 tracker = stream.send(buffers[-1], copy=False)
274
277
275 # omsg = Message(msg)
278 # omsg = Message(msg)
276 if self.debug:
279 if self.debug:
277 pprint.pprint(msg)
280 pprint.pprint(msg)
278 pprint.pprint(to_send)
281 pprint.pprint(to_send)
279 pprint.pprint(buffers)
282 pprint.pprint(buffers)
280
283
281 msg['tracker'] = tracker
284 msg['tracker'] = tracker
282
285
283 return msg
286 return msg
284
287
285 def send_raw(self, stream, msg, flags=0, copy=True, ident=None):
288 def send_raw(self, stream, msg, flags=0, copy=True, ident=None):
286 """Send a raw message via ident path.
289 """Send a raw message via ident path.
287
290
288 Parameters
291 Parameters
289 ----------
292 ----------
290 msg : list of sendable buffers"""
293 msg : list of sendable buffers"""
291 to_send = []
294 to_send = []
292 if isinstance(ident, bytes):
295 if isinstance(ident, bytes):
293 ident = [ident]
296 ident = [ident]
294 if ident is not None:
297 if ident is not None:
295 to_send.extend(ident)
298 to_send.extend(ident)
296 to_send.append(DELIM)
299 to_send.append(DELIM)
297 if self.key is not None:
300 if self.key is not None:
298 to_send.append(self.key)
301 to_send.append(self.key)
299 to_send.extend(msg)
302 to_send.extend(msg)
300 stream.send_multipart(msg, flags, copy=copy)
303 stream.send_multipart(msg, flags, copy=copy)
301
304
302 def recv(self, socket, mode=zmq.NOBLOCK, content=True, copy=True):
305 def recv(self, socket, mode=zmq.NOBLOCK, content=True, copy=True):
303 """receives and unpacks a message
306 """receives and unpacks a message
304 returns [idents], msg"""
307 returns [idents], msg"""
305 if isinstance(socket, ZMQStream):
308 if isinstance(socket, ZMQStream):
306 socket = socket.socket
309 socket = socket.socket
307 try:
310 try:
308 msg = socket.recv_multipart(mode)
311 msg = socket.recv_multipart(mode)
309 except zmq.ZMQError as e:
312 except zmq.ZMQError as e:
310 if e.errno == zmq.EAGAIN:
313 if e.errno == zmq.EAGAIN:
311 # We can convert EAGAIN to None as we know in this case
314 # We can convert EAGAIN to None as we know in this case
312 # recv_multipart won't return None.
315 # recv_multipart won't return None.
313 return None
316 return None
314 else:
317 else:
315 raise
318 raise
316 # return an actual Message object
319 # return an actual Message object
317 # determine the number of idents by trying to unpack them.
320 # determine the number of idents by trying to unpack them.
318 # this is terrible:
321 # this is terrible:
319 idents, msg = self.feed_identities(msg, copy)
322 idents, msg = self.feed_identities(msg, copy)
320 try:
323 try:
321 return idents, self.unpack_message(msg, content=content, copy=copy)
324 return idents, self.unpack_message(msg, content=content, copy=copy)
322 except Exception as e:
325 except Exception as e:
323 print (idents, msg)
326 print (idents, msg)
324 # TODO: handle it
327 # TODO: handle it
325 raise e
328 raise e
326
329
327 def feed_identities(self, msg, copy=True):
330 def feed_identities(self, msg, copy=True):
328 """feed until DELIM is reached, then return the prefix as idents and remainder as
331 """feed until DELIM is reached, then return the prefix as idents and remainder as
329 msg. This is easily broken by setting an IDENT to DELIM, but that would be silly.
332 msg. This is easily broken by setting an IDENT to DELIM, but that would be silly.
330
333
331 Parameters
334 Parameters
332 ----------
335 ----------
333 msg : a list of Message or bytes objects
336 msg : a list of Message or bytes objects
334 the message to be split
337 the message to be split
335 copy : bool
338 copy : bool
336 flag determining whether the arguments are bytes or Messages
339 flag determining whether the arguments are bytes or Messages
337
340
338 Returns
341 Returns
339 -------
342 -------
340 (idents,msg) : two lists
343 (idents,msg) : two lists
341 idents will always be a list of bytes - the indentity prefix
344 idents will always be a list of bytes - the indentity prefix
342 msg will be a list of bytes or Messages, unchanged from input
345 msg will be a list of bytes or Messages, unchanged from input
343 msg should be unpackable via self.unpack_message at this point.
346 msg should be unpackable via self.unpack_message at this point.
344 """
347 """
345 ikey = int(self.key is not None)
348 ikey = int(self.key is not None)
346 minlen = 3 + ikey
349 minlen = 3 + ikey
347 msg = list(msg)
350 msg = list(msg)
348 idents = []
351 idents = []
349 while len(msg) > minlen:
352 while len(msg) > minlen:
350 if copy:
353 if copy:
351 s = msg[0]
354 s = msg[0]
352 else:
355 else:
353 s = msg[0].bytes
356 s = msg[0].bytes
354 if s == DELIM:
357 if s == DELIM:
355 msg.pop(0)
358 msg.pop(0)
356 break
359 break
357 else:
360 else:
358 idents.append(s)
361 idents.append(s)
359 msg.pop(0)
362 msg.pop(0)
360
363
361 return idents, msg
364 return idents, msg
362
365
363 def unpack_message(self, msg, content=True, copy=True):
366 def unpack_message(self, msg, content=True, copy=True):
364 """Return a message object from the format
367 """Return a message object from the format
365 sent by self.send.
368 sent by self.send.
366
369
367 Parameters:
370 Parameters:
368 -----------
371 -----------
369
372
370 content : bool (True)
373 content : bool (True)
371 whether to unpack the content dict (True),
374 whether to unpack the content dict (True),
372 or leave it serialized (False)
375 or leave it serialized (False)
373
376
374 copy : bool (True)
377 copy : bool (True)
375 whether to return the bytes (True),
378 whether to return the bytes (True),
376 or the non-copying Message object in each place (False)
379 or the non-copying Message object in each place (False)
377
380
378 """
381 """
379 ikey = int(self.key is not None)
382 ikey = int(self.key is not None)
380 minlen = 3 + ikey
383 minlen = 3 + ikey
381 message = {}
384 message = {}
382 if not copy:
385 if not copy:
383 for i in range(minlen):
386 for i in range(minlen):
384 msg[i] = msg[i].bytes
387 msg[i] = msg[i].bytes
385 if ikey:
388 if ikey:
386 if not self.key == msg[0]:
389 if not self.key == msg[0]:
387 raise KeyError("Invalid Session Key: %s"%msg[0])
390 raise KeyError("Invalid Session Key: %s"%msg[0])
388 if not len(msg) >= minlen:
391 if not len(msg) >= minlen:
389 raise TypeError("malformed message, must have at least %i elements"%minlen)
392 raise TypeError("malformed message, must have at least %i elements"%minlen)
390 message['header'] = self.unpack(msg[ikey+0])
393 message['header'] = self.unpack(msg[ikey+0])
391 message['msg_type'] = message['header']['msg_type']
394 message['msg_type'] = message['header']['msg_type']
392 message['parent_header'] = self.unpack(msg[ikey+1])
395 message['parent_header'] = self.unpack(msg[ikey+1])
393 if content:
396 if content:
394 message['content'] = self.unpack(msg[ikey+2])
397 message['content'] = self.unpack(msg[ikey+2])
395 else:
398 else:
396 message['content'] = msg[ikey+2]
399 message['content'] = msg[ikey+2]
397
400
398 message['buffers'] = msg[ikey+3:]# [ m.buffer for m in msg[3:] ]
401 message['buffers'] = msg[ikey+3:]# [ m.buffer for m in msg[3:] ]
399 return message
402 return message
400
403
401
404
402 def test_msg2obj():
405 def test_msg2obj():
403 am = dict(x=1)
406 am = dict(x=1)
404 ao = Message(am)
407 ao = Message(am)
405 assert ao.x == am['x']
408 assert ao.x == am['x']
406
409
407 am['y'] = dict(z=1)
410 am['y'] = dict(z=1)
408 ao = Message(am)
411 ao = Message(am)
409 assert ao.y.z == am['y']['z']
412 assert ao.y.z == am['y']['z']
410
413
411 k1, k2 = 'y', 'z'
414 k1, k2 = 'y', 'z'
412 assert ao[k1][k2] == am[k1][k2]
415 assert ao[k1][k2] == am[k1][k2]
413
416
414 am2 = dict(ao)
417 am2 = dict(ao)
415 assert am['x'] == am2['x']
418 assert am['x'] == am2['x']
416 assert am['y']['z'] == am2['y']['z']
419 assert am['y']['z'] == am2['y']['z']
@@ -1,214 +1,237 b''
1 """Tests for parallel client.py"""
1 """Tests for parallel client.py"""
2
2
3 #-------------------------------------------------------------------------------
3 #-------------------------------------------------------------------------------
4 # Copyright (C) 2011 The IPython Development Team
4 # Copyright (C) 2011 The IPython Development Team
5 #
5 #
6 # Distributed under the terms of the BSD License. The full license is in
6 # Distributed under the terms of the BSD License. The full license is in
7 # the file COPYING, distributed as part of this software.
7 # the file COPYING, distributed as part of this software.
8 #-------------------------------------------------------------------------------
8 #-------------------------------------------------------------------------------
9
9
10 #-------------------------------------------------------------------------------
10 #-------------------------------------------------------------------------------
11 # Imports
11 # Imports
12 #-------------------------------------------------------------------------------
12 #-------------------------------------------------------------------------------
13
13
14 import time
14 import time
15 from datetime import datetime
15 from datetime import datetime
16 from tempfile import mktemp
16 from tempfile import mktemp
17
17
18 import zmq
18 import zmq
19
19
20 from IPython.parallel.client import client as clientmod
20 from IPython.parallel.client import client as clientmod
21 from IPython.parallel import error
21 from IPython.parallel import error
22 from IPython.parallel import AsyncResult, AsyncHubResult
22 from IPython.parallel import AsyncResult, AsyncHubResult
23 from IPython.parallel import LoadBalancedView, DirectView
23 from IPython.parallel import LoadBalancedView, DirectView
24
24
25 from clienttest import ClusterTestCase, segfault, wait, add_engines
25 from clienttest import ClusterTestCase, segfault, wait, add_engines
26
26
27 def setup():
27 def setup():
28 add_engines(4)
28 add_engines(4)
29
29
30 class TestClient(ClusterTestCase):
30 class TestClient(ClusterTestCase):
31
31
32 def test_ids(self):
32 def test_ids(self):
33 n = len(self.client.ids)
33 n = len(self.client.ids)
34 self.add_engines(3)
34 self.add_engines(3)
35 self.assertEquals(len(self.client.ids), n+3)
35 self.assertEquals(len(self.client.ids), n+3)
36
36
37 def test_view_indexing(self):
37 def test_view_indexing(self):
38 """test index access for views"""
38 """test index access for views"""
39 self.add_engines(2)
39 self.add_engines(2)
40 targets = self.client._build_targets('all')[-1]
40 targets = self.client._build_targets('all')[-1]
41 v = self.client[:]
41 v = self.client[:]
42 self.assertEquals(v.targets, targets)
42 self.assertEquals(v.targets, targets)
43 t = self.client.ids[2]
43 t = self.client.ids[2]
44 v = self.client[t]
44 v = self.client[t]
45 self.assert_(isinstance(v, DirectView))
45 self.assert_(isinstance(v, DirectView))
46 self.assertEquals(v.targets, t)
46 self.assertEquals(v.targets, t)
47 t = self.client.ids[2:4]
47 t = self.client.ids[2:4]
48 v = self.client[t]
48 v = self.client[t]
49 self.assert_(isinstance(v, DirectView))
49 self.assert_(isinstance(v, DirectView))
50 self.assertEquals(v.targets, t)
50 self.assertEquals(v.targets, t)
51 v = self.client[::2]
51 v = self.client[::2]
52 self.assert_(isinstance(v, DirectView))
52 self.assert_(isinstance(v, DirectView))
53 self.assertEquals(v.targets, targets[::2])
53 self.assertEquals(v.targets, targets[::2])
54 v = self.client[1::3]
54 v = self.client[1::3]
55 self.assert_(isinstance(v, DirectView))
55 self.assert_(isinstance(v, DirectView))
56 self.assertEquals(v.targets, targets[1::3])
56 self.assertEquals(v.targets, targets[1::3])
57 v = self.client[:-3]
57 v = self.client[:-3]
58 self.assert_(isinstance(v, DirectView))
58 self.assert_(isinstance(v, DirectView))
59 self.assertEquals(v.targets, targets[:-3])
59 self.assertEquals(v.targets, targets[:-3])
60 v = self.client[-1]
60 v = self.client[-1]
61 self.assert_(isinstance(v, DirectView))
61 self.assert_(isinstance(v, DirectView))
62 self.assertEquals(v.targets, targets[-1])
62 self.assertEquals(v.targets, targets[-1])
63 self.assertRaises(TypeError, lambda : self.client[None])
63 self.assertRaises(TypeError, lambda : self.client[None])
64
64
65 def test_lbview_targets(self):
65 def test_lbview_targets(self):
66 """test load_balanced_view targets"""
66 """test load_balanced_view targets"""
67 v = self.client.load_balanced_view()
67 v = self.client.load_balanced_view()
68 self.assertEquals(v.targets, None)
68 self.assertEquals(v.targets, None)
69 v = self.client.load_balanced_view(-1)
69 v = self.client.load_balanced_view(-1)
70 self.assertEquals(v.targets, [self.client.ids[-1]])
70 self.assertEquals(v.targets, [self.client.ids[-1]])
71 v = self.client.load_balanced_view('all')
71 v = self.client.load_balanced_view('all')
72 self.assertEquals(v.targets, self.client.ids)
72 self.assertEquals(v.targets, self.client.ids)
73
73
74 def test_targets(self):
74 def test_targets(self):
75 """test various valid targets arguments"""
75 """test various valid targets arguments"""
76 build = self.client._build_targets
76 build = self.client._build_targets
77 ids = self.client.ids
77 ids = self.client.ids
78 idents,targets = build(None)
78 idents,targets = build(None)
79 self.assertEquals(ids, targets)
79 self.assertEquals(ids, targets)
80
80
81 def test_clear(self):
81 def test_clear(self):
82 """test clear behavior"""
82 """test clear behavior"""
83 # self.add_engines(2)
83 # self.add_engines(2)
84 v = self.client[:]
84 v = self.client[:]
85 v.block=True
85 v.block=True
86 v.push(dict(a=5))
86 v.push(dict(a=5))
87 v.pull('a')
87 v.pull('a')
88 id0 = self.client.ids[-1]
88 id0 = self.client.ids[-1]
89 self.client.clear(targets=id0, block=True)
89 self.client.clear(targets=id0, block=True)
90 a = self.client[:-1].get('a')
90 a = self.client[:-1].get('a')
91 self.assertRaisesRemote(NameError, self.client[id0].get, 'a')
91 self.assertRaisesRemote(NameError, self.client[id0].get, 'a')
92 self.client.clear(block=True)
92 self.client.clear(block=True)
93 for i in self.client.ids:
93 for i in self.client.ids:
94 # print i
94 # print i
95 self.assertRaisesRemote(NameError, self.client[i].get, 'a')
95 self.assertRaisesRemote(NameError, self.client[i].get, 'a')
96
96
97 def test_get_result(self):
97 def test_get_result(self):
98 """test getting results from the Hub."""
98 """test getting results from the Hub."""
99 c = clientmod.Client(profile='iptest')
99 c = clientmod.Client(profile='iptest')
100 # self.add_engines(1)
100 # self.add_engines(1)
101 t = c.ids[-1]
101 t = c.ids[-1]
102 ar = c[t].apply_async(wait, 1)
102 ar = c[t].apply_async(wait, 1)
103 # give the monitor time to notice the message
103 # give the monitor time to notice the message
104 time.sleep(.25)
104 time.sleep(.25)
105 ahr = self.client.get_result(ar.msg_ids)
105 ahr = self.client.get_result(ar.msg_ids)
106 self.assertTrue(isinstance(ahr, AsyncHubResult))
106 self.assertTrue(isinstance(ahr, AsyncHubResult))
107 self.assertEquals(ahr.get(), ar.get())
107 self.assertEquals(ahr.get(), ar.get())
108 ar2 = self.client.get_result(ar.msg_ids)
108 ar2 = self.client.get_result(ar.msg_ids)
109 self.assertFalse(isinstance(ar2, AsyncHubResult))
109 self.assertFalse(isinstance(ar2, AsyncHubResult))
110 c.close()
110 c.close()
111
111
112 def test_ids_list(self):
112 def test_ids_list(self):
113 """test client.ids"""
113 """test client.ids"""
114 # self.add_engines(2)
114 # self.add_engines(2)
115 ids = self.client.ids
115 ids = self.client.ids
116 self.assertEquals(ids, self.client._ids)
116 self.assertEquals(ids, self.client._ids)
117 self.assertFalse(ids is self.client._ids)
117 self.assertFalse(ids is self.client._ids)
118 ids.remove(ids[-1])
118 ids.remove(ids[-1])
119 self.assertNotEquals(ids, self.client._ids)
119 self.assertNotEquals(ids, self.client._ids)
120
120
121 def test_queue_status(self):
121 def test_queue_status(self):
122 # self.addEngine(4)
122 # self.addEngine(4)
123 ids = self.client.ids
123 ids = self.client.ids
124 id0 = ids[0]
124 id0 = ids[0]
125 qs = self.client.queue_status(targets=id0)
125 qs = self.client.queue_status(targets=id0)
126 self.assertTrue(isinstance(qs, dict))
126 self.assertTrue(isinstance(qs, dict))
127 self.assertEquals(sorted(qs.keys()), ['completed', 'queue', 'tasks'])
127 self.assertEquals(sorted(qs.keys()), ['completed', 'queue', 'tasks'])
128 allqs = self.client.queue_status()
128 allqs = self.client.queue_status()
129 self.assertTrue(isinstance(allqs, dict))
129 self.assertTrue(isinstance(allqs, dict))
130 self.assertEquals(sorted(allqs.keys()), sorted(self.client.ids + ['unassigned']))
130 self.assertEquals(sorted(allqs.keys()), sorted(self.client.ids + ['unassigned']))
131 unassigned = allqs.pop('unassigned')
131 unassigned = allqs.pop('unassigned')
132 for eid,qs in allqs.items():
132 for eid,qs in allqs.items():
133 self.assertTrue(isinstance(qs, dict))
133 self.assertTrue(isinstance(qs, dict))
134 self.assertEquals(sorted(qs.keys()), ['completed', 'queue', 'tasks'])
134 self.assertEquals(sorted(qs.keys()), ['completed', 'queue', 'tasks'])
135
135
136 def test_shutdown(self):
136 def test_shutdown(self):
137 # self.addEngine(4)
137 # self.addEngine(4)
138 ids = self.client.ids
138 ids = self.client.ids
139 id0 = ids[0]
139 id0 = ids[0]
140 self.client.shutdown(id0, block=True)
140 self.client.shutdown(id0, block=True)
141 while id0 in self.client.ids:
141 while id0 in self.client.ids:
142 time.sleep(0.1)
142 time.sleep(0.1)
143 self.client.spin()
143 self.client.spin()
144
144
145 self.assertRaises(IndexError, lambda : self.client[id0])
145 self.assertRaises(IndexError, lambda : self.client[id0])
146
146
147 def test_result_status(self):
147 def test_result_status(self):
148 pass
148 pass
149 # to be written
149 # to be written
150
150
151 def test_db_query_dt(self):
151 def test_db_query_dt(self):
152 """test db query by date"""
152 """test db query by date"""
153 hist = self.client.hub_history()
153 hist = self.client.hub_history()
154 middle = self.client.db_query({'msg_id' : hist[len(hist)/2]})[0]
154 middle = self.client.db_query({'msg_id' : hist[len(hist)/2]})[0]
155 tic = middle['submitted']
155 tic = middle['submitted']
156 before = self.client.db_query({'submitted' : {'$lt' : tic}})
156 before = self.client.db_query({'submitted' : {'$lt' : tic}})
157 after = self.client.db_query({'submitted' : {'$gte' : tic}})
157 after = self.client.db_query({'submitted' : {'$gte' : tic}})
158 self.assertEquals(len(before)+len(after),len(hist))
158 self.assertEquals(len(before)+len(after),len(hist))
159 for b in before:
159 for b in before:
160 self.assertTrue(b['submitted'] < tic)
160 self.assertTrue(b['submitted'] < tic)
161 for a in after:
161 for a in after:
162 self.assertTrue(a['submitted'] >= tic)
162 self.assertTrue(a['submitted'] >= tic)
163 same = self.client.db_query({'submitted' : tic})
163 same = self.client.db_query({'submitted' : tic})
164 for s in same:
164 for s in same:
165 self.assertTrue(s['submitted'] == tic)
165 self.assertTrue(s['submitted'] == tic)
166
166
167 def test_db_query_keys(self):
167 def test_db_query_keys(self):
168 """test extracting subset of record keys"""
168 """test extracting subset of record keys"""
169 found = self.client.db_query({'msg_id': {'$ne' : ''}},keys=['submitted', 'completed'])
169 found = self.client.db_query({'msg_id': {'$ne' : ''}},keys=['submitted', 'completed'])
170 for rec in found:
170 for rec in found:
171 self.assertEquals(set(rec.keys()), set(['msg_id', 'submitted', 'completed']))
171 self.assertEquals(set(rec.keys()), set(['msg_id', 'submitted', 'completed']))
172
172
173 def test_db_query_msg_id(self):
173 def test_db_query_msg_id(self):
174 """ensure msg_id is always in db queries"""
174 """ensure msg_id is always in db queries"""
175 found = self.client.db_query({'msg_id': {'$ne' : ''}},keys=['submitted', 'completed'])
175 found = self.client.db_query({'msg_id': {'$ne' : ''}},keys=['submitted', 'completed'])
176 for rec in found:
176 for rec in found:
177 self.assertTrue('msg_id' in rec.keys())
177 self.assertTrue('msg_id' in rec.keys())
178 found = self.client.db_query({'msg_id': {'$ne' : ''}},keys=['submitted'])
178 found = self.client.db_query({'msg_id': {'$ne' : ''}},keys=['submitted'])
179 for rec in found:
179 for rec in found:
180 self.assertTrue('msg_id' in rec.keys())
180 self.assertTrue('msg_id' in rec.keys())
181 found = self.client.db_query({'msg_id': {'$ne' : ''}},keys=['msg_id'])
181 found = self.client.db_query({'msg_id': {'$ne' : ''}},keys=['msg_id'])
182 for rec in found:
182 for rec in found:
183 self.assertTrue('msg_id' in rec.keys())
183 self.assertTrue('msg_id' in rec.keys())
184
184
185 def test_db_query_in(self):
185 def test_db_query_in(self):
186 """test db query with '$in','$nin' operators"""
186 """test db query with '$in','$nin' operators"""
187 hist = self.client.hub_history()
187 hist = self.client.hub_history()
188 even = hist[::2]
188 even = hist[::2]
189 odd = hist[1::2]
189 odd = hist[1::2]
190 recs = self.client.db_query({ 'msg_id' : {'$in' : even}})
190 recs = self.client.db_query({ 'msg_id' : {'$in' : even}})
191 found = [ r['msg_id'] for r in recs ]
191 found = [ r['msg_id'] for r in recs ]
192 self.assertEquals(set(even), set(found))
192 self.assertEquals(set(even), set(found))
193 recs = self.client.db_query({ 'msg_id' : {'$nin' : even}})
193 recs = self.client.db_query({ 'msg_id' : {'$nin' : even}})
194 found = [ r['msg_id'] for r in recs ]
194 found = [ r['msg_id'] for r in recs ]
195 self.assertEquals(set(odd), set(found))
195 self.assertEquals(set(odd), set(found))
196
196
197 def test_hub_history(self):
197 def test_hub_history(self):
198 hist = self.client.hub_history()
198 hist = self.client.hub_history()
199 recs = self.client.db_query({ 'msg_id' : {"$ne":''}})
199 recs = self.client.db_query({ 'msg_id' : {"$ne":''}})
200 recdict = {}
200 recdict = {}
201 for rec in recs:
201 for rec in recs:
202 recdict[rec['msg_id']] = rec
202 recdict[rec['msg_id']] = rec
203
203
204 latest = datetime(1984,1,1)
204 latest = datetime(1984,1,1)
205 for msg_id in hist:
205 for msg_id in hist:
206 rec = recdict[msg_id]
206 rec = recdict[msg_id]
207 newt = rec['submitted']
207 newt = rec['submitted']
208 self.assertTrue(newt >= latest)
208 self.assertTrue(newt >= latest)
209 latest = newt
209 latest = newt
210 ar = self.client[-1].apply_async(lambda : 1)
210 ar = self.client[-1].apply_async(lambda : 1)
211 ar.get()
211 ar.get()
212 time.sleep(0.25)
212 time.sleep(0.25)
213 self.assertEquals(self.client.hub_history()[-1:],ar.msg_ids)
213 self.assertEquals(self.client.hub_history()[-1:],ar.msg_ids)
214
214
215 def test_resubmit(self):
216 def f():
217 import random
218 return random.random()
219 v = self.client.load_balanced_view()
220 ar = v.apply_async(f)
221 r1 = ar.get(1)
222 ahr = self.client.resubmit(ar.msg_ids)
223 r2 = ahr.get(1)
224 self.assertFalse(r1 == r2)
225
226 def test_resubmit_inflight(self):
227 """ensure ValueError on resubmit of inflight task"""
228 v = self.client.load_balanced_view()
229 ar = v.apply_async(time.sleep,1)
230 # give the message a chance to arrive
231 time.sleep(0.2)
232 self.assertRaisesRemote(ValueError, self.client.resubmit, ar.msg_ids)
233 ar.get(2)
234
235 def test_resubmit_badkey(self):
236 """ensure KeyError on resubmit of nonexistant task"""
237 self.assertRaisesRemote(KeyError, self.client.resubmit, ['invalid'])
@@ -1,120 +1,120 b''
1 """test LoadBalancedView objects"""
1 """test LoadBalancedView objects"""
2 # -*- coding: utf-8 -*-
2 # -*- coding: utf-8 -*-
3 #-------------------------------------------------------------------------------
3 #-------------------------------------------------------------------------------
4 # Copyright (C) 2011 The IPython Development Team
4 # Copyright (C) 2011 The IPython Development Team
5 #
5 #
6 # Distributed under the terms of the BSD License. The full license is in
6 # Distributed under the terms of the BSD License. The full license is in
7 # the file COPYING, distributed as part of this software.
7 # the file COPYING, distributed as part of this software.
8 #-------------------------------------------------------------------------------
8 #-------------------------------------------------------------------------------
9
9
10 #-------------------------------------------------------------------------------
10 #-------------------------------------------------------------------------------
11 # Imports
11 # Imports
12 #-------------------------------------------------------------------------------
12 #-------------------------------------------------------------------------------
13
13
14 import sys
14 import sys
15 import time
15 import time
16
16
17 import zmq
17 import zmq
18
18
19 from IPython import parallel as pmod
19 from IPython import parallel as pmod
20 from IPython.parallel import error
20 from IPython.parallel import error
21
21
22 from IPython.parallel.tests import add_engines
22 from IPython.parallel.tests import add_engines
23
23
24 from .clienttest import ClusterTestCase, crash, wait, skip_without
24 from .clienttest import ClusterTestCase, crash, wait, skip_without
25
25
26 def setup():
26 def setup():
27 add_engines(3)
27 add_engines(3)
28
28
29 class TestLoadBalancedView(ClusterTestCase):
29 class TestLoadBalancedView(ClusterTestCase):
30
30
31 def setUp(self):
31 def setUp(self):
32 ClusterTestCase.setUp(self)
32 ClusterTestCase.setUp(self)
33 self.view = self.client.load_balanced_view()
33 self.view = self.client.load_balanced_view()
34
34
35 def test_z_crash_task(self):
35 def test_z_crash_task(self):
36 """test graceful handling of engine death (balanced)"""
36 """test graceful handling of engine death (balanced)"""
37 # self.add_engines(1)
37 # self.add_engines(1)
38 ar = self.view.apply_async(crash)
38 ar = self.view.apply_async(crash)
39 self.assertRaisesRemote(error.EngineError, ar.get)
39 self.assertRaisesRemote(error.EngineError, ar.get, 10)
40 eid = ar.engine_id
40 eid = ar.engine_id
41 tic = time.time()
41 tic = time.time()
42 while eid in self.client.ids and time.time()-tic < 5:
42 while eid in self.client.ids and time.time()-tic < 5:
43 time.sleep(.01)
43 time.sleep(.01)
44 self.client.spin()
44 self.client.spin()
45 self.assertFalse(eid in self.client.ids, "Engine should have died")
45 self.assertFalse(eid in self.client.ids, "Engine should have died")
46
46
47 def test_map(self):
47 def test_map(self):
48 def f(x):
48 def f(x):
49 return x**2
49 return x**2
50 data = range(16)
50 data = range(16)
51 r = self.view.map_sync(f, data)
51 r = self.view.map_sync(f, data)
52 self.assertEquals(r, map(f, data))
52 self.assertEquals(r, map(f, data))
53
53
54 def test_abort(self):
54 def test_abort(self):
55 view = self.view
55 view = self.view
56 ar = self.client[:].apply_async(time.sleep, .5)
56 ar = self.client[:].apply_async(time.sleep, .5)
57 ar2 = view.apply_async(lambda : 2)
57 ar2 = view.apply_async(lambda : 2)
58 ar3 = view.apply_async(lambda : 3)
58 ar3 = view.apply_async(lambda : 3)
59 view.abort(ar2)
59 view.abort(ar2)
60 view.abort(ar3.msg_ids)
60 view.abort(ar3.msg_ids)
61 self.assertRaises(error.TaskAborted, ar2.get)
61 self.assertRaises(error.TaskAborted, ar2.get)
62 self.assertRaises(error.TaskAborted, ar3.get)
62 self.assertRaises(error.TaskAborted, ar3.get)
63
63
64 def test_retries(self):
64 def test_retries(self):
65 add_engines(3)
65 add_engines(3)
66 view = self.view
66 view = self.view
67 view.timeout = 1 # prevent hang if this doesn't behave
67 view.timeout = 1 # prevent hang if this doesn't behave
68 def fail():
68 def fail():
69 assert False
69 assert False
70 for r in range(len(self.client)-1):
70 for r in range(len(self.client)-1):
71 with view.temp_flags(retries=r):
71 with view.temp_flags(retries=r):
72 self.assertRaisesRemote(AssertionError, view.apply_sync, fail)
72 self.assertRaisesRemote(AssertionError, view.apply_sync, fail)
73
73
74 with view.temp_flags(retries=len(self.client), timeout=0.25):
74 with view.temp_flags(retries=len(self.client), timeout=0.25):
75 self.assertRaisesRemote(error.TaskTimeout, view.apply_sync, fail)
75 self.assertRaisesRemote(error.TaskTimeout, view.apply_sync, fail)
76
76
77 def test_invalid_dependency(self):
77 def test_invalid_dependency(self):
78 view = self.view
78 view = self.view
79 with view.temp_flags(after='12345'):
79 with view.temp_flags(after='12345'):
80 self.assertRaisesRemote(error.InvalidDependency, view.apply_sync, lambda : 1)
80 self.assertRaisesRemote(error.InvalidDependency, view.apply_sync, lambda : 1)
81
81
82 def test_impossible_dependency(self):
82 def test_impossible_dependency(self):
83 if len(self.client) < 2:
83 if len(self.client) < 2:
84 add_engines(2)
84 add_engines(2)
85 view = self.client.load_balanced_view()
85 view = self.client.load_balanced_view()
86 ar1 = view.apply_async(lambda : 1)
86 ar1 = view.apply_async(lambda : 1)
87 ar1.get()
87 ar1.get()
88 e1 = ar1.engine_id
88 e1 = ar1.engine_id
89 e2 = e1
89 e2 = e1
90 while e2 == e1:
90 while e2 == e1:
91 ar2 = view.apply_async(lambda : 1)
91 ar2 = view.apply_async(lambda : 1)
92 ar2.get()
92 ar2.get()
93 e2 = ar2.engine_id
93 e2 = ar2.engine_id
94
94
95 with view.temp_flags(follow=[ar1, ar2]):
95 with view.temp_flags(follow=[ar1, ar2]):
96 self.assertRaisesRemote(error.ImpossibleDependency, view.apply_sync, lambda : 1)
96 self.assertRaisesRemote(error.ImpossibleDependency, view.apply_sync, lambda : 1)
97
97
98
98
99 def test_follow(self):
99 def test_follow(self):
100 ar = self.view.apply_async(lambda : 1)
100 ar = self.view.apply_async(lambda : 1)
101 ar.get()
101 ar.get()
102 ars = []
102 ars = []
103 first_id = ar.engine_id
103 first_id = ar.engine_id
104
104
105 self.view.follow = ar
105 self.view.follow = ar
106 for i in range(5):
106 for i in range(5):
107 ars.append(self.view.apply_async(lambda : 1))
107 ars.append(self.view.apply_async(lambda : 1))
108 self.view.wait(ars)
108 self.view.wait(ars)
109 for ar in ars:
109 for ar in ars:
110 self.assertEquals(ar.engine_id, first_id)
110 self.assertEquals(ar.engine_id, first_id)
111
111
112 def test_after(self):
112 def test_after(self):
113 view = self.view
113 view = self.view
114 ar = view.apply_async(time.sleep, 0.5)
114 ar = view.apply_async(time.sleep, 0.5)
115 with view.temp_flags(after=ar):
115 with view.temp_flags(after=ar):
116 ar2 = view.apply_async(lambda : 1)
116 ar2 = view.apply_async(lambda : 1)
117
117
118 ar.wait()
118 ar.wait()
119 ar2.wait()
119 ar2.wait()
120 self.assertTrue(ar2.started > ar.completed)
120 self.assertTrue(ar2.started > ar.completed)
General Comments 0
You need to be logged in to leave comments. Login now