##// END OF EJS Templates
introduce is_url() replacing validate_url() in cases where it is used in a control structure (as opposed to an assert)
Hannes Schulz -
Show More
@@ -1,1435 +1,1433 b''
1 """A semi-synchronous Client for the ZMQ cluster
1 """A semi-synchronous Client for the ZMQ cluster
2
2
3 Authors:
3 Authors:
4
4
5 * MinRK
5 * MinRK
6 """
6 """
7 #-----------------------------------------------------------------------------
7 #-----------------------------------------------------------------------------
8 # Copyright (C) 2010-2011 The IPython Development Team
8 # Copyright (C) 2010-2011 The IPython Development Team
9 #
9 #
10 # Distributed under the terms of the BSD License. The full license is in
10 # Distributed under the terms of the BSD License. The full license is in
11 # the file COPYING, distributed as part of this software.
11 # the file COPYING, distributed as part of this software.
12 #-----------------------------------------------------------------------------
12 #-----------------------------------------------------------------------------
13
13
14 #-----------------------------------------------------------------------------
14 #-----------------------------------------------------------------------------
15 # Imports
15 # Imports
16 #-----------------------------------------------------------------------------
16 #-----------------------------------------------------------------------------
17
17
18 import os
18 import os
19 import json
19 import json
20 import sys
20 import sys
21 import time
21 import time
22 import warnings
22 import warnings
23 from datetime import datetime
23 from datetime import datetime
24 from getpass import getpass
24 from getpass import getpass
25 from pprint import pprint
25 from pprint import pprint
26
26
27 pjoin = os.path.join
27 pjoin = os.path.join
28
28
29 import zmq
29 import zmq
30 # from zmq.eventloop import ioloop, zmqstream
30 # from zmq.eventloop import ioloop, zmqstream
31
31
32 from IPython.config.configurable import MultipleInstanceError
32 from IPython.config.configurable import MultipleInstanceError
33 from IPython.core.application import BaseIPythonApplication
33 from IPython.core.application import BaseIPythonApplication
34
34
35 from IPython.utils.jsonutil import rekey
35 from IPython.utils.jsonutil import rekey
36 from IPython.utils.localinterfaces import LOCAL_IPS
36 from IPython.utils.localinterfaces import LOCAL_IPS
37 from IPython.utils.path import get_ipython_dir
37 from IPython.utils.path import get_ipython_dir
38 from IPython.utils.traitlets import (HasTraits, Int, Instance, Unicode,
38 from IPython.utils.traitlets import (HasTraits, Int, Instance, Unicode,
39 Dict, List, Bool, Set)
39 Dict, List, Bool, Set)
40 from IPython.external.decorator import decorator
40 from IPython.external.decorator import decorator
41 from IPython.external.ssh import tunnel
41 from IPython.external.ssh import tunnel
42
42
43 from IPython.parallel import error
43 from IPython.parallel import error
44 from IPython.parallel import util
44 from IPython.parallel import util
45
45
46 from IPython.zmq.session import Session, Message
46 from IPython.zmq.session import Session, Message
47
47
48 from .asyncresult import AsyncResult, AsyncHubResult
48 from .asyncresult import AsyncResult, AsyncHubResult
49 from IPython.core.profiledir import ProfileDir, ProfileDirError
49 from IPython.core.profiledir import ProfileDir, ProfileDirError
50 from .view import DirectView, LoadBalancedView
50 from .view import DirectView, LoadBalancedView
51
51
52 if sys.version_info[0] >= 3:
52 if sys.version_info[0] >= 3:
53 # xrange is used in a couple 'isinstance' tests in py2
53 # xrange is used in a couple 'isinstance' tests in py2
54 # should be just 'range' in 3k
54 # should be just 'range' in 3k
55 xrange = range
55 xrange = range
56
56
57 #--------------------------------------------------------------------------
57 #--------------------------------------------------------------------------
58 # Decorators for Client methods
58 # Decorators for Client methods
59 #--------------------------------------------------------------------------
59 #--------------------------------------------------------------------------
60
60
61 @decorator
61 @decorator
62 def spin_first(f, self, *args, **kwargs):
62 def spin_first(f, self, *args, **kwargs):
63 """Call spin() to sync state prior to calling the method."""
63 """Call spin() to sync state prior to calling the method."""
64 self.spin()
64 self.spin()
65 return f(self, *args, **kwargs)
65 return f(self, *args, **kwargs)
66
66
67
67
68 #--------------------------------------------------------------------------
68 #--------------------------------------------------------------------------
69 # Classes
69 # Classes
70 #--------------------------------------------------------------------------
70 #--------------------------------------------------------------------------
71
71
72 class Metadata(dict):
72 class Metadata(dict):
73 """Subclass of dict for initializing metadata values.
73 """Subclass of dict for initializing metadata values.
74
74
75 Attribute access works on keys.
75 Attribute access works on keys.
76
76
77 These objects have a strict set of keys - errors will raise if you try
77 These objects have a strict set of keys - errors will raise if you try
78 to add new keys.
78 to add new keys.
79 """
79 """
80 def __init__(self, *args, **kwargs):
80 def __init__(self, *args, **kwargs):
81 dict.__init__(self)
81 dict.__init__(self)
82 md = {'msg_id' : None,
82 md = {'msg_id' : None,
83 'submitted' : None,
83 'submitted' : None,
84 'started' : None,
84 'started' : None,
85 'completed' : None,
85 'completed' : None,
86 'received' : None,
86 'received' : None,
87 'engine_uuid' : None,
87 'engine_uuid' : None,
88 'engine_id' : None,
88 'engine_id' : None,
89 'follow' : None,
89 'follow' : None,
90 'after' : None,
90 'after' : None,
91 'status' : None,
91 'status' : None,
92
92
93 'pyin' : None,
93 'pyin' : None,
94 'pyout' : None,
94 'pyout' : None,
95 'pyerr' : None,
95 'pyerr' : None,
96 'stdout' : '',
96 'stdout' : '',
97 'stderr' : '',
97 'stderr' : '',
98 }
98 }
99 self.update(md)
99 self.update(md)
100 self.update(dict(*args, **kwargs))
100 self.update(dict(*args, **kwargs))
101
101
102 def __getattr__(self, key):
102 def __getattr__(self, key):
103 """getattr aliased to getitem"""
103 """getattr aliased to getitem"""
104 if key in self.iterkeys():
104 if key in self.iterkeys():
105 return self[key]
105 return self[key]
106 else:
106 else:
107 raise AttributeError(key)
107 raise AttributeError(key)
108
108
109 def __setattr__(self, key, value):
109 def __setattr__(self, key, value):
110 """setattr aliased to setitem, with strict"""
110 """setattr aliased to setitem, with strict"""
111 if key in self.iterkeys():
111 if key in self.iterkeys():
112 self[key] = value
112 self[key] = value
113 else:
113 else:
114 raise AttributeError(key)
114 raise AttributeError(key)
115
115
116 def __setitem__(self, key, value):
116 def __setitem__(self, key, value):
117 """strict static key enforcement"""
117 """strict static key enforcement"""
118 if key in self.iterkeys():
118 if key in self.iterkeys():
119 dict.__setitem__(self, key, value)
119 dict.__setitem__(self, key, value)
120 else:
120 else:
121 raise KeyError(key)
121 raise KeyError(key)
122
122
123
123
124 class Client(HasTraits):
124 class Client(HasTraits):
125 """A semi-synchronous client to the IPython ZMQ cluster
125 """A semi-synchronous client to the IPython ZMQ cluster
126
126
127 Parameters
127 Parameters
128 ----------
128 ----------
129
129
130 url_or_file : bytes or unicode; zmq url or path to ipcontroller-client.json
130 url_or_file : bytes or unicode; zmq url or path to ipcontroller-client.json
131 Connection information for the Hub's registration. If a json connector
131 Connection information for the Hub's registration. If a json connector
132 file is given, then likely no further configuration is necessary.
132 file is given, then likely no further configuration is necessary.
133 [Default: use profile]
133 [Default: use profile]
134 profile : bytes
134 profile : bytes
135 The name of the Cluster profile to be used to find connector information.
135 The name of the Cluster profile to be used to find connector information.
136 If run from an IPython application, the default profile will be the same
136 If run from an IPython application, the default profile will be the same
137 as the running application, otherwise it will be 'default'.
137 as the running application, otherwise it will be 'default'.
138 context : zmq.Context
138 context : zmq.Context
139 Pass an existing zmq.Context instance, otherwise the client will create its own.
139 Pass an existing zmq.Context instance, otherwise the client will create its own.
140 debug : bool
140 debug : bool
141 flag for lots of message printing for debug purposes
141 flag for lots of message printing for debug purposes
142 timeout : int/float
142 timeout : int/float
143 time (in seconds) to wait for connection replies from the Hub
143 time (in seconds) to wait for connection replies from the Hub
144 [Default: 10]
144 [Default: 10]
145
145
146 #-------------- session related args ----------------
146 #-------------- session related args ----------------
147
147
148 config : Config object
148 config : Config object
149 If specified, this will be relayed to the Session for configuration
149 If specified, this will be relayed to the Session for configuration
150 username : str
150 username : str
151 set username for the session object
151 set username for the session object
152 packer : str (import_string) or callable
152 packer : str (import_string) or callable
153 Can be either the simple keyword 'json' or 'pickle', or an import_string to a
153 Can be either the simple keyword 'json' or 'pickle', or an import_string to a
154 function to serialize messages. Must support same input as
154 function to serialize messages. Must support same input as
155 JSON, and output must be bytes.
155 JSON, and output must be bytes.
156 You can pass a callable directly as `pack`
156 You can pass a callable directly as `pack`
157 unpacker : str (import_string) or callable
157 unpacker : str (import_string) or callable
158 The inverse of packer. Only necessary if packer is specified as *not* one
158 The inverse of packer. Only necessary if packer is specified as *not* one
159 of 'json' or 'pickle'.
159 of 'json' or 'pickle'.
160
160
161 #-------------- ssh related args ----------------
161 #-------------- ssh related args ----------------
162 # These are args for configuring the ssh tunnel to be used
162 # These are args for configuring the ssh tunnel to be used
163 # credentials are used to forward connections over ssh to the Controller
163 # credentials are used to forward connections over ssh to the Controller
164 # Note that the ip given in `addr` needs to be relative to sshserver
164 # Note that the ip given in `addr` needs to be relative to sshserver
165 # The most basic case is to leave addr as pointing to localhost (127.0.0.1),
165 # The most basic case is to leave addr as pointing to localhost (127.0.0.1),
166 # and set sshserver as the same machine the Controller is on. However,
166 # and set sshserver as the same machine the Controller is on. However,
167 # the only requirement is that sshserver is able to see the Controller
167 # the only requirement is that sshserver is able to see the Controller
168 # (i.e. is within the same trusted network).
168 # (i.e. is within the same trusted network).
169
169
170 sshserver : str
170 sshserver : str
171 A string of the form passed to ssh, i.e. 'server.tld' or 'user@server.tld:port'
171 A string of the form passed to ssh, i.e. 'server.tld' or 'user@server.tld:port'
172 If keyfile or password is specified, and this is not, it will default to
172 If keyfile or password is specified, and this is not, it will default to
173 the ip given in addr.
173 the ip given in addr.
174 sshkey : str; path to ssh private key file
174 sshkey : str; path to ssh private key file
175 This specifies a key to be used in ssh login, default None.
175 This specifies a key to be used in ssh login, default None.
176 Regular default ssh keys will be used without specifying this argument.
176 Regular default ssh keys will be used without specifying this argument.
177 password : str
177 password : str
178 Your ssh password to sshserver. Note that if this is left None,
178 Your ssh password to sshserver. Note that if this is left None,
179 you will be prompted for it if passwordless key based login is unavailable.
179 you will be prompted for it if passwordless key based login is unavailable.
180 paramiko : bool
180 paramiko : bool
181 flag for whether to use paramiko instead of shell ssh for tunneling.
181 flag for whether to use paramiko instead of shell ssh for tunneling.
182 [default: True on win32, False else]
182 [default: True on win32, False else]
183
183
184 ------- exec authentication args -------
184 ------- exec authentication args -------
185 If even localhost is untrusted, you can have some protection against
185 If even localhost is untrusted, you can have some protection against
186 unauthorized execution by signing messages with HMAC digests.
186 unauthorized execution by signing messages with HMAC digests.
187 Messages are still sent as cleartext, so if someone can snoop your
187 Messages are still sent as cleartext, so if someone can snoop your
188 loopback traffic this will not protect your privacy, but will prevent
188 loopback traffic this will not protect your privacy, but will prevent
189 unauthorized execution.
189 unauthorized execution.
190
190
191 exec_key : str
191 exec_key : str
192 an authentication key or file containing a key
192 an authentication key or file containing a key
193 default: None
193 default: None
194
194
195
195
196 Attributes
196 Attributes
197 ----------
197 ----------
198
198
199 ids : list of int engine IDs
199 ids : list of int engine IDs
200 requesting the ids attribute always synchronizes
200 requesting the ids attribute always synchronizes
201 the registration state. To request ids without synchronization,
201 the registration state. To request ids without synchronization,
202 use semi-private _ids attributes.
202 use semi-private _ids attributes.
203
203
204 history : list of msg_ids
204 history : list of msg_ids
205 a list of msg_ids, keeping track of all the execution
205 a list of msg_ids, keeping track of all the execution
206 messages you have submitted in order.
206 messages you have submitted in order.
207
207
208 outstanding : set of msg_ids
208 outstanding : set of msg_ids
209 a set of msg_ids that have been submitted, but whose
209 a set of msg_ids that have been submitted, but whose
210 results have not yet been received.
210 results have not yet been received.
211
211
212 results : dict
212 results : dict
213 a dict of all our results, keyed by msg_id
213 a dict of all our results, keyed by msg_id
214
214
215 block : bool
215 block : bool
216 determines default behavior when block not specified
216 determines default behavior when block not specified
217 in execution methods
217 in execution methods
218
218
219 Methods
219 Methods
220 -------
220 -------
221
221
222 spin
222 spin
223 flushes incoming results and registration state changes
223 flushes incoming results and registration state changes
224 control methods spin, and requesting `ids` also ensures up to date
224 control methods spin, and requesting `ids` also ensures up to date
225
225
226 wait
226 wait
227 wait on one or more msg_ids
227 wait on one or more msg_ids
228
228
229 execution methods
229 execution methods
230 apply
230 apply
231 legacy: execute, run
231 legacy: execute, run
232
232
233 data movement
233 data movement
234 push, pull, scatter, gather
234 push, pull, scatter, gather
235
235
236 query methods
236 query methods
237 queue_status, get_result, purge, result_status
237 queue_status, get_result, purge, result_status
238
238
239 control methods
239 control methods
240 abort, shutdown
240 abort, shutdown
241
241
242 """
242 """
243
243
244
244
245 block = Bool(False)
245 block = Bool(False)
246 outstanding = Set()
246 outstanding = Set()
247 results = Instance('collections.defaultdict', (dict,))
247 results = Instance('collections.defaultdict', (dict,))
248 metadata = Instance('collections.defaultdict', (Metadata,))
248 metadata = Instance('collections.defaultdict', (Metadata,))
249 history = List()
249 history = List()
250 debug = Bool(False)
250 debug = Bool(False)
251
251
252 profile=Unicode()
252 profile=Unicode()
253 def _profile_default(self):
253 def _profile_default(self):
254 if BaseIPythonApplication.initialized():
254 if BaseIPythonApplication.initialized():
255 # an IPython app *might* be running, try to get its profile
255 # an IPython app *might* be running, try to get its profile
256 try:
256 try:
257 return BaseIPythonApplication.instance().profile
257 return BaseIPythonApplication.instance().profile
258 except (AttributeError, MultipleInstanceError):
258 except (AttributeError, MultipleInstanceError):
259 # could be a *different* subclass of config.Application,
259 # could be a *different* subclass of config.Application,
260 # which would raise one of these two errors.
260 # which would raise one of these two errors.
261 return u'default'
261 return u'default'
262 else:
262 else:
263 return u'default'
263 return u'default'
264
264
265
265
266 _outstanding_dict = Instance('collections.defaultdict', (set,))
266 _outstanding_dict = Instance('collections.defaultdict', (set,))
267 _ids = List()
267 _ids = List()
268 _connected=Bool(False)
268 _connected=Bool(False)
269 _ssh=Bool(False)
269 _ssh=Bool(False)
270 _context = Instance('zmq.Context')
270 _context = Instance('zmq.Context')
271 _config = Dict()
271 _config = Dict()
272 _engines=Instance(util.ReverseDict, (), {})
272 _engines=Instance(util.ReverseDict, (), {})
273 # _hub_socket=Instance('zmq.Socket')
273 # _hub_socket=Instance('zmq.Socket')
274 _query_socket=Instance('zmq.Socket')
274 _query_socket=Instance('zmq.Socket')
275 _control_socket=Instance('zmq.Socket')
275 _control_socket=Instance('zmq.Socket')
276 _iopub_socket=Instance('zmq.Socket')
276 _iopub_socket=Instance('zmq.Socket')
277 _notification_socket=Instance('zmq.Socket')
277 _notification_socket=Instance('zmq.Socket')
278 _mux_socket=Instance('zmq.Socket')
278 _mux_socket=Instance('zmq.Socket')
279 _task_socket=Instance('zmq.Socket')
279 _task_socket=Instance('zmq.Socket')
280 _task_scheme=Unicode()
280 _task_scheme=Unicode()
281 _closed = False
281 _closed = False
282 _ignored_control_replies=Int(0)
282 _ignored_control_replies=Int(0)
283 _ignored_hub_replies=Int(0)
283 _ignored_hub_replies=Int(0)
284
284
285 def __new__(self, *args, **kw):
285 def __new__(self, *args, **kw):
286 # don't raise on positional args
286 # don't raise on positional args
287 return HasTraits.__new__(self, **kw)
287 return HasTraits.__new__(self, **kw)
288
288
289 def __init__(self, url_or_file=None, profile=None, profile_dir=None, ipython_dir=None,
289 def __init__(self, url_or_file=None, profile=None, profile_dir=None, ipython_dir=None,
290 context=None, debug=False, exec_key=None,
290 context=None, debug=False, exec_key=None,
291 sshserver=None, sshkey=None, password=None, paramiko=None,
291 sshserver=None, sshkey=None, password=None, paramiko=None,
292 timeout=10, **extra_args
292 timeout=10, **extra_args
293 ):
293 ):
294 if profile:
294 if profile:
295 super(Client, self).__init__(debug=debug, profile=profile)
295 super(Client, self).__init__(debug=debug, profile=profile)
296 else:
296 else:
297 super(Client, self).__init__(debug=debug)
297 super(Client, self).__init__(debug=debug)
298 if context is None:
298 if context is None:
299 context = zmq.Context.instance()
299 context = zmq.Context.instance()
300 self._context = context
300 self._context = context
301
301
302 self._setup_profile_dir(self.profile, profile_dir, ipython_dir)
302 self._setup_profile_dir(self.profile, profile_dir, ipython_dir)
303 if self._cd is not None:
303 if self._cd is not None:
304 if url_or_file is None:
304 if url_or_file is None:
305 url_or_file = pjoin(self._cd.security_dir, 'ipcontroller-client.json')
305 url_or_file = pjoin(self._cd.security_dir, 'ipcontroller-client.json')
306 assert url_or_file is not None, "I can't find enough information to connect to a hub!"\
306 assert url_or_file is not None, "I can't find enough information to connect to a hub!"\
307 " Please specify at least one of url_or_file or profile."
307 " Please specify at least one of url_or_file or profile."
308
308
309 try:
309 if not is_url(url_or_file):
310 util.validate_url(url_or_file)
311 except AssertionError:
312 if not os.path.exists(url_or_file):
310 if not os.path.exists(url_or_file):
313 if self._cd:
311 if self._cd:
314 url_or_file = os.path.join(self._cd.security_dir, url_or_file)
312 url_or_file = os.path.join(self._cd.security_dir, url_or_file)
315 assert os.path.exists(url_or_file), "Not a valid connection file or url: %r"%url_or_file
313 assert os.path.exists(url_or_file), "Not a valid connection file or url: %r"%url_or_file
316 with open(url_or_file) as f:
314 with open(url_or_file) as f:
317 cfg = json.loads(f.read())
315 cfg = json.loads(f.read())
318 else:
316 else:
319 cfg = {'url':url_or_file}
317 cfg = {'url':url_or_file}
320
318
321 # sync defaults from args, json:
319 # sync defaults from args, json:
322 if sshserver:
320 if sshserver:
323 cfg['ssh'] = sshserver
321 cfg['ssh'] = sshserver
324 if exec_key:
322 if exec_key:
325 cfg['exec_key'] = exec_key
323 cfg['exec_key'] = exec_key
326 exec_key = cfg['exec_key']
324 exec_key = cfg['exec_key']
327 location = cfg.setdefault('location', None)
325 location = cfg.setdefault('location', None)
328 cfg['url'] = util.disambiguate_url(cfg['url'], location)
326 cfg['url'] = util.disambiguate_url(cfg['url'], location)
329 url = cfg['url']
327 url = cfg['url']
330 proto,addr,port = util.split_url(url)
328 proto,addr,port = util.split_url(url)
331 if location is not None and addr == '127.0.0.1':
329 if location is not None and addr == '127.0.0.1':
332 # location specified, and connection is expected to be local
330 # location specified, and connection is expected to be local
333 if location not in LOCAL_IPS and not sshserver:
331 if location not in LOCAL_IPS and not sshserver:
334 # load ssh from JSON *only* if the controller is not on
332 # load ssh from JSON *only* if the controller is not on
335 # this machine
333 # this machine
336 sshserver=cfg['ssh']
334 sshserver=cfg['ssh']
337 if location not in LOCAL_IPS and not sshserver:
335 if location not in LOCAL_IPS and not sshserver:
338 # warn if no ssh specified, but SSH is probably needed
336 # warn if no ssh specified, but SSH is probably needed
339 # This is only a warning, because the most likely cause
337 # This is only a warning, because the most likely cause
340 # is a local Controller on a laptop whose IP is dynamic
338 # is a local Controller on a laptop whose IP is dynamic
341 warnings.warn("""
339 warnings.warn("""
342 Controller appears to be listening on localhost, but not on this machine.
340 Controller appears to be listening on localhost, but not on this machine.
343 If this is true, you should specify Client(...,sshserver='you@%s')
341 If this is true, you should specify Client(...,sshserver='you@%s')
344 or instruct your controller to listen on an external IP."""%location,
342 or instruct your controller to listen on an external IP."""%location,
345 RuntimeWarning)
343 RuntimeWarning)
346 elif not sshserver:
344 elif not sshserver:
347 # otherwise sync with cfg
345 # otherwise sync with cfg
348 sshserver = cfg['ssh']
346 sshserver = cfg['ssh']
349
347
350 self._config = cfg
348 self._config = cfg
351
349
352 self._ssh = bool(sshserver or sshkey or password)
350 self._ssh = bool(sshserver or sshkey or password)
353 if self._ssh and sshserver is None:
351 if self._ssh and sshserver is None:
354 # default to ssh via localhost
352 # default to ssh via localhost
355 sshserver = url.split('://')[1].split(':')[0]
353 sshserver = url.split('://')[1].split(':')[0]
356 if self._ssh and password is None:
354 if self._ssh and password is None:
357 if tunnel.try_passwordless_ssh(sshserver, sshkey, paramiko):
355 if tunnel.try_passwordless_ssh(sshserver, sshkey, paramiko):
358 password=False
356 password=False
359 else:
357 else:
360 password = getpass("SSH Password for %s: "%sshserver)
358 password = getpass("SSH Password for %s: "%sshserver)
361 ssh_kwargs = dict(keyfile=sshkey, password=password, paramiko=paramiko)
359 ssh_kwargs = dict(keyfile=sshkey, password=password, paramiko=paramiko)
362
360
363 # configure and construct the session
361 # configure and construct the session
364 if exec_key is not None:
362 if exec_key is not None:
365 if os.path.isfile(exec_key):
363 if os.path.isfile(exec_key):
366 extra_args['keyfile'] = exec_key
364 extra_args['keyfile'] = exec_key
367 else:
365 else:
368 exec_key = util.asbytes(exec_key)
366 exec_key = util.asbytes(exec_key)
369 extra_args['key'] = exec_key
367 extra_args['key'] = exec_key
370 self.session = Session(**extra_args)
368 self.session = Session(**extra_args)
371
369
372 self._query_socket = self._context.socket(zmq.DEALER)
370 self._query_socket = self._context.socket(zmq.DEALER)
373 self._query_socket.setsockopt(zmq.IDENTITY, self.session.bsession)
371 self._query_socket.setsockopt(zmq.IDENTITY, self.session.bsession)
374 if self._ssh:
372 if self._ssh:
375 tunnel.tunnel_connection(self._query_socket, url, sshserver, **ssh_kwargs)
373 tunnel.tunnel_connection(self._query_socket, url, sshserver, **ssh_kwargs)
376 else:
374 else:
377 self._query_socket.connect(url)
375 self._query_socket.connect(url)
378
376
379 self.session.debug = self.debug
377 self.session.debug = self.debug
380
378
381 self._notification_handlers = {'registration_notification' : self._register_engine,
379 self._notification_handlers = {'registration_notification' : self._register_engine,
382 'unregistration_notification' : self._unregister_engine,
380 'unregistration_notification' : self._unregister_engine,
383 'shutdown_notification' : lambda msg: self.close(),
381 'shutdown_notification' : lambda msg: self.close(),
384 }
382 }
385 self._queue_handlers = {'execute_reply' : self._handle_execute_reply,
383 self._queue_handlers = {'execute_reply' : self._handle_execute_reply,
386 'apply_reply' : self._handle_apply_reply}
384 'apply_reply' : self._handle_apply_reply}
387 self._connect(sshserver, ssh_kwargs, timeout)
385 self._connect(sshserver, ssh_kwargs, timeout)
388
386
389 def __del__(self):
387 def __del__(self):
390 """cleanup sockets, but _not_ context."""
388 """cleanup sockets, but _not_ context."""
391 self.close()
389 self.close()
392
390
393 def _setup_profile_dir(self, profile, profile_dir, ipython_dir):
391 def _setup_profile_dir(self, profile, profile_dir, ipython_dir):
394 if ipython_dir is None:
392 if ipython_dir is None:
395 ipython_dir = get_ipython_dir()
393 ipython_dir = get_ipython_dir()
396 if profile_dir is not None:
394 if profile_dir is not None:
397 try:
395 try:
398 self._cd = ProfileDir.find_profile_dir(profile_dir)
396 self._cd = ProfileDir.find_profile_dir(profile_dir)
399 return
397 return
400 except ProfileDirError:
398 except ProfileDirError:
401 pass
399 pass
402 elif profile is not None:
400 elif profile is not None:
403 try:
401 try:
404 self._cd = ProfileDir.find_profile_dir_by_name(
402 self._cd = ProfileDir.find_profile_dir_by_name(
405 ipython_dir, profile)
403 ipython_dir, profile)
406 return
404 return
407 except ProfileDirError:
405 except ProfileDirError:
408 pass
406 pass
409 self._cd = None
407 self._cd = None
410
408
411 def _update_engines(self, engines):
409 def _update_engines(self, engines):
412 """Update our engines dict and _ids from a dict of the form: {id:uuid}."""
410 """Update our engines dict and _ids from a dict of the form: {id:uuid}."""
413 for k,v in engines.iteritems():
411 for k,v in engines.iteritems():
414 eid = int(k)
412 eid = int(k)
415 self._engines[eid] = v
413 self._engines[eid] = v
416 self._ids.append(eid)
414 self._ids.append(eid)
417 self._ids = sorted(self._ids)
415 self._ids = sorted(self._ids)
418 if sorted(self._engines.keys()) != range(len(self._engines)) and \
416 if sorted(self._engines.keys()) != range(len(self._engines)) and \
419 self._task_scheme == 'pure' and self._task_socket:
417 self._task_scheme == 'pure' and self._task_socket:
420 self._stop_scheduling_tasks()
418 self._stop_scheduling_tasks()
421
419
422 def _stop_scheduling_tasks(self):
420 def _stop_scheduling_tasks(self):
423 """Stop scheduling tasks because an engine has been unregistered
421 """Stop scheduling tasks because an engine has been unregistered
424 from a pure ZMQ scheduler.
422 from a pure ZMQ scheduler.
425 """
423 """
426 self._task_socket.close()
424 self._task_socket.close()
427 self._task_socket = None
425 self._task_socket = None
428 msg = "An engine has been unregistered, and we are using pure " +\
426 msg = "An engine has been unregistered, and we are using pure " +\
429 "ZMQ task scheduling. Task farming will be disabled."
427 "ZMQ task scheduling. Task farming will be disabled."
430 if self.outstanding:
428 if self.outstanding:
431 msg += " If you were running tasks when this happened, " +\
429 msg += " If you were running tasks when this happened, " +\
432 "some `outstanding` msg_ids may never resolve."
430 "some `outstanding` msg_ids may never resolve."
433 warnings.warn(msg, RuntimeWarning)
431 warnings.warn(msg, RuntimeWarning)
434
432
435 def _build_targets(self, targets):
433 def _build_targets(self, targets):
436 """Turn valid target IDs or 'all' into two lists:
434 """Turn valid target IDs or 'all' into two lists:
437 (int_ids, uuids).
435 (int_ids, uuids).
438 """
436 """
439 if not self._ids:
437 if not self._ids:
440 # flush notification socket if no engines yet, just in case
438 # flush notification socket if no engines yet, just in case
441 if not self.ids:
439 if not self.ids:
442 raise error.NoEnginesRegistered("Can't build targets without any engines")
440 raise error.NoEnginesRegistered("Can't build targets without any engines")
443
441
444 if targets is None:
442 if targets is None:
445 targets = self._ids
443 targets = self._ids
446 elif isinstance(targets, basestring):
444 elif isinstance(targets, basestring):
447 if targets.lower() == 'all':
445 if targets.lower() == 'all':
448 targets = self._ids
446 targets = self._ids
449 else:
447 else:
450 raise TypeError("%r not valid str target, must be 'all'"%(targets))
448 raise TypeError("%r not valid str target, must be 'all'"%(targets))
451 elif isinstance(targets, int):
449 elif isinstance(targets, int):
452 if targets < 0:
450 if targets < 0:
453 targets = self.ids[targets]
451 targets = self.ids[targets]
454 if targets not in self._ids:
452 if targets not in self._ids:
455 raise IndexError("No such engine: %i"%targets)
453 raise IndexError("No such engine: %i"%targets)
456 targets = [targets]
454 targets = [targets]
457
455
458 if isinstance(targets, slice):
456 if isinstance(targets, slice):
459 indices = range(len(self._ids))[targets]
457 indices = range(len(self._ids))[targets]
460 ids = self.ids
458 ids = self.ids
461 targets = [ ids[i] for i in indices ]
459 targets = [ ids[i] for i in indices ]
462
460
463 if not isinstance(targets, (tuple, list, xrange)):
461 if not isinstance(targets, (tuple, list, xrange)):
464 raise TypeError("targets by int/slice/collection of ints only, not %s"%(type(targets)))
462 raise TypeError("targets by int/slice/collection of ints only, not %s"%(type(targets)))
465
463
466 return [util.asbytes(self._engines[t]) for t in targets], list(targets)
464 return [util.asbytes(self._engines[t]) for t in targets], list(targets)
467
465
468 def _connect(self, sshserver, ssh_kwargs, timeout):
466 def _connect(self, sshserver, ssh_kwargs, timeout):
469 """setup all our socket connections to the cluster. This is called from
467 """setup all our socket connections to the cluster. This is called from
470 __init__."""
468 __init__."""
471
469
472 # Maybe allow reconnecting?
470 # Maybe allow reconnecting?
473 if self._connected:
471 if self._connected:
474 return
472 return
475 self._connected=True
473 self._connected=True
476
474
477 def connect_socket(s, url):
475 def connect_socket(s, url):
478 url = util.disambiguate_url(url, self._config['location'])
476 url = util.disambiguate_url(url, self._config['location'])
479 if self._ssh:
477 if self._ssh:
480 return tunnel.tunnel_connection(s, url, sshserver, **ssh_kwargs)
478 return tunnel.tunnel_connection(s, url, sshserver, **ssh_kwargs)
481 else:
479 else:
482 return s.connect(url)
480 return s.connect(url)
483
481
484 self.session.send(self._query_socket, 'connection_request')
482 self.session.send(self._query_socket, 'connection_request')
485 # use Poller because zmq.select has wrong units in pyzmq 2.1.7
483 # use Poller because zmq.select has wrong units in pyzmq 2.1.7
486 poller = zmq.Poller()
484 poller = zmq.Poller()
487 poller.register(self._query_socket, zmq.POLLIN)
485 poller.register(self._query_socket, zmq.POLLIN)
488 # poll expects milliseconds, timeout is seconds
486 # poll expects milliseconds, timeout is seconds
489 evts = poller.poll(timeout*1000)
487 evts = poller.poll(timeout*1000)
490 if not evts:
488 if not evts:
491 raise error.TimeoutError("Hub connection request timed out")
489 raise error.TimeoutError("Hub connection request timed out")
492 idents,msg = self.session.recv(self._query_socket,mode=0)
490 idents,msg = self.session.recv(self._query_socket,mode=0)
493 if self.debug:
491 if self.debug:
494 pprint(msg)
492 pprint(msg)
495 msg = Message(msg)
493 msg = Message(msg)
496 content = msg.content
494 content = msg.content
497 self._config['registration'] = dict(content)
495 self._config['registration'] = dict(content)
498 if content.status == 'ok':
496 if content.status == 'ok':
499 ident = self.session.bsession
497 ident = self.session.bsession
500 if content.mux:
498 if content.mux:
501 self._mux_socket = self._context.socket(zmq.DEALER)
499 self._mux_socket = self._context.socket(zmq.DEALER)
502 self._mux_socket.setsockopt(zmq.IDENTITY, ident)
500 self._mux_socket.setsockopt(zmq.IDENTITY, ident)
503 connect_socket(self._mux_socket, content.mux)
501 connect_socket(self._mux_socket, content.mux)
504 if content.task:
502 if content.task:
505 self._task_scheme, task_addr = content.task
503 self._task_scheme, task_addr = content.task
506 self._task_socket = self._context.socket(zmq.DEALER)
504 self._task_socket = self._context.socket(zmq.DEALER)
507 self._task_socket.setsockopt(zmq.IDENTITY, ident)
505 self._task_socket.setsockopt(zmq.IDENTITY, ident)
508 connect_socket(self._task_socket, task_addr)
506 connect_socket(self._task_socket, task_addr)
509 if content.notification:
507 if content.notification:
510 self._notification_socket = self._context.socket(zmq.SUB)
508 self._notification_socket = self._context.socket(zmq.SUB)
511 connect_socket(self._notification_socket, content.notification)
509 connect_socket(self._notification_socket, content.notification)
512 self._notification_socket.setsockopt(zmq.SUBSCRIBE, b'')
510 self._notification_socket.setsockopt(zmq.SUBSCRIBE, b'')
513 # if content.query:
511 # if content.query:
514 # self._query_socket = self._context.socket(zmq.DEALER)
512 # self._query_socket = self._context.socket(zmq.DEALER)
515 # self._query_socket.setsockopt(zmq.IDENTITY, self.session.bsession)
513 # self._query_socket.setsockopt(zmq.IDENTITY, self.session.bsession)
516 # connect_socket(self._query_socket, content.query)
514 # connect_socket(self._query_socket, content.query)
517 if content.control:
515 if content.control:
518 self._control_socket = self._context.socket(zmq.DEALER)
516 self._control_socket = self._context.socket(zmq.DEALER)
519 self._control_socket.setsockopt(zmq.IDENTITY, ident)
517 self._control_socket.setsockopt(zmq.IDENTITY, ident)
520 connect_socket(self._control_socket, content.control)
518 connect_socket(self._control_socket, content.control)
521 if content.iopub:
519 if content.iopub:
522 self._iopub_socket = self._context.socket(zmq.SUB)
520 self._iopub_socket = self._context.socket(zmq.SUB)
523 self._iopub_socket.setsockopt(zmq.SUBSCRIBE, b'')
521 self._iopub_socket.setsockopt(zmq.SUBSCRIBE, b'')
524 self._iopub_socket.setsockopt(zmq.IDENTITY, ident)
522 self._iopub_socket.setsockopt(zmq.IDENTITY, ident)
525 connect_socket(self._iopub_socket, content.iopub)
523 connect_socket(self._iopub_socket, content.iopub)
526 self._update_engines(dict(content.engines))
524 self._update_engines(dict(content.engines))
527 else:
525 else:
528 self._connected = False
526 self._connected = False
529 raise Exception("Failed to connect!")
527 raise Exception("Failed to connect!")
530
528
531 #--------------------------------------------------------------------------
529 #--------------------------------------------------------------------------
532 # handlers and callbacks for incoming messages
530 # handlers and callbacks for incoming messages
533 #--------------------------------------------------------------------------
531 #--------------------------------------------------------------------------
534
532
535 def _unwrap_exception(self, content):
533 def _unwrap_exception(self, content):
536 """unwrap exception, and remap engine_id to int."""
534 """unwrap exception, and remap engine_id to int."""
537 e = error.unwrap_exception(content)
535 e = error.unwrap_exception(content)
538 # print e.traceback
536 # print e.traceback
539 if e.engine_info:
537 if e.engine_info:
540 e_uuid = e.engine_info['engine_uuid']
538 e_uuid = e.engine_info['engine_uuid']
541 eid = self._engines[e_uuid]
539 eid = self._engines[e_uuid]
542 e.engine_info['engine_id'] = eid
540 e.engine_info['engine_id'] = eid
543 return e
541 return e
544
542
545 def _extract_metadata(self, header, parent, content):
543 def _extract_metadata(self, header, parent, content):
546 md = {'msg_id' : parent['msg_id'],
544 md = {'msg_id' : parent['msg_id'],
547 'received' : datetime.now(),
545 'received' : datetime.now(),
548 'engine_uuid' : header.get('engine', None),
546 'engine_uuid' : header.get('engine', None),
549 'follow' : parent.get('follow', []),
547 'follow' : parent.get('follow', []),
550 'after' : parent.get('after', []),
548 'after' : parent.get('after', []),
551 'status' : content['status'],
549 'status' : content['status'],
552 }
550 }
553
551
554 if md['engine_uuid'] is not None:
552 if md['engine_uuid'] is not None:
555 md['engine_id'] = self._engines.get(md['engine_uuid'], None)
553 md['engine_id'] = self._engines.get(md['engine_uuid'], None)
556
554
557 if 'date' in parent:
555 if 'date' in parent:
558 md['submitted'] = parent['date']
556 md['submitted'] = parent['date']
559 if 'started' in header:
557 if 'started' in header:
560 md['started'] = header['started']
558 md['started'] = header['started']
561 if 'date' in header:
559 if 'date' in header:
562 md['completed'] = header['date']
560 md['completed'] = header['date']
563 return md
561 return md
564
562
565 def _register_engine(self, msg):
563 def _register_engine(self, msg):
566 """Register a new engine, and update our connection info."""
564 """Register a new engine, and update our connection info."""
567 content = msg['content']
565 content = msg['content']
568 eid = content['id']
566 eid = content['id']
569 d = {eid : content['queue']}
567 d = {eid : content['queue']}
570 self._update_engines(d)
568 self._update_engines(d)
571
569
572 def _unregister_engine(self, msg):
570 def _unregister_engine(self, msg):
573 """Unregister an engine that has died."""
571 """Unregister an engine that has died."""
574 content = msg['content']
572 content = msg['content']
575 eid = int(content['id'])
573 eid = int(content['id'])
576 if eid in self._ids:
574 if eid in self._ids:
577 self._ids.remove(eid)
575 self._ids.remove(eid)
578 uuid = self._engines.pop(eid)
576 uuid = self._engines.pop(eid)
579
577
580 self._handle_stranded_msgs(eid, uuid)
578 self._handle_stranded_msgs(eid, uuid)
581
579
582 if self._task_socket and self._task_scheme == 'pure':
580 if self._task_socket and self._task_scheme == 'pure':
583 self._stop_scheduling_tasks()
581 self._stop_scheduling_tasks()
584
582
585 def _handle_stranded_msgs(self, eid, uuid):
583 def _handle_stranded_msgs(self, eid, uuid):
586 """Handle messages known to be on an engine when the engine unregisters.
584 """Handle messages known to be on an engine when the engine unregisters.
587
585
588 It is possible that this will fire prematurely - that is, an engine will
586 It is possible that this will fire prematurely - that is, an engine will
589 go down after completing a result, and the client will be notified
587 go down after completing a result, and the client will be notified
590 of the unregistration and later receive the successful result.
588 of the unregistration and later receive the successful result.
591 """
589 """
592
590
593 outstanding = self._outstanding_dict[uuid]
591 outstanding = self._outstanding_dict[uuid]
594
592
595 for msg_id in list(outstanding):
593 for msg_id in list(outstanding):
596 if msg_id in self.results:
594 if msg_id in self.results:
597 # we already
595 # we already
598 continue
596 continue
599 try:
597 try:
600 raise error.EngineError("Engine %r died while running task %r"%(eid, msg_id))
598 raise error.EngineError("Engine %r died while running task %r"%(eid, msg_id))
601 except:
599 except:
602 content = error.wrap_exception()
600 content = error.wrap_exception()
603 # build a fake message:
601 # build a fake message:
604 parent = {}
602 parent = {}
605 header = {}
603 header = {}
606 parent['msg_id'] = msg_id
604 parent['msg_id'] = msg_id
607 header['engine'] = uuid
605 header['engine'] = uuid
608 header['date'] = datetime.now()
606 header['date'] = datetime.now()
609 msg = dict(parent_header=parent, header=header, content=content)
607 msg = dict(parent_header=parent, header=header, content=content)
610 self._handle_apply_reply(msg)
608 self._handle_apply_reply(msg)
611
609
612 def _handle_execute_reply(self, msg):
610 def _handle_execute_reply(self, msg):
613 """Save the reply to an execute_request into our results.
611 """Save the reply to an execute_request into our results.
614
612
615 execute messages are never actually used. apply is used instead.
613 execute messages are never actually used. apply is used instead.
616 """
614 """
617
615
618 parent = msg['parent_header']
616 parent = msg['parent_header']
619 msg_id = parent['msg_id']
617 msg_id = parent['msg_id']
620 if msg_id not in self.outstanding:
618 if msg_id not in self.outstanding:
621 if msg_id in self.history:
619 if msg_id in self.history:
622 print ("got stale result: %s"%msg_id)
620 print ("got stale result: %s"%msg_id)
623 else:
621 else:
624 print ("got unknown result: %s"%msg_id)
622 print ("got unknown result: %s"%msg_id)
625 else:
623 else:
626 self.outstanding.remove(msg_id)
624 self.outstanding.remove(msg_id)
627 self.results[msg_id] = self._unwrap_exception(msg['content'])
625 self.results[msg_id] = self._unwrap_exception(msg['content'])
628
626
629 def _handle_apply_reply(self, msg):
627 def _handle_apply_reply(self, msg):
630 """Save the reply to an apply_request into our results."""
628 """Save the reply to an apply_request into our results."""
631 parent = msg['parent_header']
629 parent = msg['parent_header']
632 msg_id = parent['msg_id']
630 msg_id = parent['msg_id']
633 if msg_id not in self.outstanding:
631 if msg_id not in self.outstanding:
634 if msg_id in self.history:
632 if msg_id in self.history:
635 print ("got stale result: %s"%msg_id)
633 print ("got stale result: %s"%msg_id)
636 print self.results[msg_id]
634 print self.results[msg_id]
637 print msg
635 print msg
638 else:
636 else:
639 print ("got unknown result: %s"%msg_id)
637 print ("got unknown result: %s"%msg_id)
640 else:
638 else:
641 self.outstanding.remove(msg_id)
639 self.outstanding.remove(msg_id)
642 content = msg['content']
640 content = msg['content']
643 header = msg['header']
641 header = msg['header']
644
642
645 # construct metadata:
643 # construct metadata:
646 md = self.metadata[msg_id]
644 md = self.metadata[msg_id]
647 md.update(self._extract_metadata(header, parent, content))
645 md.update(self._extract_metadata(header, parent, content))
648 # is this redundant?
646 # is this redundant?
649 self.metadata[msg_id] = md
647 self.metadata[msg_id] = md
650
648
651 e_outstanding = self._outstanding_dict[md['engine_uuid']]
649 e_outstanding = self._outstanding_dict[md['engine_uuid']]
652 if msg_id in e_outstanding:
650 if msg_id in e_outstanding:
653 e_outstanding.remove(msg_id)
651 e_outstanding.remove(msg_id)
654
652
655 # construct result:
653 # construct result:
656 if content['status'] == 'ok':
654 if content['status'] == 'ok':
657 self.results[msg_id] = util.unserialize_object(msg['buffers'])[0]
655 self.results[msg_id] = util.unserialize_object(msg['buffers'])[0]
658 elif content['status'] == 'aborted':
656 elif content['status'] == 'aborted':
659 self.results[msg_id] = error.TaskAborted(msg_id)
657 self.results[msg_id] = error.TaskAborted(msg_id)
660 elif content['status'] == 'resubmitted':
658 elif content['status'] == 'resubmitted':
661 # TODO: handle resubmission
659 # TODO: handle resubmission
662 pass
660 pass
663 else:
661 else:
664 self.results[msg_id] = self._unwrap_exception(content)
662 self.results[msg_id] = self._unwrap_exception(content)
665
663
666 def _flush_notifications(self):
664 def _flush_notifications(self):
667 """Flush notifications of engine registrations waiting
665 """Flush notifications of engine registrations waiting
668 in ZMQ queue."""
666 in ZMQ queue."""
669 idents,msg = self.session.recv(self._notification_socket, mode=zmq.NOBLOCK)
667 idents,msg = self.session.recv(self._notification_socket, mode=zmq.NOBLOCK)
670 while msg is not None:
668 while msg is not None:
671 if self.debug:
669 if self.debug:
672 pprint(msg)
670 pprint(msg)
673 msg_type = msg['header']['msg_type']
671 msg_type = msg['header']['msg_type']
674 handler = self._notification_handlers.get(msg_type, None)
672 handler = self._notification_handlers.get(msg_type, None)
675 if handler is None:
673 if handler is None:
676 raise Exception("Unhandled message type: %s"%msg.msg_type)
674 raise Exception("Unhandled message type: %s"%msg.msg_type)
677 else:
675 else:
678 handler(msg)
676 handler(msg)
679 idents,msg = self.session.recv(self._notification_socket, mode=zmq.NOBLOCK)
677 idents,msg = self.session.recv(self._notification_socket, mode=zmq.NOBLOCK)
680
678
681 def _flush_results(self, sock):
679 def _flush_results(self, sock):
682 """Flush task or queue results waiting in ZMQ queue."""
680 """Flush task or queue results waiting in ZMQ queue."""
683 idents,msg = self.session.recv(sock, mode=zmq.NOBLOCK)
681 idents,msg = self.session.recv(sock, mode=zmq.NOBLOCK)
684 while msg is not None:
682 while msg is not None:
685 if self.debug:
683 if self.debug:
686 pprint(msg)
684 pprint(msg)
687 msg_type = msg['header']['msg_type']
685 msg_type = msg['header']['msg_type']
688 handler = self._queue_handlers.get(msg_type, None)
686 handler = self._queue_handlers.get(msg_type, None)
689 if handler is None:
687 if handler is None:
690 raise Exception("Unhandled message type: %s"%msg.msg_type)
688 raise Exception("Unhandled message type: %s"%msg.msg_type)
691 else:
689 else:
692 handler(msg)
690 handler(msg)
693 idents,msg = self.session.recv(sock, mode=zmq.NOBLOCK)
691 idents,msg = self.session.recv(sock, mode=zmq.NOBLOCK)
694
692
695 def _flush_control(self, sock):
693 def _flush_control(self, sock):
696 """Flush replies from the control channel waiting
694 """Flush replies from the control channel waiting
697 in the ZMQ queue.
695 in the ZMQ queue.
698
696
699 Currently: ignore them."""
697 Currently: ignore them."""
700 if self._ignored_control_replies <= 0:
698 if self._ignored_control_replies <= 0:
701 return
699 return
702 idents,msg = self.session.recv(sock, mode=zmq.NOBLOCK)
700 idents,msg = self.session.recv(sock, mode=zmq.NOBLOCK)
703 while msg is not None:
701 while msg is not None:
704 self._ignored_control_replies -= 1
702 self._ignored_control_replies -= 1
705 if self.debug:
703 if self.debug:
706 pprint(msg)
704 pprint(msg)
707 idents,msg = self.session.recv(sock, mode=zmq.NOBLOCK)
705 idents,msg = self.session.recv(sock, mode=zmq.NOBLOCK)
708
706
709 def _flush_ignored_control(self):
707 def _flush_ignored_control(self):
710 """flush ignored control replies"""
708 """flush ignored control replies"""
711 while self._ignored_control_replies > 0:
709 while self._ignored_control_replies > 0:
712 self.session.recv(self._control_socket)
710 self.session.recv(self._control_socket)
713 self._ignored_control_replies -= 1
711 self._ignored_control_replies -= 1
714
712
715 def _flush_ignored_hub_replies(self):
713 def _flush_ignored_hub_replies(self):
716 ident,msg = self.session.recv(self._query_socket, mode=zmq.NOBLOCK)
714 ident,msg = self.session.recv(self._query_socket, mode=zmq.NOBLOCK)
717 while msg is not None:
715 while msg is not None:
718 ident,msg = self.session.recv(self._query_socket, mode=zmq.NOBLOCK)
716 ident,msg = self.session.recv(self._query_socket, mode=zmq.NOBLOCK)
719
717
720 def _flush_iopub(self, sock):
718 def _flush_iopub(self, sock):
721 """Flush replies from the iopub channel waiting
719 """Flush replies from the iopub channel waiting
722 in the ZMQ queue.
720 in the ZMQ queue.
723 """
721 """
724 idents,msg = self.session.recv(sock, mode=zmq.NOBLOCK)
722 idents,msg = self.session.recv(sock, mode=zmq.NOBLOCK)
725 while msg is not None:
723 while msg is not None:
726 if self.debug:
724 if self.debug:
727 pprint(msg)
725 pprint(msg)
728 parent = msg['parent_header']
726 parent = msg['parent_header']
729 msg_id = parent['msg_id']
727 msg_id = parent['msg_id']
730 content = msg['content']
728 content = msg['content']
731 header = msg['header']
729 header = msg['header']
732 msg_type = msg['header']['msg_type']
730 msg_type = msg['header']['msg_type']
733
731
734 # init metadata:
732 # init metadata:
735 md = self.metadata[msg_id]
733 md = self.metadata[msg_id]
736
734
737 if msg_type == 'stream':
735 if msg_type == 'stream':
738 name = content['name']
736 name = content['name']
739 s = md[name] or ''
737 s = md[name] or ''
740 md[name] = s + content['data']
738 md[name] = s + content['data']
741 elif msg_type == 'pyerr':
739 elif msg_type == 'pyerr':
742 md.update({'pyerr' : self._unwrap_exception(content)})
740 md.update({'pyerr' : self._unwrap_exception(content)})
743 elif msg_type == 'pyin':
741 elif msg_type == 'pyin':
744 md.update({'pyin' : content['code']})
742 md.update({'pyin' : content['code']})
745 else:
743 else:
746 md.update({msg_type : content.get('data', '')})
744 md.update({msg_type : content.get('data', '')})
747
745
748 # reduntant?
746 # reduntant?
749 self.metadata[msg_id] = md
747 self.metadata[msg_id] = md
750
748
751 idents,msg = self.session.recv(sock, mode=zmq.NOBLOCK)
749 idents,msg = self.session.recv(sock, mode=zmq.NOBLOCK)
752
750
753 #--------------------------------------------------------------------------
751 #--------------------------------------------------------------------------
754 # len, getitem
752 # len, getitem
755 #--------------------------------------------------------------------------
753 #--------------------------------------------------------------------------
756
754
757 def __len__(self):
755 def __len__(self):
758 """len(client) returns # of engines."""
756 """len(client) returns # of engines."""
759 return len(self.ids)
757 return len(self.ids)
760
758
761 def __getitem__(self, key):
759 def __getitem__(self, key):
762 """index access returns DirectView multiplexer objects
760 """index access returns DirectView multiplexer objects
763
761
764 Must be int, slice, or list/tuple/xrange of ints"""
762 Must be int, slice, or list/tuple/xrange of ints"""
765 if not isinstance(key, (int, slice, tuple, list, xrange)):
763 if not isinstance(key, (int, slice, tuple, list, xrange)):
766 raise TypeError("key by int/slice/iterable of ints only, not %s"%(type(key)))
764 raise TypeError("key by int/slice/iterable of ints only, not %s"%(type(key)))
767 else:
765 else:
768 return self.direct_view(key)
766 return self.direct_view(key)
769
767
770 #--------------------------------------------------------------------------
768 #--------------------------------------------------------------------------
771 # Begin public methods
769 # Begin public methods
772 #--------------------------------------------------------------------------
770 #--------------------------------------------------------------------------
773
771
774 @property
772 @property
775 def ids(self):
773 def ids(self):
776 """Always up-to-date ids property."""
774 """Always up-to-date ids property."""
777 self._flush_notifications()
775 self._flush_notifications()
778 # always copy:
776 # always copy:
779 return list(self._ids)
777 return list(self._ids)
780
778
781 def close(self):
779 def close(self):
782 if self._closed:
780 if self._closed:
783 return
781 return
784 snames = filter(lambda n: n.endswith('socket'), dir(self))
782 snames = filter(lambda n: n.endswith('socket'), dir(self))
785 for socket in map(lambda name: getattr(self, name), snames):
783 for socket in map(lambda name: getattr(self, name), snames):
786 if isinstance(socket, zmq.Socket) and not socket.closed:
784 if isinstance(socket, zmq.Socket) and not socket.closed:
787 socket.close()
785 socket.close()
788 self._closed = True
786 self._closed = True
789
787
790 def spin(self):
788 def spin(self):
791 """Flush any registration notifications and execution results
789 """Flush any registration notifications and execution results
792 waiting in the ZMQ queue.
790 waiting in the ZMQ queue.
793 """
791 """
794 if self._notification_socket:
792 if self._notification_socket:
795 self._flush_notifications()
793 self._flush_notifications()
796 if self._mux_socket:
794 if self._mux_socket:
797 self._flush_results(self._mux_socket)
795 self._flush_results(self._mux_socket)
798 if self._task_socket:
796 if self._task_socket:
799 self._flush_results(self._task_socket)
797 self._flush_results(self._task_socket)
800 if self._control_socket:
798 if self._control_socket:
801 self._flush_control(self._control_socket)
799 self._flush_control(self._control_socket)
802 if self._iopub_socket:
800 if self._iopub_socket:
803 self._flush_iopub(self._iopub_socket)
801 self._flush_iopub(self._iopub_socket)
804 if self._query_socket:
802 if self._query_socket:
805 self._flush_ignored_hub_replies()
803 self._flush_ignored_hub_replies()
806
804
807 def wait(self, jobs=None, timeout=-1):
805 def wait(self, jobs=None, timeout=-1):
808 """waits on one or more `jobs`, for up to `timeout` seconds.
806 """waits on one or more `jobs`, for up to `timeout` seconds.
809
807
810 Parameters
808 Parameters
811 ----------
809 ----------
812
810
813 jobs : int, str, or list of ints and/or strs, or one or more AsyncResult objects
811 jobs : int, str, or list of ints and/or strs, or one or more AsyncResult objects
814 ints are indices to self.history
812 ints are indices to self.history
815 strs are msg_ids
813 strs are msg_ids
816 default: wait on all outstanding messages
814 default: wait on all outstanding messages
817 timeout : float
815 timeout : float
818 a time in seconds, after which to give up.
816 a time in seconds, after which to give up.
819 default is -1, which means no timeout
817 default is -1, which means no timeout
820
818
821 Returns
819 Returns
822 -------
820 -------
823
821
824 True : when all msg_ids are done
822 True : when all msg_ids are done
825 False : timeout reached, some msg_ids still outstanding
823 False : timeout reached, some msg_ids still outstanding
826 """
824 """
827 tic = time.time()
825 tic = time.time()
828 if jobs is None:
826 if jobs is None:
829 theids = self.outstanding
827 theids = self.outstanding
830 else:
828 else:
831 if isinstance(jobs, (int, basestring, AsyncResult)):
829 if isinstance(jobs, (int, basestring, AsyncResult)):
832 jobs = [jobs]
830 jobs = [jobs]
833 theids = set()
831 theids = set()
834 for job in jobs:
832 for job in jobs:
835 if isinstance(job, int):
833 if isinstance(job, int):
836 # index access
834 # index access
837 job = self.history[job]
835 job = self.history[job]
838 elif isinstance(job, AsyncResult):
836 elif isinstance(job, AsyncResult):
839 map(theids.add, job.msg_ids)
837 map(theids.add, job.msg_ids)
840 continue
838 continue
841 theids.add(job)
839 theids.add(job)
842 if not theids.intersection(self.outstanding):
840 if not theids.intersection(self.outstanding):
843 return True
841 return True
844 self.spin()
842 self.spin()
845 while theids.intersection(self.outstanding):
843 while theids.intersection(self.outstanding):
846 if timeout >= 0 and ( time.time()-tic ) > timeout:
844 if timeout >= 0 and ( time.time()-tic ) > timeout:
847 break
845 break
848 time.sleep(1e-3)
846 time.sleep(1e-3)
849 self.spin()
847 self.spin()
850 return len(theids.intersection(self.outstanding)) == 0
848 return len(theids.intersection(self.outstanding)) == 0
851
849
852 #--------------------------------------------------------------------------
850 #--------------------------------------------------------------------------
853 # Control methods
851 # Control methods
854 #--------------------------------------------------------------------------
852 #--------------------------------------------------------------------------
855
853
856 @spin_first
854 @spin_first
857 def clear(self, targets=None, block=None):
855 def clear(self, targets=None, block=None):
858 """Clear the namespace in target(s)."""
856 """Clear the namespace in target(s)."""
859 block = self.block if block is None else block
857 block = self.block if block is None else block
860 targets = self._build_targets(targets)[0]
858 targets = self._build_targets(targets)[0]
861 for t in targets:
859 for t in targets:
862 self.session.send(self._control_socket, 'clear_request', content={}, ident=t)
860 self.session.send(self._control_socket, 'clear_request', content={}, ident=t)
863 error = False
861 error = False
864 if block:
862 if block:
865 self._flush_ignored_control()
863 self._flush_ignored_control()
866 for i in range(len(targets)):
864 for i in range(len(targets)):
867 idents,msg = self.session.recv(self._control_socket,0)
865 idents,msg = self.session.recv(self._control_socket,0)
868 if self.debug:
866 if self.debug:
869 pprint(msg)
867 pprint(msg)
870 if msg['content']['status'] != 'ok':
868 if msg['content']['status'] != 'ok':
871 error = self._unwrap_exception(msg['content'])
869 error = self._unwrap_exception(msg['content'])
872 else:
870 else:
873 self._ignored_control_replies += len(targets)
871 self._ignored_control_replies += len(targets)
874 if error:
872 if error:
875 raise error
873 raise error
876
874
877
875
878 @spin_first
876 @spin_first
879 def abort(self, jobs=None, targets=None, block=None):
877 def abort(self, jobs=None, targets=None, block=None):
880 """Abort specific jobs from the execution queues of target(s).
878 """Abort specific jobs from the execution queues of target(s).
881
879
882 This is a mechanism to prevent jobs that have already been submitted
880 This is a mechanism to prevent jobs that have already been submitted
883 from executing.
881 from executing.
884
882
885 Parameters
883 Parameters
886 ----------
884 ----------
887
885
888 jobs : msg_id, list of msg_ids, or AsyncResult
886 jobs : msg_id, list of msg_ids, or AsyncResult
889 The jobs to be aborted
887 The jobs to be aborted
890
888
891
889
892 """
890 """
893 block = self.block if block is None else block
891 block = self.block if block is None else block
894 targets = self._build_targets(targets)[0]
892 targets = self._build_targets(targets)[0]
895 msg_ids = []
893 msg_ids = []
896 if isinstance(jobs, (basestring,AsyncResult)):
894 if isinstance(jobs, (basestring,AsyncResult)):
897 jobs = [jobs]
895 jobs = [jobs]
898 bad_ids = filter(lambda obj: not isinstance(obj, (basestring, AsyncResult)), jobs)
896 bad_ids = filter(lambda obj: not isinstance(obj, (basestring, AsyncResult)), jobs)
899 if bad_ids:
897 if bad_ids:
900 raise TypeError("Invalid msg_id type %r, expected str or AsyncResult"%bad_ids[0])
898 raise TypeError("Invalid msg_id type %r, expected str or AsyncResult"%bad_ids[0])
901 for j in jobs:
899 for j in jobs:
902 if isinstance(j, AsyncResult):
900 if isinstance(j, AsyncResult):
903 msg_ids.extend(j.msg_ids)
901 msg_ids.extend(j.msg_ids)
904 else:
902 else:
905 msg_ids.append(j)
903 msg_ids.append(j)
906 content = dict(msg_ids=msg_ids)
904 content = dict(msg_ids=msg_ids)
907 for t in targets:
905 for t in targets:
908 self.session.send(self._control_socket, 'abort_request',
906 self.session.send(self._control_socket, 'abort_request',
909 content=content, ident=t)
907 content=content, ident=t)
910 error = False
908 error = False
911 if block:
909 if block:
912 self._flush_ignored_control()
910 self._flush_ignored_control()
913 for i in range(len(targets)):
911 for i in range(len(targets)):
914 idents,msg = self.session.recv(self._control_socket,0)
912 idents,msg = self.session.recv(self._control_socket,0)
915 if self.debug:
913 if self.debug:
916 pprint(msg)
914 pprint(msg)
917 if msg['content']['status'] != 'ok':
915 if msg['content']['status'] != 'ok':
918 error = self._unwrap_exception(msg['content'])
916 error = self._unwrap_exception(msg['content'])
919 else:
917 else:
920 self._ignored_control_replies += len(targets)
918 self._ignored_control_replies += len(targets)
921 if error:
919 if error:
922 raise error
920 raise error
923
921
924 @spin_first
922 @spin_first
925 def shutdown(self, targets=None, restart=False, hub=False, block=None):
923 def shutdown(self, targets=None, restart=False, hub=False, block=None):
926 """Terminates one or more engine processes, optionally including the hub."""
924 """Terminates one or more engine processes, optionally including the hub."""
927 block = self.block if block is None else block
925 block = self.block if block is None else block
928 if hub:
926 if hub:
929 targets = 'all'
927 targets = 'all'
930 targets = self._build_targets(targets)[0]
928 targets = self._build_targets(targets)[0]
931 for t in targets:
929 for t in targets:
932 self.session.send(self._control_socket, 'shutdown_request',
930 self.session.send(self._control_socket, 'shutdown_request',
933 content={'restart':restart},ident=t)
931 content={'restart':restart},ident=t)
934 error = False
932 error = False
935 if block or hub:
933 if block or hub:
936 self._flush_ignored_control()
934 self._flush_ignored_control()
937 for i in range(len(targets)):
935 for i in range(len(targets)):
938 idents,msg = self.session.recv(self._control_socket, 0)
936 idents,msg = self.session.recv(self._control_socket, 0)
939 if self.debug:
937 if self.debug:
940 pprint(msg)
938 pprint(msg)
941 if msg['content']['status'] != 'ok':
939 if msg['content']['status'] != 'ok':
942 error = self._unwrap_exception(msg['content'])
940 error = self._unwrap_exception(msg['content'])
943 else:
941 else:
944 self._ignored_control_replies += len(targets)
942 self._ignored_control_replies += len(targets)
945
943
946 if hub:
944 if hub:
947 time.sleep(0.25)
945 time.sleep(0.25)
948 self.session.send(self._query_socket, 'shutdown_request')
946 self.session.send(self._query_socket, 'shutdown_request')
949 idents,msg = self.session.recv(self._query_socket, 0)
947 idents,msg = self.session.recv(self._query_socket, 0)
950 if self.debug:
948 if self.debug:
951 pprint(msg)
949 pprint(msg)
952 if msg['content']['status'] != 'ok':
950 if msg['content']['status'] != 'ok':
953 error = self._unwrap_exception(msg['content'])
951 error = self._unwrap_exception(msg['content'])
954
952
955 if error:
953 if error:
956 raise error
954 raise error
957
955
958 #--------------------------------------------------------------------------
956 #--------------------------------------------------------------------------
959 # Execution related methods
957 # Execution related methods
960 #--------------------------------------------------------------------------
958 #--------------------------------------------------------------------------
961
959
962 def _maybe_raise(self, result):
960 def _maybe_raise(self, result):
963 """wrapper for maybe raising an exception if apply failed."""
961 """wrapper for maybe raising an exception if apply failed."""
964 if isinstance(result, error.RemoteError):
962 if isinstance(result, error.RemoteError):
965 raise result
963 raise result
966
964
967 return result
965 return result
968
966
969 def send_apply_message(self, socket, f, args=None, kwargs=None, subheader=None, track=False,
967 def send_apply_message(self, socket, f, args=None, kwargs=None, subheader=None, track=False,
970 ident=None):
968 ident=None):
971 """construct and send an apply message via a socket.
969 """construct and send an apply message via a socket.
972
970
973 This is the principal method with which all engine execution is performed by views.
971 This is the principal method with which all engine execution is performed by views.
974 """
972 """
975
973
976 assert not self._closed, "cannot use me anymore, I'm closed!"
974 assert not self._closed, "cannot use me anymore, I'm closed!"
977 # defaults:
975 # defaults:
978 args = args if args is not None else []
976 args = args if args is not None else []
979 kwargs = kwargs if kwargs is not None else {}
977 kwargs = kwargs if kwargs is not None else {}
980 subheader = subheader if subheader is not None else {}
978 subheader = subheader if subheader is not None else {}
981
979
982 # validate arguments
980 # validate arguments
983 if not callable(f):
981 if not callable(f):
984 raise TypeError("f must be callable, not %s"%type(f))
982 raise TypeError("f must be callable, not %s"%type(f))
985 if not isinstance(args, (tuple, list)):
983 if not isinstance(args, (tuple, list)):
986 raise TypeError("args must be tuple or list, not %s"%type(args))
984 raise TypeError("args must be tuple or list, not %s"%type(args))
987 if not isinstance(kwargs, dict):
985 if not isinstance(kwargs, dict):
988 raise TypeError("kwargs must be dict, not %s"%type(kwargs))
986 raise TypeError("kwargs must be dict, not %s"%type(kwargs))
989 if not isinstance(subheader, dict):
987 if not isinstance(subheader, dict):
990 raise TypeError("subheader must be dict, not %s"%type(subheader))
988 raise TypeError("subheader must be dict, not %s"%type(subheader))
991
989
992 bufs = util.pack_apply_message(f,args,kwargs)
990 bufs = util.pack_apply_message(f,args,kwargs)
993
991
994 msg = self.session.send(socket, "apply_request", buffers=bufs, ident=ident,
992 msg = self.session.send(socket, "apply_request", buffers=bufs, ident=ident,
995 subheader=subheader, track=track)
993 subheader=subheader, track=track)
996
994
997 msg_id = msg['header']['msg_id']
995 msg_id = msg['header']['msg_id']
998 self.outstanding.add(msg_id)
996 self.outstanding.add(msg_id)
999 if ident:
997 if ident:
1000 # possibly routed to a specific engine
998 # possibly routed to a specific engine
1001 if isinstance(ident, list):
999 if isinstance(ident, list):
1002 ident = ident[-1]
1000 ident = ident[-1]
1003 if ident in self._engines.values():
1001 if ident in self._engines.values():
1004 # save for later, in case of engine death
1002 # save for later, in case of engine death
1005 self._outstanding_dict[ident].add(msg_id)
1003 self._outstanding_dict[ident].add(msg_id)
1006 self.history.append(msg_id)
1004 self.history.append(msg_id)
1007 self.metadata[msg_id]['submitted'] = datetime.now()
1005 self.metadata[msg_id]['submitted'] = datetime.now()
1008
1006
1009 return msg
1007 return msg
1010
1008
1011 #--------------------------------------------------------------------------
1009 #--------------------------------------------------------------------------
1012 # construct a View object
1010 # construct a View object
1013 #--------------------------------------------------------------------------
1011 #--------------------------------------------------------------------------
1014
1012
1015 def load_balanced_view(self, targets=None):
1013 def load_balanced_view(self, targets=None):
1016 """construct a DirectView object.
1014 """construct a DirectView object.
1017
1015
1018 If no arguments are specified, create a LoadBalancedView
1016 If no arguments are specified, create a LoadBalancedView
1019 using all engines.
1017 using all engines.
1020
1018
1021 Parameters
1019 Parameters
1022 ----------
1020 ----------
1023
1021
1024 targets: list,slice,int,etc. [default: use all engines]
1022 targets: list,slice,int,etc. [default: use all engines]
1025 The subset of engines across which to load-balance
1023 The subset of engines across which to load-balance
1026 """
1024 """
1027 if targets == 'all':
1025 if targets == 'all':
1028 targets = None
1026 targets = None
1029 if targets is not None:
1027 if targets is not None:
1030 targets = self._build_targets(targets)[1]
1028 targets = self._build_targets(targets)[1]
1031 return LoadBalancedView(client=self, socket=self._task_socket, targets=targets)
1029 return LoadBalancedView(client=self, socket=self._task_socket, targets=targets)
1032
1030
1033 def direct_view(self, targets='all'):
1031 def direct_view(self, targets='all'):
1034 """construct a DirectView object.
1032 """construct a DirectView object.
1035
1033
1036 If no targets are specified, create a DirectView
1034 If no targets are specified, create a DirectView
1037 using all engines.
1035 using all engines.
1038
1036
1039 Parameters
1037 Parameters
1040 ----------
1038 ----------
1041
1039
1042 targets: list,slice,int,etc. [default: use all engines]
1040 targets: list,slice,int,etc. [default: use all engines]
1043 The engines to use for the View
1041 The engines to use for the View
1044 """
1042 """
1045 single = isinstance(targets, int)
1043 single = isinstance(targets, int)
1046 # allow 'all' to be lazily evaluated at each execution
1044 # allow 'all' to be lazily evaluated at each execution
1047 if targets != 'all':
1045 if targets != 'all':
1048 targets = self._build_targets(targets)[1]
1046 targets = self._build_targets(targets)[1]
1049 if single:
1047 if single:
1050 targets = targets[0]
1048 targets = targets[0]
1051 return DirectView(client=self, socket=self._mux_socket, targets=targets)
1049 return DirectView(client=self, socket=self._mux_socket, targets=targets)
1052
1050
1053 #--------------------------------------------------------------------------
1051 #--------------------------------------------------------------------------
1054 # Query methods
1052 # Query methods
1055 #--------------------------------------------------------------------------
1053 #--------------------------------------------------------------------------
1056
1054
1057 @spin_first
1055 @spin_first
1058 def get_result(self, indices_or_msg_ids=None, block=None):
1056 def get_result(self, indices_or_msg_ids=None, block=None):
1059 """Retrieve a result by msg_id or history index, wrapped in an AsyncResult object.
1057 """Retrieve a result by msg_id or history index, wrapped in an AsyncResult object.
1060
1058
1061 If the client already has the results, no request to the Hub will be made.
1059 If the client already has the results, no request to the Hub will be made.
1062
1060
1063 This is a convenient way to construct AsyncResult objects, which are wrappers
1061 This is a convenient way to construct AsyncResult objects, which are wrappers
1064 that include metadata about execution, and allow for awaiting results that
1062 that include metadata about execution, and allow for awaiting results that
1065 were not submitted by this Client.
1063 were not submitted by this Client.
1066
1064
1067 It can also be a convenient way to retrieve the metadata associated with
1065 It can also be a convenient way to retrieve the metadata associated with
1068 blocking execution, since it always retrieves
1066 blocking execution, since it always retrieves
1069
1067
1070 Examples
1068 Examples
1071 --------
1069 --------
1072 ::
1070 ::
1073
1071
1074 In [10]: r = client.apply()
1072 In [10]: r = client.apply()
1075
1073
1076 Parameters
1074 Parameters
1077 ----------
1075 ----------
1078
1076
1079 indices_or_msg_ids : integer history index, str msg_id, or list of either
1077 indices_or_msg_ids : integer history index, str msg_id, or list of either
1080 The indices or msg_ids of indices to be retrieved
1078 The indices or msg_ids of indices to be retrieved
1081
1079
1082 block : bool
1080 block : bool
1083 Whether to wait for the result to be done
1081 Whether to wait for the result to be done
1084
1082
1085 Returns
1083 Returns
1086 -------
1084 -------
1087
1085
1088 AsyncResult
1086 AsyncResult
1089 A single AsyncResult object will always be returned.
1087 A single AsyncResult object will always be returned.
1090
1088
1091 AsyncHubResult
1089 AsyncHubResult
1092 A subclass of AsyncResult that retrieves results from the Hub
1090 A subclass of AsyncResult that retrieves results from the Hub
1093
1091
1094 """
1092 """
1095 block = self.block if block is None else block
1093 block = self.block if block is None else block
1096 if indices_or_msg_ids is None:
1094 if indices_or_msg_ids is None:
1097 indices_or_msg_ids = -1
1095 indices_or_msg_ids = -1
1098
1096
1099 if not isinstance(indices_or_msg_ids, (list,tuple)):
1097 if not isinstance(indices_or_msg_ids, (list,tuple)):
1100 indices_or_msg_ids = [indices_or_msg_ids]
1098 indices_or_msg_ids = [indices_or_msg_ids]
1101
1099
1102 theids = []
1100 theids = []
1103 for id in indices_or_msg_ids:
1101 for id in indices_or_msg_ids:
1104 if isinstance(id, int):
1102 if isinstance(id, int):
1105 id = self.history[id]
1103 id = self.history[id]
1106 if not isinstance(id, basestring):
1104 if not isinstance(id, basestring):
1107 raise TypeError("indices must be str or int, not %r"%id)
1105 raise TypeError("indices must be str or int, not %r"%id)
1108 theids.append(id)
1106 theids.append(id)
1109
1107
1110 local_ids = filter(lambda msg_id: msg_id in self.history or msg_id in self.results, theids)
1108 local_ids = filter(lambda msg_id: msg_id in self.history or msg_id in self.results, theids)
1111 remote_ids = filter(lambda msg_id: msg_id not in local_ids, theids)
1109 remote_ids = filter(lambda msg_id: msg_id not in local_ids, theids)
1112
1110
1113 if remote_ids:
1111 if remote_ids:
1114 ar = AsyncHubResult(self, msg_ids=theids)
1112 ar = AsyncHubResult(self, msg_ids=theids)
1115 else:
1113 else:
1116 ar = AsyncResult(self, msg_ids=theids)
1114 ar = AsyncResult(self, msg_ids=theids)
1117
1115
1118 if block:
1116 if block:
1119 ar.wait()
1117 ar.wait()
1120
1118
1121 return ar
1119 return ar
1122
1120
1123 @spin_first
1121 @spin_first
1124 def resubmit(self, indices_or_msg_ids=None, subheader=None, block=None):
1122 def resubmit(self, indices_or_msg_ids=None, subheader=None, block=None):
1125 """Resubmit one or more tasks.
1123 """Resubmit one or more tasks.
1126
1124
1127 in-flight tasks may not be resubmitted.
1125 in-flight tasks may not be resubmitted.
1128
1126
1129 Parameters
1127 Parameters
1130 ----------
1128 ----------
1131
1129
1132 indices_or_msg_ids : integer history index, str msg_id, or list of either
1130 indices_or_msg_ids : integer history index, str msg_id, or list of either
1133 The indices or msg_ids of indices to be retrieved
1131 The indices or msg_ids of indices to be retrieved
1134
1132
1135 block : bool
1133 block : bool
1136 Whether to wait for the result to be done
1134 Whether to wait for the result to be done
1137
1135
1138 Returns
1136 Returns
1139 -------
1137 -------
1140
1138
1141 AsyncHubResult
1139 AsyncHubResult
1142 A subclass of AsyncResult that retrieves results from the Hub
1140 A subclass of AsyncResult that retrieves results from the Hub
1143
1141
1144 """
1142 """
1145 block = self.block if block is None else block
1143 block = self.block if block is None else block
1146 if indices_or_msg_ids is None:
1144 if indices_or_msg_ids is None:
1147 indices_or_msg_ids = -1
1145 indices_or_msg_ids = -1
1148
1146
1149 if not isinstance(indices_or_msg_ids, (list,tuple)):
1147 if not isinstance(indices_or_msg_ids, (list,tuple)):
1150 indices_or_msg_ids = [indices_or_msg_ids]
1148 indices_or_msg_ids = [indices_or_msg_ids]
1151
1149
1152 theids = []
1150 theids = []
1153 for id in indices_or_msg_ids:
1151 for id in indices_or_msg_ids:
1154 if isinstance(id, int):
1152 if isinstance(id, int):
1155 id = self.history[id]
1153 id = self.history[id]
1156 if not isinstance(id, basestring):
1154 if not isinstance(id, basestring):
1157 raise TypeError("indices must be str or int, not %r"%id)
1155 raise TypeError("indices must be str or int, not %r"%id)
1158 theids.append(id)
1156 theids.append(id)
1159
1157
1160 for msg_id in theids:
1158 for msg_id in theids:
1161 self.outstanding.discard(msg_id)
1159 self.outstanding.discard(msg_id)
1162 if msg_id in self.history:
1160 if msg_id in self.history:
1163 self.history.remove(msg_id)
1161 self.history.remove(msg_id)
1164 self.results.pop(msg_id, None)
1162 self.results.pop(msg_id, None)
1165 self.metadata.pop(msg_id, None)
1163 self.metadata.pop(msg_id, None)
1166 content = dict(msg_ids = theids)
1164 content = dict(msg_ids = theids)
1167
1165
1168 self.session.send(self._query_socket, 'resubmit_request', content)
1166 self.session.send(self._query_socket, 'resubmit_request', content)
1169
1167
1170 zmq.select([self._query_socket], [], [])
1168 zmq.select([self._query_socket], [], [])
1171 idents,msg = self.session.recv(self._query_socket, zmq.NOBLOCK)
1169 idents,msg = self.session.recv(self._query_socket, zmq.NOBLOCK)
1172 if self.debug:
1170 if self.debug:
1173 pprint(msg)
1171 pprint(msg)
1174 content = msg['content']
1172 content = msg['content']
1175 if content['status'] != 'ok':
1173 if content['status'] != 'ok':
1176 raise self._unwrap_exception(content)
1174 raise self._unwrap_exception(content)
1177
1175
1178 ar = AsyncHubResult(self, msg_ids=theids)
1176 ar = AsyncHubResult(self, msg_ids=theids)
1179
1177
1180 if block:
1178 if block:
1181 ar.wait()
1179 ar.wait()
1182
1180
1183 return ar
1181 return ar
1184
1182
1185 @spin_first
1183 @spin_first
1186 def result_status(self, msg_ids, status_only=True):
1184 def result_status(self, msg_ids, status_only=True):
1187 """Check on the status of the result(s) of the apply request with `msg_ids`.
1185 """Check on the status of the result(s) of the apply request with `msg_ids`.
1188
1186
1189 If status_only is False, then the actual results will be retrieved, else
1187 If status_only is False, then the actual results will be retrieved, else
1190 only the status of the results will be checked.
1188 only the status of the results will be checked.
1191
1189
1192 Parameters
1190 Parameters
1193 ----------
1191 ----------
1194
1192
1195 msg_ids : list of msg_ids
1193 msg_ids : list of msg_ids
1196 if int:
1194 if int:
1197 Passed as index to self.history for convenience.
1195 Passed as index to self.history for convenience.
1198 status_only : bool (default: True)
1196 status_only : bool (default: True)
1199 if False:
1197 if False:
1200 Retrieve the actual results of completed tasks.
1198 Retrieve the actual results of completed tasks.
1201
1199
1202 Returns
1200 Returns
1203 -------
1201 -------
1204
1202
1205 results : dict
1203 results : dict
1206 There will always be the keys 'pending' and 'completed', which will
1204 There will always be the keys 'pending' and 'completed', which will
1207 be lists of msg_ids that are incomplete or complete. If `status_only`
1205 be lists of msg_ids that are incomplete or complete. If `status_only`
1208 is False, then completed results will be keyed by their `msg_id`.
1206 is False, then completed results will be keyed by their `msg_id`.
1209 """
1207 """
1210 if not isinstance(msg_ids, (list,tuple)):
1208 if not isinstance(msg_ids, (list,tuple)):
1211 msg_ids = [msg_ids]
1209 msg_ids = [msg_ids]
1212
1210
1213 theids = []
1211 theids = []
1214 for msg_id in msg_ids:
1212 for msg_id in msg_ids:
1215 if isinstance(msg_id, int):
1213 if isinstance(msg_id, int):
1216 msg_id = self.history[msg_id]
1214 msg_id = self.history[msg_id]
1217 if not isinstance(msg_id, basestring):
1215 if not isinstance(msg_id, basestring):
1218 raise TypeError("msg_ids must be str, not %r"%msg_id)
1216 raise TypeError("msg_ids must be str, not %r"%msg_id)
1219 theids.append(msg_id)
1217 theids.append(msg_id)
1220
1218
1221 completed = []
1219 completed = []
1222 local_results = {}
1220 local_results = {}
1223
1221
1224 # comment this block out to temporarily disable local shortcut:
1222 # comment this block out to temporarily disable local shortcut:
1225 for msg_id in theids:
1223 for msg_id in theids:
1226 if msg_id in self.results:
1224 if msg_id in self.results:
1227 completed.append(msg_id)
1225 completed.append(msg_id)
1228 local_results[msg_id] = self.results[msg_id]
1226 local_results[msg_id] = self.results[msg_id]
1229 theids.remove(msg_id)
1227 theids.remove(msg_id)
1230
1228
1231 if theids: # some not locally cached
1229 if theids: # some not locally cached
1232 content = dict(msg_ids=theids, status_only=status_only)
1230 content = dict(msg_ids=theids, status_only=status_only)
1233 msg = self.session.send(self._query_socket, "result_request", content=content)
1231 msg = self.session.send(self._query_socket, "result_request", content=content)
1234 zmq.select([self._query_socket], [], [])
1232 zmq.select([self._query_socket], [], [])
1235 idents,msg = self.session.recv(self._query_socket, zmq.NOBLOCK)
1233 idents,msg = self.session.recv(self._query_socket, zmq.NOBLOCK)
1236 if self.debug:
1234 if self.debug:
1237 pprint(msg)
1235 pprint(msg)
1238 content = msg['content']
1236 content = msg['content']
1239 if content['status'] != 'ok':
1237 if content['status'] != 'ok':
1240 raise self._unwrap_exception(content)
1238 raise self._unwrap_exception(content)
1241 buffers = msg['buffers']
1239 buffers = msg['buffers']
1242 else:
1240 else:
1243 content = dict(completed=[],pending=[])
1241 content = dict(completed=[],pending=[])
1244
1242
1245 content['completed'].extend(completed)
1243 content['completed'].extend(completed)
1246
1244
1247 if status_only:
1245 if status_only:
1248 return content
1246 return content
1249
1247
1250 failures = []
1248 failures = []
1251 # load cached results into result:
1249 # load cached results into result:
1252 content.update(local_results)
1250 content.update(local_results)
1253
1251
1254 # update cache with results:
1252 # update cache with results:
1255 for msg_id in sorted(theids):
1253 for msg_id in sorted(theids):
1256 if msg_id in content['completed']:
1254 if msg_id in content['completed']:
1257 rec = content[msg_id]
1255 rec = content[msg_id]
1258 parent = rec['header']
1256 parent = rec['header']
1259 header = rec['result_header']
1257 header = rec['result_header']
1260 rcontent = rec['result_content']
1258 rcontent = rec['result_content']
1261 iodict = rec['io']
1259 iodict = rec['io']
1262 if isinstance(rcontent, str):
1260 if isinstance(rcontent, str):
1263 rcontent = self.session.unpack(rcontent)
1261 rcontent = self.session.unpack(rcontent)
1264
1262
1265 md = self.metadata[msg_id]
1263 md = self.metadata[msg_id]
1266 md.update(self._extract_metadata(header, parent, rcontent))
1264 md.update(self._extract_metadata(header, parent, rcontent))
1267 md.update(iodict)
1265 md.update(iodict)
1268
1266
1269 if rcontent['status'] == 'ok':
1267 if rcontent['status'] == 'ok':
1270 res,buffers = util.unserialize_object(buffers)
1268 res,buffers = util.unserialize_object(buffers)
1271 else:
1269 else:
1272 print rcontent
1270 print rcontent
1273 res = self._unwrap_exception(rcontent)
1271 res = self._unwrap_exception(rcontent)
1274 failures.append(res)
1272 failures.append(res)
1275
1273
1276 self.results[msg_id] = res
1274 self.results[msg_id] = res
1277 content[msg_id] = res
1275 content[msg_id] = res
1278
1276
1279 if len(theids) == 1 and failures:
1277 if len(theids) == 1 and failures:
1280 raise failures[0]
1278 raise failures[0]
1281
1279
1282 error.collect_exceptions(failures, "result_status")
1280 error.collect_exceptions(failures, "result_status")
1283 return content
1281 return content
1284
1282
1285 @spin_first
1283 @spin_first
1286 def queue_status(self, targets='all', verbose=False):
1284 def queue_status(self, targets='all', verbose=False):
1287 """Fetch the status of engine queues.
1285 """Fetch the status of engine queues.
1288
1286
1289 Parameters
1287 Parameters
1290 ----------
1288 ----------
1291
1289
1292 targets : int/str/list of ints/strs
1290 targets : int/str/list of ints/strs
1293 the engines whose states are to be queried.
1291 the engines whose states are to be queried.
1294 default : all
1292 default : all
1295 verbose : bool
1293 verbose : bool
1296 Whether to return lengths only, or lists of ids for each element
1294 Whether to return lengths only, or lists of ids for each element
1297 """
1295 """
1298 engine_ids = self._build_targets(targets)[1]
1296 engine_ids = self._build_targets(targets)[1]
1299 content = dict(targets=engine_ids, verbose=verbose)
1297 content = dict(targets=engine_ids, verbose=verbose)
1300 self.session.send(self._query_socket, "queue_request", content=content)
1298 self.session.send(self._query_socket, "queue_request", content=content)
1301 idents,msg = self.session.recv(self._query_socket, 0)
1299 idents,msg = self.session.recv(self._query_socket, 0)
1302 if self.debug:
1300 if self.debug:
1303 pprint(msg)
1301 pprint(msg)
1304 content = msg['content']
1302 content = msg['content']
1305 status = content.pop('status')
1303 status = content.pop('status')
1306 if status != 'ok':
1304 if status != 'ok':
1307 raise self._unwrap_exception(content)
1305 raise self._unwrap_exception(content)
1308 content = rekey(content)
1306 content = rekey(content)
1309 if isinstance(targets, int):
1307 if isinstance(targets, int):
1310 return content[targets]
1308 return content[targets]
1311 else:
1309 else:
1312 return content
1310 return content
1313
1311
1314 @spin_first
1312 @spin_first
1315 def purge_results(self, jobs=[], targets=[]):
1313 def purge_results(self, jobs=[], targets=[]):
1316 """Tell the Hub to forget results.
1314 """Tell the Hub to forget results.
1317
1315
1318 Individual results can be purged by msg_id, or the entire
1316 Individual results can be purged by msg_id, or the entire
1319 history of specific targets can be purged.
1317 history of specific targets can be purged.
1320
1318
1321 Use `purge_results('all')` to scrub everything from the Hub's db.
1319 Use `purge_results('all')` to scrub everything from the Hub's db.
1322
1320
1323 Parameters
1321 Parameters
1324 ----------
1322 ----------
1325
1323
1326 jobs : str or list of str or AsyncResult objects
1324 jobs : str or list of str or AsyncResult objects
1327 the msg_ids whose results should be forgotten.
1325 the msg_ids whose results should be forgotten.
1328 targets : int/str/list of ints/strs
1326 targets : int/str/list of ints/strs
1329 The targets, by int_id, whose entire history is to be purged.
1327 The targets, by int_id, whose entire history is to be purged.
1330
1328
1331 default : None
1329 default : None
1332 """
1330 """
1333 if not targets and not jobs:
1331 if not targets and not jobs:
1334 raise ValueError("Must specify at least one of `targets` and `jobs`")
1332 raise ValueError("Must specify at least one of `targets` and `jobs`")
1335 if targets:
1333 if targets:
1336 targets = self._build_targets(targets)[1]
1334 targets = self._build_targets(targets)[1]
1337
1335
1338 # construct msg_ids from jobs
1336 # construct msg_ids from jobs
1339 if jobs == 'all':
1337 if jobs == 'all':
1340 msg_ids = jobs
1338 msg_ids = jobs
1341 else:
1339 else:
1342 msg_ids = []
1340 msg_ids = []
1343 if isinstance(jobs, (basestring,AsyncResult)):
1341 if isinstance(jobs, (basestring,AsyncResult)):
1344 jobs = [jobs]
1342 jobs = [jobs]
1345 bad_ids = filter(lambda obj: not isinstance(obj, (basestring, AsyncResult)), jobs)
1343 bad_ids = filter(lambda obj: not isinstance(obj, (basestring, AsyncResult)), jobs)
1346 if bad_ids:
1344 if bad_ids:
1347 raise TypeError("Invalid msg_id type %r, expected str or AsyncResult"%bad_ids[0])
1345 raise TypeError("Invalid msg_id type %r, expected str or AsyncResult"%bad_ids[0])
1348 for j in jobs:
1346 for j in jobs:
1349 if isinstance(j, AsyncResult):
1347 if isinstance(j, AsyncResult):
1350 msg_ids.extend(j.msg_ids)
1348 msg_ids.extend(j.msg_ids)
1351 else:
1349 else:
1352 msg_ids.append(j)
1350 msg_ids.append(j)
1353
1351
1354 content = dict(engine_ids=targets, msg_ids=msg_ids)
1352 content = dict(engine_ids=targets, msg_ids=msg_ids)
1355 self.session.send(self._query_socket, "purge_request", content=content)
1353 self.session.send(self._query_socket, "purge_request", content=content)
1356 idents, msg = self.session.recv(self._query_socket, 0)
1354 idents, msg = self.session.recv(self._query_socket, 0)
1357 if self.debug:
1355 if self.debug:
1358 pprint(msg)
1356 pprint(msg)
1359 content = msg['content']
1357 content = msg['content']
1360 if content['status'] != 'ok':
1358 if content['status'] != 'ok':
1361 raise self._unwrap_exception(content)
1359 raise self._unwrap_exception(content)
1362
1360
1363 @spin_first
1361 @spin_first
1364 def hub_history(self):
1362 def hub_history(self):
1365 """Get the Hub's history
1363 """Get the Hub's history
1366
1364
1367 Just like the Client, the Hub has a history, which is a list of msg_ids.
1365 Just like the Client, the Hub has a history, which is a list of msg_ids.
1368 This will contain the history of all clients, and, depending on configuration,
1366 This will contain the history of all clients, and, depending on configuration,
1369 may contain history across multiple cluster sessions.
1367 may contain history across multiple cluster sessions.
1370
1368
1371 Any msg_id returned here is a valid argument to `get_result`.
1369 Any msg_id returned here is a valid argument to `get_result`.
1372
1370
1373 Returns
1371 Returns
1374 -------
1372 -------
1375
1373
1376 msg_ids : list of strs
1374 msg_ids : list of strs
1377 list of all msg_ids, ordered by task submission time.
1375 list of all msg_ids, ordered by task submission time.
1378 """
1376 """
1379
1377
1380 self.session.send(self._query_socket, "history_request", content={})
1378 self.session.send(self._query_socket, "history_request", content={})
1381 idents, msg = self.session.recv(self._query_socket, 0)
1379 idents, msg = self.session.recv(self._query_socket, 0)
1382
1380
1383 if self.debug:
1381 if self.debug:
1384 pprint(msg)
1382 pprint(msg)
1385 content = msg['content']
1383 content = msg['content']
1386 if content['status'] != 'ok':
1384 if content['status'] != 'ok':
1387 raise self._unwrap_exception(content)
1385 raise self._unwrap_exception(content)
1388 else:
1386 else:
1389 return content['history']
1387 return content['history']
1390
1388
1391 @spin_first
1389 @spin_first
1392 def db_query(self, query, keys=None):
1390 def db_query(self, query, keys=None):
1393 """Query the Hub's TaskRecord database
1391 """Query the Hub's TaskRecord database
1394
1392
1395 This will return a list of task record dicts that match `query`
1393 This will return a list of task record dicts that match `query`
1396
1394
1397 Parameters
1395 Parameters
1398 ----------
1396 ----------
1399
1397
1400 query : mongodb query dict
1398 query : mongodb query dict
1401 The search dict. See mongodb query docs for details.
1399 The search dict. See mongodb query docs for details.
1402 keys : list of strs [optional]
1400 keys : list of strs [optional]
1403 The subset of keys to be returned. The default is to fetch everything but buffers.
1401 The subset of keys to be returned. The default is to fetch everything but buffers.
1404 'msg_id' will *always* be included.
1402 'msg_id' will *always* be included.
1405 """
1403 """
1406 if isinstance(keys, basestring):
1404 if isinstance(keys, basestring):
1407 keys = [keys]
1405 keys = [keys]
1408 content = dict(query=query, keys=keys)
1406 content = dict(query=query, keys=keys)
1409 self.session.send(self._query_socket, "db_request", content=content)
1407 self.session.send(self._query_socket, "db_request", content=content)
1410 idents, msg = self.session.recv(self._query_socket, 0)
1408 idents, msg = self.session.recv(self._query_socket, 0)
1411 if self.debug:
1409 if self.debug:
1412 pprint(msg)
1410 pprint(msg)
1413 content = msg['content']
1411 content = msg['content']
1414 if content['status'] != 'ok':
1412 if content['status'] != 'ok':
1415 raise self._unwrap_exception(content)
1413 raise self._unwrap_exception(content)
1416
1414
1417 records = content['records']
1415 records = content['records']
1418
1416
1419 buffer_lens = content['buffer_lens']
1417 buffer_lens = content['buffer_lens']
1420 result_buffer_lens = content['result_buffer_lens']
1418 result_buffer_lens = content['result_buffer_lens']
1421 buffers = msg['buffers']
1419 buffers = msg['buffers']
1422 has_bufs = buffer_lens is not None
1420 has_bufs = buffer_lens is not None
1423 has_rbufs = result_buffer_lens is not None
1421 has_rbufs = result_buffer_lens is not None
1424 for i,rec in enumerate(records):
1422 for i,rec in enumerate(records):
1425 # relink buffers
1423 # relink buffers
1426 if has_bufs:
1424 if has_bufs:
1427 blen = buffer_lens[i]
1425 blen = buffer_lens[i]
1428 rec['buffers'], buffers = buffers[:blen],buffers[blen:]
1426 rec['buffers'], buffers = buffers[:blen],buffers[blen:]
1429 if has_rbufs:
1427 if has_rbufs:
1430 blen = result_buffer_lens[i]
1428 blen = result_buffer_lens[i]
1431 rec['result_buffers'], buffers = buffers[:blen],buffers[blen:]
1429 rec['result_buffers'], buffers = buffers[:blen],buffers[blen:]
1432
1430
1433 return records
1431 return records
1434
1432
1435 __all__ = [ 'Client' ]
1433 __all__ = [ 'Client' ]
@@ -1,463 +1,472 b''
1 """some generic utilities for dealing with classes, urls, and serialization
1 """some generic utilities for dealing with classes, urls, and serialization
2
2
3 Authors:
3 Authors:
4
4
5 * Min RK
5 * Min RK
6 """
6 """
7 #-----------------------------------------------------------------------------
7 #-----------------------------------------------------------------------------
8 # Copyright (C) 2010-2011 The IPython Development Team
8 # Copyright (C) 2010-2011 The IPython Development Team
9 #
9 #
10 # Distributed under the terms of the BSD License. The full license is in
10 # Distributed under the terms of the BSD License. The full license is in
11 # the file COPYING, distributed as part of this software.
11 # the file COPYING, distributed as part of this software.
12 #-----------------------------------------------------------------------------
12 #-----------------------------------------------------------------------------
13
13
14 #-----------------------------------------------------------------------------
14 #-----------------------------------------------------------------------------
15 # Imports
15 # Imports
16 #-----------------------------------------------------------------------------
16 #-----------------------------------------------------------------------------
17
17
18 # Standard library imports.
18 # Standard library imports.
19 import logging
19 import logging
20 import os
20 import os
21 import re
21 import re
22 import stat
22 import stat
23 import socket
23 import socket
24 import sys
24 import sys
25 from signal import signal, SIGINT, SIGABRT, SIGTERM
25 from signal import signal, SIGINT, SIGABRT, SIGTERM
26 try:
26 try:
27 from signal import SIGKILL
27 from signal import SIGKILL
28 except ImportError:
28 except ImportError:
29 SIGKILL=None
29 SIGKILL=None
30
30
31 try:
31 try:
32 import cPickle
32 import cPickle
33 pickle = cPickle
33 pickle = cPickle
34 except:
34 except:
35 cPickle = None
35 cPickle = None
36 import pickle
36 import pickle
37
37
38 # System library imports
38 # System library imports
39 import zmq
39 import zmq
40 from zmq.log import handlers
40 from zmq.log import handlers
41
41
42 # IPython imports
42 # IPython imports
43 from IPython.config.application import Application
43 from IPython.config.application import Application
44 from IPython.utils.pickleutil import can, uncan, canSequence, uncanSequence
44 from IPython.utils.pickleutil import can, uncan, canSequence, uncanSequence
45 from IPython.utils.newserialized import serialize, unserialize
45 from IPython.utils.newserialized import serialize, unserialize
46 from IPython.zmq.log import EnginePUBHandler
46 from IPython.zmq.log import EnginePUBHandler
47
47
48 #-----------------------------------------------------------------------------
48 #-----------------------------------------------------------------------------
49 # Classes
49 # Classes
50 #-----------------------------------------------------------------------------
50 #-----------------------------------------------------------------------------
51
51
52 class Namespace(dict):
52 class Namespace(dict):
53 """Subclass of dict for attribute access to keys."""
53 """Subclass of dict for attribute access to keys."""
54
54
55 def __getattr__(self, key):
55 def __getattr__(self, key):
56 """getattr aliased to getitem"""
56 """getattr aliased to getitem"""
57 if key in self.iterkeys():
57 if key in self.iterkeys():
58 return self[key]
58 return self[key]
59 else:
59 else:
60 raise NameError(key)
60 raise NameError(key)
61
61
62 def __setattr__(self, key, value):
62 def __setattr__(self, key, value):
63 """setattr aliased to setitem, with strict"""
63 """setattr aliased to setitem, with strict"""
64 if hasattr(dict, key):
64 if hasattr(dict, key):
65 raise KeyError("Cannot override dict keys %r"%key)
65 raise KeyError("Cannot override dict keys %r"%key)
66 self[key] = value
66 self[key] = value
67
67
68
68
69 class ReverseDict(dict):
69 class ReverseDict(dict):
70 """simple double-keyed subset of dict methods."""
70 """simple double-keyed subset of dict methods."""
71
71
72 def __init__(self, *args, **kwargs):
72 def __init__(self, *args, **kwargs):
73 dict.__init__(self, *args, **kwargs)
73 dict.__init__(self, *args, **kwargs)
74 self._reverse = dict()
74 self._reverse = dict()
75 for key, value in self.iteritems():
75 for key, value in self.iteritems():
76 self._reverse[value] = key
76 self._reverse[value] = key
77
77
78 def __getitem__(self, key):
78 def __getitem__(self, key):
79 try:
79 try:
80 return dict.__getitem__(self, key)
80 return dict.__getitem__(self, key)
81 except KeyError:
81 except KeyError:
82 return self._reverse[key]
82 return self._reverse[key]
83
83
84 def __setitem__(self, key, value):
84 def __setitem__(self, key, value):
85 if key in self._reverse:
85 if key in self._reverse:
86 raise KeyError("Can't have key %r on both sides!"%key)
86 raise KeyError("Can't have key %r on both sides!"%key)
87 dict.__setitem__(self, key, value)
87 dict.__setitem__(self, key, value)
88 self._reverse[value] = key
88 self._reverse[value] = key
89
89
90 def pop(self, key):
90 def pop(self, key):
91 value = dict.pop(self, key)
91 value = dict.pop(self, key)
92 self._reverse.pop(value)
92 self._reverse.pop(value)
93 return value
93 return value
94
94
95 def get(self, key, default=None):
95 def get(self, key, default=None):
96 try:
96 try:
97 return self[key]
97 return self[key]
98 except KeyError:
98 except KeyError:
99 return default
99 return default
100
100
101 #-----------------------------------------------------------------------------
101 #-----------------------------------------------------------------------------
102 # Functions
102 # Functions
103 #-----------------------------------------------------------------------------
103 #-----------------------------------------------------------------------------
104
104
105 def asbytes(s):
105 def asbytes(s):
106 """ensure that an object is ascii bytes"""
106 """ensure that an object is ascii bytes"""
107 if isinstance(s, unicode):
107 if isinstance(s, unicode):
108 s = s.encode('ascii')
108 s = s.encode('ascii')
109 return s
109 return s
110
110
111 def is_url(url):
112 """boolean check for whether a string is a zmq url"""
113 if '://' not in url:
114 return False
115 proto, addr = url.split('://', 1)
116 if proto.lower() not in ['tcp','pgm','epgm','ipc','inproc']:
117 return False
118 return True
119
111 def validate_url(url):
120 def validate_url(url):
112 """validate a url for zeromq"""
121 """validate a url for zeromq"""
113 if not isinstance(url, basestring):
122 if not isinstance(url, basestring):
114 raise TypeError("url must be a string, not %r"%type(url))
123 raise TypeError("url must be a string, not %r"%type(url))
115 url = url.lower()
124 url = url.lower()
116
125
117 proto_addr = url.split('://')
126 proto_addr = url.split('://')
118 assert len(proto_addr) == 2, 'Invalid url: %r'%url
127 assert len(proto_addr) == 2, 'Invalid url: %r'%url
119 proto, addr = proto_addr
128 proto, addr = proto_addr
120 assert proto in ['tcp','pgm','epgm','ipc','inproc'], "Invalid protocol: %r"%proto
129 assert proto in ['tcp','pgm','epgm','ipc','inproc'], "Invalid protocol: %r"%proto
121
130
122 # domain pattern adapted from http://www.regexlib.com/REDetails.aspx?regexp_id=391
131 # domain pattern adapted from http://www.regexlib.com/REDetails.aspx?regexp_id=391
123 # author: Remi Sabourin
132 # author: Remi Sabourin
124 pat = re.compile(r'^([\w\d]([\w\d\-]{0,61}[\w\d])?\.)*[\w\d]([\w\d\-]{0,61}[\w\d])?$')
133 pat = re.compile(r'^([\w\d]([\w\d\-]{0,61}[\w\d])?\.)*[\w\d]([\w\d\-]{0,61}[\w\d])?$')
125
134
126 if proto == 'tcp':
135 if proto == 'tcp':
127 lis = addr.split(':')
136 lis = addr.split(':')
128 assert len(lis) == 2, 'Invalid url: %r'%url
137 assert len(lis) == 2, 'Invalid url: %r'%url
129 addr,s_port = lis
138 addr,s_port = lis
130 try:
139 try:
131 port = int(s_port)
140 port = int(s_port)
132 except ValueError:
141 except ValueError:
133 raise AssertionError("Invalid port %r in url: %r"%(port, url))
142 raise AssertionError("Invalid port %r in url: %r"%(port, url))
134
143
135 assert addr == '*' or pat.match(addr) is not None, 'Invalid url: %r'%url
144 assert addr == '*' or pat.match(addr) is not None, 'Invalid url: %r'%url
136
145
137 else:
146 else:
138 # only validate tcp urls currently
147 # only validate tcp urls currently
139 pass
148 pass
140
149
141 return True
150 return True
142
151
143
152
144 def validate_url_container(container):
153 def validate_url_container(container):
145 """validate a potentially nested collection of urls."""
154 """validate a potentially nested collection of urls."""
146 if isinstance(container, basestring):
155 if isinstance(container, basestring):
147 url = container
156 url = container
148 return validate_url(url)
157 return validate_url(url)
149 elif isinstance(container, dict):
158 elif isinstance(container, dict):
150 container = container.itervalues()
159 container = container.itervalues()
151
160
152 for element in container:
161 for element in container:
153 validate_url_container(element)
162 validate_url_container(element)
154
163
155
164
156 def split_url(url):
165 def split_url(url):
157 """split a zmq url (tcp://ip:port) into ('tcp','ip','port')."""
166 """split a zmq url (tcp://ip:port) into ('tcp','ip','port')."""
158 proto_addr = url.split('://')
167 proto_addr = url.split('://')
159 assert len(proto_addr) == 2, 'Invalid url: %r'%url
168 assert len(proto_addr) == 2, 'Invalid url: %r'%url
160 proto, addr = proto_addr
169 proto, addr = proto_addr
161 lis = addr.split(':')
170 lis = addr.split(':')
162 assert len(lis) == 2, 'Invalid url: %r'%url
171 assert len(lis) == 2, 'Invalid url: %r'%url
163 addr,s_port = lis
172 addr,s_port = lis
164 return proto,addr,s_port
173 return proto,addr,s_port
165
174
166 def disambiguate_ip_address(ip, location=None):
175 def disambiguate_ip_address(ip, location=None):
167 """turn multi-ip interfaces '0.0.0.0' and '*' into connectable
176 """turn multi-ip interfaces '0.0.0.0' and '*' into connectable
168 ones, based on the location (default interpretation of location is localhost)."""
177 ones, based on the location (default interpretation of location is localhost)."""
169 if ip in ('0.0.0.0', '*'):
178 if ip in ('0.0.0.0', '*'):
170 try:
179 try:
171 external_ips = socket.gethostbyname_ex(socket.gethostname())[2]
180 external_ips = socket.gethostbyname_ex(socket.gethostname())[2]
172 except (socket.gaierror, IndexError):
181 except (socket.gaierror, IndexError):
173 # couldn't identify this machine, assume localhost
182 # couldn't identify this machine, assume localhost
174 external_ips = []
183 external_ips = []
175 if location is None or location in external_ips or not external_ips:
184 if location is None or location in external_ips or not external_ips:
176 # If location is unspecified or cannot be determined, assume local
185 # If location is unspecified or cannot be determined, assume local
177 ip='127.0.0.1'
186 ip='127.0.0.1'
178 elif location:
187 elif location:
179 return location
188 return location
180 return ip
189 return ip
181
190
182 def disambiguate_url(url, location=None):
191 def disambiguate_url(url, location=None):
183 """turn multi-ip interfaces '0.0.0.0' and '*' into connectable
192 """turn multi-ip interfaces '0.0.0.0' and '*' into connectable
184 ones, based on the location (default interpretation is localhost).
193 ones, based on the location (default interpretation is localhost).
185
194
186 This is for zeromq urls, such as tcp://*:10101."""
195 This is for zeromq urls, such as tcp://*:10101."""
187 try:
196 try:
188 proto,ip,port = split_url(url)
197 proto,ip,port = split_url(url)
189 except AssertionError:
198 except AssertionError:
190 # probably not tcp url; could be ipc, etc.
199 # probably not tcp url; could be ipc, etc.
191 return url
200 return url
192
201
193 ip = disambiguate_ip_address(ip,location)
202 ip = disambiguate_ip_address(ip,location)
194
203
195 return "%s://%s:%s"%(proto,ip,port)
204 return "%s://%s:%s"%(proto,ip,port)
196
205
197 def serialize_object(obj, threshold=64e-6):
206 def serialize_object(obj, threshold=64e-6):
198 """Serialize an object into a list of sendable buffers.
207 """Serialize an object into a list of sendable buffers.
199
208
200 Parameters
209 Parameters
201 ----------
210 ----------
202
211
203 obj : object
212 obj : object
204 The object to be serialized
213 The object to be serialized
205 threshold : float
214 threshold : float
206 The threshold for not double-pickling the content.
215 The threshold for not double-pickling the content.
207
216
208
217
209 Returns
218 Returns
210 -------
219 -------
211 ('pmd', [bufs]) :
220 ('pmd', [bufs]) :
212 where pmd is the pickled metadata wrapper,
221 where pmd is the pickled metadata wrapper,
213 bufs is a list of data buffers
222 bufs is a list of data buffers
214 """
223 """
215 databuffers = []
224 databuffers = []
216 if isinstance(obj, (list, tuple)):
225 if isinstance(obj, (list, tuple)):
217 clist = canSequence(obj)
226 clist = canSequence(obj)
218 slist = map(serialize, clist)
227 slist = map(serialize, clist)
219 for s in slist:
228 for s in slist:
220 if s.typeDescriptor in ('buffer', 'ndarray') or s.getDataSize() > threshold:
229 if s.typeDescriptor in ('buffer', 'ndarray') or s.getDataSize() > threshold:
221 databuffers.append(s.getData())
230 databuffers.append(s.getData())
222 s.data = None
231 s.data = None
223 return pickle.dumps(slist,-1), databuffers
232 return pickle.dumps(slist,-1), databuffers
224 elif isinstance(obj, dict):
233 elif isinstance(obj, dict):
225 sobj = {}
234 sobj = {}
226 for k in sorted(obj.iterkeys()):
235 for k in sorted(obj.iterkeys()):
227 s = serialize(can(obj[k]))
236 s = serialize(can(obj[k]))
228 if s.typeDescriptor in ('buffer', 'ndarray') or s.getDataSize() > threshold:
237 if s.typeDescriptor in ('buffer', 'ndarray') or s.getDataSize() > threshold:
229 databuffers.append(s.getData())
238 databuffers.append(s.getData())
230 s.data = None
239 s.data = None
231 sobj[k] = s
240 sobj[k] = s
232 return pickle.dumps(sobj,-1),databuffers
241 return pickle.dumps(sobj,-1),databuffers
233 else:
242 else:
234 s = serialize(can(obj))
243 s = serialize(can(obj))
235 if s.typeDescriptor in ('buffer', 'ndarray') or s.getDataSize() > threshold:
244 if s.typeDescriptor in ('buffer', 'ndarray') or s.getDataSize() > threshold:
236 databuffers.append(s.getData())
245 databuffers.append(s.getData())
237 s.data = None
246 s.data = None
238 return pickle.dumps(s,-1),databuffers
247 return pickle.dumps(s,-1),databuffers
239
248
240
249
241 def unserialize_object(bufs):
250 def unserialize_object(bufs):
242 """reconstruct an object serialized by serialize_object from data buffers."""
251 """reconstruct an object serialized by serialize_object from data buffers."""
243 bufs = list(bufs)
252 bufs = list(bufs)
244 sobj = pickle.loads(bufs.pop(0))
253 sobj = pickle.loads(bufs.pop(0))
245 if isinstance(sobj, (list, tuple)):
254 if isinstance(sobj, (list, tuple)):
246 for s in sobj:
255 for s in sobj:
247 if s.data is None:
256 if s.data is None:
248 s.data = bufs.pop(0)
257 s.data = bufs.pop(0)
249 return uncanSequence(map(unserialize, sobj)), bufs
258 return uncanSequence(map(unserialize, sobj)), bufs
250 elif isinstance(sobj, dict):
259 elif isinstance(sobj, dict):
251 newobj = {}
260 newobj = {}
252 for k in sorted(sobj.iterkeys()):
261 for k in sorted(sobj.iterkeys()):
253 s = sobj[k]
262 s = sobj[k]
254 if s.data is None:
263 if s.data is None:
255 s.data = bufs.pop(0)
264 s.data = bufs.pop(0)
256 newobj[k] = uncan(unserialize(s))
265 newobj[k] = uncan(unserialize(s))
257 return newobj, bufs
266 return newobj, bufs
258 else:
267 else:
259 if sobj.data is None:
268 if sobj.data is None:
260 sobj.data = bufs.pop(0)
269 sobj.data = bufs.pop(0)
261 return uncan(unserialize(sobj)), bufs
270 return uncan(unserialize(sobj)), bufs
262
271
263 def pack_apply_message(f, args, kwargs, threshold=64e-6):
272 def pack_apply_message(f, args, kwargs, threshold=64e-6):
264 """pack up a function, args, and kwargs to be sent over the wire
273 """pack up a function, args, and kwargs to be sent over the wire
265 as a series of buffers. Any object whose data is larger than `threshold`
274 as a series of buffers. Any object whose data is larger than `threshold`
266 will not have their data copied (currently only numpy arrays support zero-copy)"""
275 will not have their data copied (currently only numpy arrays support zero-copy)"""
267 msg = [pickle.dumps(can(f),-1)]
276 msg = [pickle.dumps(can(f),-1)]
268 databuffers = [] # for large objects
277 databuffers = [] # for large objects
269 sargs, bufs = serialize_object(args,threshold)
278 sargs, bufs = serialize_object(args,threshold)
270 msg.append(sargs)
279 msg.append(sargs)
271 databuffers.extend(bufs)
280 databuffers.extend(bufs)
272 skwargs, bufs = serialize_object(kwargs,threshold)
281 skwargs, bufs = serialize_object(kwargs,threshold)
273 msg.append(skwargs)
282 msg.append(skwargs)
274 databuffers.extend(bufs)
283 databuffers.extend(bufs)
275 msg.extend(databuffers)
284 msg.extend(databuffers)
276 return msg
285 return msg
277
286
278 def unpack_apply_message(bufs, g=None, copy=True):
287 def unpack_apply_message(bufs, g=None, copy=True):
279 """unpack f,args,kwargs from buffers packed by pack_apply_message()
288 """unpack f,args,kwargs from buffers packed by pack_apply_message()
280 Returns: original f,args,kwargs"""
289 Returns: original f,args,kwargs"""
281 bufs = list(bufs) # allow us to pop
290 bufs = list(bufs) # allow us to pop
282 assert len(bufs) >= 3, "not enough buffers!"
291 assert len(bufs) >= 3, "not enough buffers!"
283 if not copy:
292 if not copy:
284 for i in range(3):
293 for i in range(3):
285 bufs[i] = bufs[i].bytes
294 bufs[i] = bufs[i].bytes
286 cf = pickle.loads(bufs.pop(0))
295 cf = pickle.loads(bufs.pop(0))
287 sargs = list(pickle.loads(bufs.pop(0)))
296 sargs = list(pickle.loads(bufs.pop(0)))
288 skwargs = dict(pickle.loads(bufs.pop(0)))
297 skwargs = dict(pickle.loads(bufs.pop(0)))
289 # print sargs, skwargs
298 # print sargs, skwargs
290 f = uncan(cf, g)
299 f = uncan(cf, g)
291 for sa in sargs:
300 for sa in sargs:
292 if sa.data is None:
301 if sa.data is None:
293 m = bufs.pop(0)
302 m = bufs.pop(0)
294 if sa.getTypeDescriptor() in ('buffer', 'ndarray'):
303 if sa.getTypeDescriptor() in ('buffer', 'ndarray'):
295 # always use a buffer, until memoryviews get sorted out
304 # always use a buffer, until memoryviews get sorted out
296 sa.data = buffer(m)
305 sa.data = buffer(m)
297 # disable memoryview support
306 # disable memoryview support
298 # if copy:
307 # if copy:
299 # sa.data = buffer(m)
308 # sa.data = buffer(m)
300 # else:
309 # else:
301 # sa.data = m.buffer
310 # sa.data = m.buffer
302 else:
311 else:
303 if copy:
312 if copy:
304 sa.data = m
313 sa.data = m
305 else:
314 else:
306 sa.data = m.bytes
315 sa.data = m.bytes
307
316
308 args = uncanSequence(map(unserialize, sargs), g)
317 args = uncanSequence(map(unserialize, sargs), g)
309 kwargs = {}
318 kwargs = {}
310 for k in sorted(skwargs.iterkeys()):
319 for k in sorted(skwargs.iterkeys()):
311 sa = skwargs[k]
320 sa = skwargs[k]
312 if sa.data is None:
321 if sa.data is None:
313 m = bufs.pop(0)
322 m = bufs.pop(0)
314 if sa.getTypeDescriptor() in ('buffer', 'ndarray'):
323 if sa.getTypeDescriptor() in ('buffer', 'ndarray'):
315 # always use a buffer, until memoryviews get sorted out
324 # always use a buffer, until memoryviews get sorted out
316 sa.data = buffer(m)
325 sa.data = buffer(m)
317 # disable memoryview support
326 # disable memoryview support
318 # if copy:
327 # if copy:
319 # sa.data = buffer(m)
328 # sa.data = buffer(m)
320 # else:
329 # else:
321 # sa.data = m.buffer
330 # sa.data = m.buffer
322 else:
331 else:
323 if copy:
332 if copy:
324 sa.data = m
333 sa.data = m
325 else:
334 else:
326 sa.data = m.bytes
335 sa.data = m.bytes
327
336
328 kwargs[k] = uncan(unserialize(sa), g)
337 kwargs[k] = uncan(unserialize(sa), g)
329
338
330 return f,args,kwargs
339 return f,args,kwargs
331
340
332 #--------------------------------------------------------------------------
341 #--------------------------------------------------------------------------
333 # helpers for implementing old MEC API via view.apply
342 # helpers for implementing old MEC API via view.apply
334 #--------------------------------------------------------------------------
343 #--------------------------------------------------------------------------
335
344
336 def interactive(f):
345 def interactive(f):
337 """decorator for making functions appear as interactively defined.
346 """decorator for making functions appear as interactively defined.
338 This results in the function being linked to the user_ns as globals()
347 This results in the function being linked to the user_ns as globals()
339 instead of the module globals().
348 instead of the module globals().
340 """
349 """
341 f.__module__ = '__main__'
350 f.__module__ = '__main__'
342 return f
351 return f
343
352
344 @interactive
353 @interactive
345 def _push(ns):
354 def _push(ns):
346 """helper method for implementing `client.push` via `client.apply`"""
355 """helper method for implementing `client.push` via `client.apply`"""
347 globals().update(ns)
356 globals().update(ns)
348
357
349 @interactive
358 @interactive
350 def _pull(keys):
359 def _pull(keys):
351 """helper method for implementing `client.pull` via `client.apply`"""
360 """helper method for implementing `client.pull` via `client.apply`"""
352 user_ns = globals()
361 user_ns = globals()
353 if isinstance(keys, (list,tuple, set)):
362 if isinstance(keys, (list,tuple, set)):
354 for key in keys:
363 for key in keys:
355 if not user_ns.has_key(key):
364 if not user_ns.has_key(key):
356 raise NameError("name '%s' is not defined"%key)
365 raise NameError("name '%s' is not defined"%key)
357 return map(user_ns.get, keys)
366 return map(user_ns.get, keys)
358 else:
367 else:
359 if not user_ns.has_key(keys):
368 if not user_ns.has_key(keys):
360 raise NameError("name '%s' is not defined"%keys)
369 raise NameError("name '%s' is not defined"%keys)
361 return user_ns.get(keys)
370 return user_ns.get(keys)
362
371
363 @interactive
372 @interactive
364 def _execute(code):
373 def _execute(code):
365 """helper method for implementing `client.execute` via `client.apply`"""
374 """helper method for implementing `client.execute` via `client.apply`"""
366 exec code in globals()
375 exec code in globals()
367
376
368 #--------------------------------------------------------------------------
377 #--------------------------------------------------------------------------
369 # extra process management utilities
378 # extra process management utilities
370 #--------------------------------------------------------------------------
379 #--------------------------------------------------------------------------
371
380
372 _random_ports = set()
381 _random_ports = set()
373
382
374 def select_random_ports(n):
383 def select_random_ports(n):
375 """Selects and return n random ports that are available."""
384 """Selects and return n random ports that are available."""
376 ports = []
385 ports = []
377 for i in xrange(n):
386 for i in xrange(n):
378 sock = socket.socket()
387 sock = socket.socket()
379 sock.bind(('', 0))
388 sock.bind(('', 0))
380 while sock.getsockname()[1] in _random_ports:
389 while sock.getsockname()[1] in _random_ports:
381 sock.close()
390 sock.close()
382 sock = socket.socket()
391 sock = socket.socket()
383 sock.bind(('', 0))
392 sock.bind(('', 0))
384 ports.append(sock)
393 ports.append(sock)
385 for i, sock in enumerate(ports):
394 for i, sock in enumerate(ports):
386 port = sock.getsockname()[1]
395 port = sock.getsockname()[1]
387 sock.close()
396 sock.close()
388 ports[i] = port
397 ports[i] = port
389 _random_ports.add(port)
398 _random_ports.add(port)
390 return ports
399 return ports
391
400
392 def signal_children(children):
401 def signal_children(children):
393 """Relay interupt/term signals to children, for more solid process cleanup."""
402 """Relay interupt/term signals to children, for more solid process cleanup."""
394 def terminate_children(sig, frame):
403 def terminate_children(sig, frame):
395 log = Application.instance().log
404 log = Application.instance().log
396 log.critical("Got signal %i, terminating children..."%sig)
405 log.critical("Got signal %i, terminating children..."%sig)
397 for child in children:
406 for child in children:
398 child.terminate()
407 child.terminate()
399
408
400 sys.exit(sig != SIGINT)
409 sys.exit(sig != SIGINT)
401 # sys.exit(sig)
410 # sys.exit(sig)
402 for sig in (SIGINT, SIGABRT, SIGTERM):
411 for sig in (SIGINT, SIGABRT, SIGTERM):
403 signal(sig, terminate_children)
412 signal(sig, terminate_children)
404
413
405 def generate_exec_key(keyfile):
414 def generate_exec_key(keyfile):
406 import uuid
415 import uuid
407 newkey = str(uuid.uuid4())
416 newkey = str(uuid.uuid4())
408 with open(keyfile, 'w') as f:
417 with open(keyfile, 'w') as f:
409 # f.write('ipython-key ')
418 # f.write('ipython-key ')
410 f.write(newkey+'\n')
419 f.write(newkey+'\n')
411 # set user-only RW permissions (0600)
420 # set user-only RW permissions (0600)
412 # this will have no effect on Windows
421 # this will have no effect on Windows
413 os.chmod(keyfile, stat.S_IRUSR|stat.S_IWUSR)
422 os.chmod(keyfile, stat.S_IRUSR|stat.S_IWUSR)
414
423
415
424
416 def integer_loglevel(loglevel):
425 def integer_loglevel(loglevel):
417 try:
426 try:
418 loglevel = int(loglevel)
427 loglevel = int(loglevel)
419 except ValueError:
428 except ValueError:
420 if isinstance(loglevel, str):
429 if isinstance(loglevel, str):
421 loglevel = getattr(logging, loglevel)
430 loglevel = getattr(logging, loglevel)
422 return loglevel
431 return loglevel
423
432
424 def connect_logger(logname, context, iface, root="ip", loglevel=logging.DEBUG):
433 def connect_logger(logname, context, iface, root="ip", loglevel=logging.DEBUG):
425 logger = logging.getLogger(logname)
434 logger = logging.getLogger(logname)
426 if any([isinstance(h, handlers.PUBHandler) for h in logger.handlers]):
435 if any([isinstance(h, handlers.PUBHandler) for h in logger.handlers]):
427 # don't add a second PUBHandler
436 # don't add a second PUBHandler
428 return
437 return
429 loglevel = integer_loglevel(loglevel)
438 loglevel = integer_loglevel(loglevel)
430 lsock = context.socket(zmq.PUB)
439 lsock = context.socket(zmq.PUB)
431 lsock.connect(iface)
440 lsock.connect(iface)
432 handler = handlers.PUBHandler(lsock)
441 handler = handlers.PUBHandler(lsock)
433 handler.setLevel(loglevel)
442 handler.setLevel(loglevel)
434 handler.root_topic = root
443 handler.root_topic = root
435 logger.addHandler(handler)
444 logger.addHandler(handler)
436 logger.setLevel(loglevel)
445 logger.setLevel(loglevel)
437
446
438 def connect_engine_logger(context, iface, engine, loglevel=logging.DEBUG):
447 def connect_engine_logger(context, iface, engine, loglevel=logging.DEBUG):
439 logger = logging.getLogger()
448 logger = logging.getLogger()
440 if any([isinstance(h, handlers.PUBHandler) for h in logger.handlers]):
449 if any([isinstance(h, handlers.PUBHandler) for h in logger.handlers]):
441 # don't add a second PUBHandler
450 # don't add a second PUBHandler
442 return
451 return
443 loglevel = integer_loglevel(loglevel)
452 loglevel = integer_loglevel(loglevel)
444 lsock = context.socket(zmq.PUB)
453 lsock = context.socket(zmq.PUB)
445 lsock.connect(iface)
454 lsock.connect(iface)
446 handler = EnginePUBHandler(engine, lsock)
455 handler = EnginePUBHandler(engine, lsock)
447 handler.setLevel(loglevel)
456 handler.setLevel(loglevel)
448 logger.addHandler(handler)
457 logger.addHandler(handler)
449 logger.setLevel(loglevel)
458 logger.setLevel(loglevel)
450 return logger
459 return logger
451
460
452 def local_logger(logname, loglevel=logging.DEBUG):
461 def local_logger(logname, loglevel=logging.DEBUG):
453 loglevel = integer_loglevel(loglevel)
462 loglevel = integer_loglevel(loglevel)
454 logger = logging.getLogger(logname)
463 logger = logging.getLogger(logname)
455 if any([isinstance(h, logging.StreamHandler) for h in logger.handlers]):
464 if any([isinstance(h, logging.StreamHandler) for h in logger.handlers]):
456 # don't add a second StreamHandler
465 # don't add a second StreamHandler
457 return
466 return
458 handler = logging.StreamHandler()
467 handler = logging.StreamHandler()
459 handler.setLevel(loglevel)
468 handler.setLevel(loglevel)
460 logger.addHandler(handler)
469 logger.addHandler(handler)
461 logger.setLevel(loglevel)
470 logger.setLevel(loglevel)
462 return logger
471 return logger
463
472
General Comments 0
You need to be logged in to leave comments. Login now