##// END OF EJS Templates
move rekey to jsonutil from parallel.util...
MinRK -
Show More
@@ -1,1373 +1,1374
1 """A semi-synchronous Client for the ZMQ cluster
1 """A semi-synchronous Client for the ZMQ cluster
2
2
3 Authors:
3 Authors:
4
4
5 * MinRK
5 * MinRK
6 """
6 """
7 #-----------------------------------------------------------------------------
7 #-----------------------------------------------------------------------------
8 # Copyright (C) 2010-2011 The IPython Development Team
8 # Copyright (C) 2010-2011 The IPython Development Team
9 #
9 #
10 # Distributed under the terms of the BSD License. The full license is in
10 # Distributed under the terms of the BSD License. The full license is in
11 # the file COPYING, distributed as part of this software.
11 # the file COPYING, distributed as part of this software.
12 #-----------------------------------------------------------------------------
12 #-----------------------------------------------------------------------------
13
13
14 #-----------------------------------------------------------------------------
14 #-----------------------------------------------------------------------------
15 # Imports
15 # Imports
16 #-----------------------------------------------------------------------------
16 #-----------------------------------------------------------------------------
17
17
18 import os
18 import os
19 import json
19 import json
20 import time
20 import time
21 import warnings
21 import warnings
22 from datetime import datetime
22 from datetime import datetime
23 from getpass import getpass
23 from getpass import getpass
24 from pprint import pprint
24 from pprint import pprint
25
25
26 pjoin = os.path.join
26 pjoin = os.path.join
27
27
28 import zmq
28 import zmq
29 # from zmq.eventloop import ioloop, zmqstream
29 # from zmq.eventloop import ioloop, zmqstream
30
30
31 from IPython.utils.jsonutil import rekey
31 from IPython.utils.path import get_ipython_dir
32 from IPython.utils.path import get_ipython_dir
32 from IPython.utils.traitlets import (HasTraits, Int, Instance, Unicode,
33 from IPython.utils.traitlets import (HasTraits, Int, Instance, Unicode,
33 Dict, List, Bool, Set)
34 Dict, List, Bool, Set)
34 from IPython.external.decorator import decorator
35 from IPython.external.decorator import decorator
35 from IPython.external.ssh import tunnel
36 from IPython.external.ssh import tunnel
36
37
37 from IPython.parallel import error
38 from IPython.parallel import error
38 from IPython.parallel import util
39 from IPython.parallel import util
39
40
40 from IPython.zmq.session import Session, Message
41 from IPython.zmq.session import Session, Message
41
42
42 from .asyncresult import AsyncResult, AsyncHubResult
43 from .asyncresult import AsyncResult, AsyncHubResult
43 from IPython.core.profiledir import ProfileDir, ProfileDirError
44 from IPython.core.profiledir import ProfileDir, ProfileDirError
44 from .view import DirectView, LoadBalancedView
45 from .view import DirectView, LoadBalancedView
45
46
46 #--------------------------------------------------------------------------
47 #--------------------------------------------------------------------------
47 # Decorators for Client methods
48 # Decorators for Client methods
48 #--------------------------------------------------------------------------
49 #--------------------------------------------------------------------------
49
50
50 @decorator
51 @decorator
51 def spin_first(f, self, *args, **kwargs):
52 def spin_first(f, self, *args, **kwargs):
52 """Call spin() to sync state prior to calling the method."""
53 """Call spin() to sync state prior to calling the method."""
53 self.spin()
54 self.spin()
54 return f(self, *args, **kwargs)
55 return f(self, *args, **kwargs)
55
56
56
57
57 #--------------------------------------------------------------------------
58 #--------------------------------------------------------------------------
58 # Classes
59 # Classes
59 #--------------------------------------------------------------------------
60 #--------------------------------------------------------------------------
60
61
61 class Metadata(dict):
62 class Metadata(dict):
62 """Subclass of dict for initializing metadata values.
63 """Subclass of dict for initializing metadata values.
63
64
64 Attribute access works on keys.
65 Attribute access works on keys.
65
66
66 These objects have a strict set of keys - errors will raise if you try
67 These objects have a strict set of keys - errors will raise if you try
67 to add new keys.
68 to add new keys.
68 """
69 """
69 def __init__(self, *args, **kwargs):
70 def __init__(self, *args, **kwargs):
70 dict.__init__(self)
71 dict.__init__(self)
71 md = {'msg_id' : None,
72 md = {'msg_id' : None,
72 'submitted' : None,
73 'submitted' : None,
73 'started' : None,
74 'started' : None,
74 'completed' : None,
75 'completed' : None,
75 'received' : None,
76 'received' : None,
76 'engine_uuid' : None,
77 'engine_uuid' : None,
77 'engine_id' : None,
78 'engine_id' : None,
78 'follow' : None,
79 'follow' : None,
79 'after' : None,
80 'after' : None,
80 'status' : None,
81 'status' : None,
81
82
82 'pyin' : None,
83 'pyin' : None,
83 'pyout' : None,
84 'pyout' : None,
84 'pyerr' : None,
85 'pyerr' : None,
85 'stdout' : '',
86 'stdout' : '',
86 'stderr' : '',
87 'stderr' : '',
87 }
88 }
88 self.update(md)
89 self.update(md)
89 self.update(dict(*args, **kwargs))
90 self.update(dict(*args, **kwargs))
90
91
91 def __getattr__(self, key):
92 def __getattr__(self, key):
92 """getattr aliased to getitem"""
93 """getattr aliased to getitem"""
93 if key in self.iterkeys():
94 if key in self.iterkeys():
94 return self[key]
95 return self[key]
95 else:
96 else:
96 raise AttributeError(key)
97 raise AttributeError(key)
97
98
98 def __setattr__(self, key, value):
99 def __setattr__(self, key, value):
99 """setattr aliased to setitem, with strict"""
100 """setattr aliased to setitem, with strict"""
100 if key in self.iterkeys():
101 if key in self.iterkeys():
101 self[key] = value
102 self[key] = value
102 else:
103 else:
103 raise AttributeError(key)
104 raise AttributeError(key)
104
105
105 def __setitem__(self, key, value):
106 def __setitem__(self, key, value):
106 """strict static key enforcement"""
107 """strict static key enforcement"""
107 if key in self.iterkeys():
108 if key in self.iterkeys():
108 dict.__setitem__(self, key, value)
109 dict.__setitem__(self, key, value)
109 else:
110 else:
110 raise KeyError(key)
111 raise KeyError(key)
111
112
112
113
113 class Client(HasTraits):
114 class Client(HasTraits):
114 """A semi-synchronous client to the IPython ZMQ cluster
115 """A semi-synchronous client to the IPython ZMQ cluster
115
116
116 Parameters
117 Parameters
117 ----------
118 ----------
118
119
119 url_or_file : bytes; zmq url or path to ipcontroller-client.json
120 url_or_file : bytes; zmq url or path to ipcontroller-client.json
120 Connection information for the Hub's registration. If a json connector
121 Connection information for the Hub's registration. If a json connector
121 file is given, then likely no further configuration is necessary.
122 file is given, then likely no further configuration is necessary.
122 [Default: use profile]
123 [Default: use profile]
123 profile : bytes
124 profile : bytes
124 The name of the Cluster profile to be used to find connector information.
125 The name of the Cluster profile to be used to find connector information.
125 [Default: 'default']
126 [Default: 'default']
126 context : zmq.Context
127 context : zmq.Context
127 Pass an existing zmq.Context instance, otherwise the client will create its own.
128 Pass an existing zmq.Context instance, otherwise the client will create its own.
128 debug : bool
129 debug : bool
129 flag for lots of message printing for debug purposes
130 flag for lots of message printing for debug purposes
130 timeout : int/float
131 timeout : int/float
131 time (in seconds) to wait for connection replies from the Hub
132 time (in seconds) to wait for connection replies from the Hub
132 [Default: 10]
133 [Default: 10]
133
134
134 #-------------- session related args ----------------
135 #-------------- session related args ----------------
135
136
136 config : Config object
137 config : Config object
137 If specified, this will be relayed to the Session for configuration
138 If specified, this will be relayed to the Session for configuration
138 username : str
139 username : str
139 set username for the session object
140 set username for the session object
140 packer : str (import_string) or callable
141 packer : str (import_string) or callable
141 Can be either the simple keyword 'json' or 'pickle', or an import_string to a
142 Can be either the simple keyword 'json' or 'pickle', or an import_string to a
142 function to serialize messages. Must support same input as
143 function to serialize messages. Must support same input as
143 JSON, and output must be bytes.
144 JSON, and output must be bytes.
144 You can pass a callable directly as `pack`
145 You can pass a callable directly as `pack`
145 unpacker : str (import_string) or callable
146 unpacker : str (import_string) or callable
146 The inverse of packer. Only necessary if packer is specified as *not* one
147 The inverse of packer. Only necessary if packer is specified as *not* one
147 of 'json' or 'pickle'.
148 of 'json' or 'pickle'.
148
149
149 #-------------- ssh related args ----------------
150 #-------------- ssh related args ----------------
150 # These are args for configuring the ssh tunnel to be used
151 # These are args for configuring the ssh tunnel to be used
151 # credentials are used to forward connections over ssh to the Controller
152 # credentials are used to forward connections over ssh to the Controller
152 # Note that the ip given in `addr` needs to be relative to sshserver
153 # Note that the ip given in `addr` needs to be relative to sshserver
153 # The most basic case is to leave addr as pointing to localhost (127.0.0.1),
154 # The most basic case is to leave addr as pointing to localhost (127.0.0.1),
154 # and set sshserver as the same machine the Controller is on. However,
155 # and set sshserver as the same machine the Controller is on. However,
155 # the only requirement is that sshserver is able to see the Controller
156 # the only requirement is that sshserver is able to see the Controller
156 # (i.e. is within the same trusted network).
157 # (i.e. is within the same trusted network).
157
158
158 sshserver : str
159 sshserver : str
159 A string of the form passed to ssh, i.e. 'server.tld' or 'user@server.tld:port'
160 A string of the form passed to ssh, i.e. 'server.tld' or 'user@server.tld:port'
160 If keyfile or password is specified, and this is not, it will default to
161 If keyfile or password is specified, and this is not, it will default to
161 the ip given in addr.
162 the ip given in addr.
162 sshkey : str; path to public ssh key file
163 sshkey : str; path to public ssh key file
163 This specifies a key to be used in ssh login, default None.
164 This specifies a key to be used in ssh login, default None.
164 Regular default ssh keys will be used without specifying this argument.
165 Regular default ssh keys will be used without specifying this argument.
165 password : str
166 password : str
166 Your ssh password to sshserver. Note that if this is left None,
167 Your ssh password to sshserver. Note that if this is left None,
167 you will be prompted for it if passwordless key based login is unavailable.
168 you will be prompted for it if passwordless key based login is unavailable.
168 paramiko : bool
169 paramiko : bool
169 flag for whether to use paramiko instead of shell ssh for tunneling.
170 flag for whether to use paramiko instead of shell ssh for tunneling.
170 [default: True on win32, False else]
171 [default: True on win32, False else]
171
172
172 ------- exec authentication args -------
173 ------- exec authentication args -------
173 If even localhost is untrusted, you can have some protection against
174 If even localhost is untrusted, you can have some protection against
174 unauthorized execution by signing messages with HMAC digests.
175 unauthorized execution by signing messages with HMAC digests.
175 Messages are still sent as cleartext, so if someone can snoop your
176 Messages are still sent as cleartext, so if someone can snoop your
176 loopback traffic this will not protect your privacy, but will prevent
177 loopback traffic this will not protect your privacy, but will prevent
177 unauthorized execution.
178 unauthorized execution.
178
179
179 exec_key : str
180 exec_key : str
180 an authentication key or file containing a key
181 an authentication key or file containing a key
181 default: None
182 default: None
182
183
183
184
184 Attributes
185 Attributes
185 ----------
186 ----------
186
187
187 ids : list of int engine IDs
188 ids : list of int engine IDs
188 requesting the ids attribute always synchronizes
189 requesting the ids attribute always synchronizes
189 the registration state. To request ids without synchronization,
190 the registration state. To request ids without synchronization,
190 use semi-private _ids attributes.
191 use semi-private _ids attributes.
191
192
192 history : list of msg_ids
193 history : list of msg_ids
193 a list of msg_ids, keeping track of all the execution
194 a list of msg_ids, keeping track of all the execution
194 messages you have submitted in order.
195 messages you have submitted in order.
195
196
196 outstanding : set of msg_ids
197 outstanding : set of msg_ids
197 a set of msg_ids that have been submitted, but whose
198 a set of msg_ids that have been submitted, but whose
198 results have not yet been received.
199 results have not yet been received.
199
200
200 results : dict
201 results : dict
201 a dict of all our results, keyed by msg_id
202 a dict of all our results, keyed by msg_id
202
203
203 block : bool
204 block : bool
204 determines default behavior when block not specified
205 determines default behavior when block not specified
205 in execution methods
206 in execution methods
206
207
207 Methods
208 Methods
208 -------
209 -------
209
210
210 spin
211 spin
211 flushes incoming results and registration state changes
212 flushes incoming results and registration state changes
212 control methods spin, and requesting `ids` also ensures up to date
213 control methods spin, and requesting `ids` also ensures up to date
213
214
214 wait
215 wait
215 wait on one or more msg_ids
216 wait on one or more msg_ids
216
217
217 execution methods
218 execution methods
218 apply
219 apply
219 legacy: execute, run
220 legacy: execute, run
220
221
221 data movement
222 data movement
222 push, pull, scatter, gather
223 push, pull, scatter, gather
223
224
224 query methods
225 query methods
225 queue_status, get_result, purge, result_status
226 queue_status, get_result, purge, result_status
226
227
227 control methods
228 control methods
228 abort, shutdown
229 abort, shutdown
229
230
230 """
231 """
231
232
232
233
233 block = Bool(False)
234 block = Bool(False)
234 outstanding = Set()
235 outstanding = Set()
235 results = Instance('collections.defaultdict', (dict,))
236 results = Instance('collections.defaultdict', (dict,))
236 metadata = Instance('collections.defaultdict', (Metadata,))
237 metadata = Instance('collections.defaultdict', (Metadata,))
237 history = List()
238 history = List()
238 debug = Bool(False)
239 debug = Bool(False)
239 profile=Unicode('default')
240 profile=Unicode('default')
240
241
241 _outstanding_dict = Instance('collections.defaultdict', (set,))
242 _outstanding_dict = Instance('collections.defaultdict', (set,))
242 _ids = List()
243 _ids = List()
243 _connected=Bool(False)
244 _connected=Bool(False)
244 _ssh=Bool(False)
245 _ssh=Bool(False)
245 _context = Instance('zmq.Context')
246 _context = Instance('zmq.Context')
246 _config = Dict()
247 _config = Dict()
247 _engines=Instance(util.ReverseDict, (), {})
248 _engines=Instance(util.ReverseDict, (), {})
248 # _hub_socket=Instance('zmq.Socket')
249 # _hub_socket=Instance('zmq.Socket')
249 _query_socket=Instance('zmq.Socket')
250 _query_socket=Instance('zmq.Socket')
250 _control_socket=Instance('zmq.Socket')
251 _control_socket=Instance('zmq.Socket')
251 _iopub_socket=Instance('zmq.Socket')
252 _iopub_socket=Instance('zmq.Socket')
252 _notification_socket=Instance('zmq.Socket')
253 _notification_socket=Instance('zmq.Socket')
253 _mux_socket=Instance('zmq.Socket')
254 _mux_socket=Instance('zmq.Socket')
254 _task_socket=Instance('zmq.Socket')
255 _task_socket=Instance('zmq.Socket')
255 _task_scheme=Unicode()
256 _task_scheme=Unicode()
256 _closed = False
257 _closed = False
257 _ignored_control_replies=Int(0)
258 _ignored_control_replies=Int(0)
258 _ignored_hub_replies=Int(0)
259 _ignored_hub_replies=Int(0)
259
260
260 def __init__(self, url_or_file=None, profile='default', profile_dir=None, ipython_dir=None,
261 def __init__(self, url_or_file=None, profile='default', profile_dir=None, ipython_dir=None,
261 context=None, debug=False, exec_key=None,
262 context=None, debug=False, exec_key=None,
262 sshserver=None, sshkey=None, password=None, paramiko=None,
263 sshserver=None, sshkey=None, password=None, paramiko=None,
263 timeout=10, **extra_args
264 timeout=10, **extra_args
264 ):
265 ):
265 super(Client, self).__init__(debug=debug, profile=profile)
266 super(Client, self).__init__(debug=debug, profile=profile)
266 if context is None:
267 if context is None:
267 context = zmq.Context.instance()
268 context = zmq.Context.instance()
268 self._context = context
269 self._context = context
269
270
270
271
271 self._setup_profile_dir(profile, profile_dir, ipython_dir)
272 self._setup_profile_dir(profile, profile_dir, ipython_dir)
272 if self._cd is not None:
273 if self._cd is not None:
273 if url_or_file is None:
274 if url_or_file is None:
274 url_or_file = pjoin(self._cd.security_dir, 'ipcontroller-client.json')
275 url_or_file = pjoin(self._cd.security_dir, 'ipcontroller-client.json')
275 assert url_or_file is not None, "I can't find enough information to connect to a hub!"\
276 assert url_or_file is not None, "I can't find enough information to connect to a hub!"\
276 " Please specify at least one of url_or_file or profile."
277 " Please specify at least one of url_or_file or profile."
277
278
278 try:
279 try:
279 util.validate_url(url_or_file)
280 util.validate_url(url_or_file)
280 except AssertionError:
281 except AssertionError:
281 if not os.path.exists(url_or_file):
282 if not os.path.exists(url_or_file):
282 if self._cd:
283 if self._cd:
283 url_or_file = os.path.join(self._cd.security_dir, url_or_file)
284 url_or_file = os.path.join(self._cd.security_dir, url_or_file)
284 assert os.path.exists(url_or_file), "Not a valid connection file or url: %r"%url_or_file
285 assert os.path.exists(url_or_file), "Not a valid connection file or url: %r"%url_or_file
285 with open(url_or_file) as f:
286 with open(url_or_file) as f:
286 cfg = json.loads(f.read())
287 cfg = json.loads(f.read())
287 else:
288 else:
288 cfg = {'url':url_or_file}
289 cfg = {'url':url_or_file}
289
290
290 # sync defaults from args, json:
291 # sync defaults from args, json:
291 if sshserver:
292 if sshserver:
292 cfg['ssh'] = sshserver
293 cfg['ssh'] = sshserver
293 if exec_key:
294 if exec_key:
294 cfg['exec_key'] = exec_key
295 cfg['exec_key'] = exec_key
295 exec_key = cfg['exec_key']
296 exec_key = cfg['exec_key']
296 sshserver=cfg['ssh']
297 sshserver=cfg['ssh']
297 url = cfg['url']
298 url = cfg['url']
298 location = cfg.setdefault('location', None)
299 location = cfg.setdefault('location', None)
299 cfg['url'] = util.disambiguate_url(cfg['url'], location)
300 cfg['url'] = util.disambiguate_url(cfg['url'], location)
300 url = cfg['url']
301 url = cfg['url']
301
302
302 self._config = cfg
303 self._config = cfg
303
304
304 self._ssh = bool(sshserver or sshkey or password)
305 self._ssh = bool(sshserver or sshkey or password)
305 if self._ssh and sshserver is None:
306 if self._ssh and sshserver is None:
306 # default to ssh via localhost
307 # default to ssh via localhost
307 sshserver = url.split('://')[1].split(':')[0]
308 sshserver = url.split('://')[1].split(':')[0]
308 if self._ssh and password is None:
309 if self._ssh and password is None:
309 if tunnel.try_passwordless_ssh(sshserver, sshkey, paramiko):
310 if tunnel.try_passwordless_ssh(sshserver, sshkey, paramiko):
310 password=False
311 password=False
311 else:
312 else:
312 password = getpass("SSH Password for %s: "%sshserver)
313 password = getpass("SSH Password for %s: "%sshserver)
313 ssh_kwargs = dict(keyfile=sshkey, password=password, paramiko=paramiko)
314 ssh_kwargs = dict(keyfile=sshkey, password=password, paramiko=paramiko)
314
315
315 # configure and construct the session
316 # configure and construct the session
316 if exec_key is not None:
317 if exec_key is not None:
317 if os.path.isfile(exec_key):
318 if os.path.isfile(exec_key):
318 extra_args['keyfile'] = exec_key
319 extra_args['keyfile'] = exec_key
319 else:
320 else:
320 extra_args['key'] = exec_key
321 extra_args['key'] = exec_key
321 self.session = Session(**extra_args)
322 self.session = Session(**extra_args)
322
323
323 self._query_socket = self._context.socket(zmq.XREQ)
324 self._query_socket = self._context.socket(zmq.XREQ)
324 self._query_socket.setsockopt(zmq.IDENTITY, self.session.session)
325 self._query_socket.setsockopt(zmq.IDENTITY, self.session.session)
325 if self._ssh:
326 if self._ssh:
326 tunnel.tunnel_connection(self._query_socket, url, sshserver, **ssh_kwargs)
327 tunnel.tunnel_connection(self._query_socket, url, sshserver, **ssh_kwargs)
327 else:
328 else:
328 self._query_socket.connect(url)
329 self._query_socket.connect(url)
329
330
330 self.session.debug = self.debug
331 self.session.debug = self.debug
331
332
332 self._notification_handlers = {'registration_notification' : self._register_engine,
333 self._notification_handlers = {'registration_notification' : self._register_engine,
333 'unregistration_notification' : self._unregister_engine,
334 'unregistration_notification' : self._unregister_engine,
334 'shutdown_notification' : lambda msg: self.close(),
335 'shutdown_notification' : lambda msg: self.close(),
335 }
336 }
336 self._queue_handlers = {'execute_reply' : self._handle_execute_reply,
337 self._queue_handlers = {'execute_reply' : self._handle_execute_reply,
337 'apply_reply' : self._handle_apply_reply}
338 'apply_reply' : self._handle_apply_reply}
338 self._connect(sshserver, ssh_kwargs, timeout)
339 self._connect(sshserver, ssh_kwargs, timeout)
339
340
340 def __del__(self):
341 def __del__(self):
341 """cleanup sockets, but _not_ context."""
342 """cleanup sockets, but _not_ context."""
342 self.close()
343 self.close()
343
344
344 def _setup_profile_dir(self, profile, profile_dir, ipython_dir):
345 def _setup_profile_dir(self, profile, profile_dir, ipython_dir):
345 if ipython_dir is None:
346 if ipython_dir is None:
346 ipython_dir = get_ipython_dir()
347 ipython_dir = get_ipython_dir()
347 if profile_dir is not None:
348 if profile_dir is not None:
348 try:
349 try:
349 self._cd = ProfileDir.find_profile_dir(profile_dir)
350 self._cd = ProfileDir.find_profile_dir(profile_dir)
350 return
351 return
351 except ProfileDirError:
352 except ProfileDirError:
352 pass
353 pass
353 elif profile is not None:
354 elif profile is not None:
354 try:
355 try:
355 self._cd = ProfileDir.find_profile_dir_by_name(
356 self._cd = ProfileDir.find_profile_dir_by_name(
356 ipython_dir, profile)
357 ipython_dir, profile)
357 return
358 return
358 except ProfileDirError:
359 except ProfileDirError:
359 pass
360 pass
360 self._cd = None
361 self._cd = None
361
362
362 def _update_engines(self, engines):
363 def _update_engines(self, engines):
363 """Update our engines dict and _ids from a dict of the form: {id:uuid}."""
364 """Update our engines dict and _ids from a dict of the form: {id:uuid}."""
364 for k,v in engines.iteritems():
365 for k,v in engines.iteritems():
365 eid = int(k)
366 eid = int(k)
366 self._engines[eid] = bytes(v) # force not unicode
367 self._engines[eid] = bytes(v) # force not unicode
367 self._ids.append(eid)
368 self._ids.append(eid)
368 self._ids = sorted(self._ids)
369 self._ids = sorted(self._ids)
369 if sorted(self._engines.keys()) != range(len(self._engines)) and \
370 if sorted(self._engines.keys()) != range(len(self._engines)) and \
370 self._task_scheme == 'pure' and self._task_socket:
371 self._task_scheme == 'pure' and self._task_socket:
371 self._stop_scheduling_tasks()
372 self._stop_scheduling_tasks()
372
373
373 def _stop_scheduling_tasks(self):
374 def _stop_scheduling_tasks(self):
374 """Stop scheduling tasks because an engine has been unregistered
375 """Stop scheduling tasks because an engine has been unregistered
375 from a pure ZMQ scheduler.
376 from a pure ZMQ scheduler.
376 """
377 """
377 self._task_socket.close()
378 self._task_socket.close()
378 self._task_socket = None
379 self._task_socket = None
379 msg = "An engine has been unregistered, and we are using pure " +\
380 msg = "An engine has been unregistered, and we are using pure " +\
380 "ZMQ task scheduling. Task farming will be disabled."
381 "ZMQ task scheduling. Task farming will be disabled."
381 if self.outstanding:
382 if self.outstanding:
382 msg += " If you were running tasks when this happened, " +\
383 msg += " If you were running tasks when this happened, " +\
383 "some `outstanding` msg_ids may never resolve."
384 "some `outstanding` msg_ids may never resolve."
384 warnings.warn(msg, RuntimeWarning)
385 warnings.warn(msg, RuntimeWarning)
385
386
386 def _build_targets(self, targets):
387 def _build_targets(self, targets):
387 """Turn valid target IDs or 'all' into two lists:
388 """Turn valid target IDs or 'all' into two lists:
388 (int_ids, uuids).
389 (int_ids, uuids).
389 """
390 """
390 if not self._ids:
391 if not self._ids:
391 # flush notification socket if no engines yet, just in case
392 # flush notification socket if no engines yet, just in case
392 if not self.ids:
393 if not self.ids:
393 raise error.NoEnginesRegistered("Can't build targets without any engines")
394 raise error.NoEnginesRegistered("Can't build targets without any engines")
394
395
395 if targets is None:
396 if targets is None:
396 targets = self._ids
397 targets = self._ids
397 elif isinstance(targets, str):
398 elif isinstance(targets, str):
398 if targets.lower() == 'all':
399 if targets.lower() == 'all':
399 targets = self._ids
400 targets = self._ids
400 else:
401 else:
401 raise TypeError("%r not valid str target, must be 'all'"%(targets))
402 raise TypeError("%r not valid str target, must be 'all'"%(targets))
402 elif isinstance(targets, int):
403 elif isinstance(targets, int):
403 if targets < 0:
404 if targets < 0:
404 targets = self.ids[targets]
405 targets = self.ids[targets]
405 if targets not in self._ids:
406 if targets not in self._ids:
406 raise IndexError("No such engine: %i"%targets)
407 raise IndexError("No such engine: %i"%targets)
407 targets = [targets]
408 targets = [targets]
408
409
409 if isinstance(targets, slice):
410 if isinstance(targets, slice):
410 indices = range(len(self._ids))[targets]
411 indices = range(len(self._ids))[targets]
411 ids = self.ids
412 ids = self.ids
412 targets = [ ids[i] for i in indices ]
413 targets = [ ids[i] for i in indices ]
413
414
414 if not isinstance(targets, (tuple, list, xrange)):
415 if not isinstance(targets, (tuple, list, xrange)):
415 raise TypeError("targets by int/slice/collection of ints only, not %s"%(type(targets)))
416 raise TypeError("targets by int/slice/collection of ints only, not %s"%(type(targets)))
416
417
417 return [self._engines[t] for t in targets], list(targets)
418 return [self._engines[t] for t in targets], list(targets)
418
419
419 def _connect(self, sshserver, ssh_kwargs, timeout):
420 def _connect(self, sshserver, ssh_kwargs, timeout):
420 """setup all our socket connections to the cluster. This is called from
421 """setup all our socket connections to the cluster. This is called from
421 __init__."""
422 __init__."""
422
423
423 # Maybe allow reconnecting?
424 # Maybe allow reconnecting?
424 if self._connected:
425 if self._connected:
425 return
426 return
426 self._connected=True
427 self._connected=True
427
428
428 def connect_socket(s, url):
429 def connect_socket(s, url):
429 url = util.disambiguate_url(url, self._config['location'])
430 url = util.disambiguate_url(url, self._config['location'])
430 if self._ssh:
431 if self._ssh:
431 return tunnel.tunnel_connection(s, url, sshserver, **ssh_kwargs)
432 return tunnel.tunnel_connection(s, url, sshserver, **ssh_kwargs)
432 else:
433 else:
433 return s.connect(url)
434 return s.connect(url)
434
435
435 self.session.send(self._query_socket, 'connection_request')
436 self.session.send(self._query_socket, 'connection_request')
436 r,w,x = zmq.select([self._query_socket],[],[], timeout)
437 r,w,x = zmq.select([self._query_socket],[],[], timeout)
437 if not r:
438 if not r:
438 raise error.TimeoutError("Hub connection request timed out")
439 raise error.TimeoutError("Hub connection request timed out")
439 idents,msg = self.session.recv(self._query_socket,mode=0)
440 idents,msg = self.session.recv(self._query_socket,mode=0)
440 if self.debug:
441 if self.debug:
441 pprint(msg)
442 pprint(msg)
442 msg = Message(msg)
443 msg = Message(msg)
443 content = msg.content
444 content = msg.content
444 self._config['registration'] = dict(content)
445 self._config['registration'] = dict(content)
445 if content.status == 'ok':
446 if content.status == 'ok':
446 if content.mux:
447 if content.mux:
447 self._mux_socket = self._context.socket(zmq.XREQ)
448 self._mux_socket = self._context.socket(zmq.XREQ)
448 self._mux_socket.setsockopt(zmq.IDENTITY, self.session.session)
449 self._mux_socket.setsockopt(zmq.IDENTITY, self.session.session)
449 connect_socket(self._mux_socket, content.mux)
450 connect_socket(self._mux_socket, content.mux)
450 if content.task:
451 if content.task:
451 self._task_scheme, task_addr = content.task
452 self._task_scheme, task_addr = content.task
452 self._task_socket = self._context.socket(zmq.XREQ)
453 self._task_socket = self._context.socket(zmq.XREQ)
453 self._task_socket.setsockopt(zmq.IDENTITY, self.session.session)
454 self._task_socket.setsockopt(zmq.IDENTITY, self.session.session)
454 connect_socket(self._task_socket, task_addr)
455 connect_socket(self._task_socket, task_addr)
455 if content.notification:
456 if content.notification:
456 self._notification_socket = self._context.socket(zmq.SUB)
457 self._notification_socket = self._context.socket(zmq.SUB)
457 connect_socket(self._notification_socket, content.notification)
458 connect_socket(self._notification_socket, content.notification)
458 self._notification_socket.setsockopt(zmq.SUBSCRIBE, b'')
459 self._notification_socket.setsockopt(zmq.SUBSCRIBE, b'')
459 # if content.query:
460 # if content.query:
460 # self._query_socket = self._context.socket(zmq.XREQ)
461 # self._query_socket = self._context.socket(zmq.XREQ)
461 # self._query_socket.setsockopt(zmq.IDENTITY, self.session.session)
462 # self._query_socket.setsockopt(zmq.IDENTITY, self.session.session)
462 # connect_socket(self._query_socket, content.query)
463 # connect_socket(self._query_socket, content.query)
463 if content.control:
464 if content.control:
464 self._control_socket = self._context.socket(zmq.XREQ)
465 self._control_socket = self._context.socket(zmq.XREQ)
465 self._control_socket.setsockopt(zmq.IDENTITY, self.session.session)
466 self._control_socket.setsockopt(zmq.IDENTITY, self.session.session)
466 connect_socket(self._control_socket, content.control)
467 connect_socket(self._control_socket, content.control)
467 if content.iopub:
468 if content.iopub:
468 self._iopub_socket = self._context.socket(zmq.SUB)
469 self._iopub_socket = self._context.socket(zmq.SUB)
469 self._iopub_socket.setsockopt(zmq.SUBSCRIBE, b'')
470 self._iopub_socket.setsockopt(zmq.SUBSCRIBE, b'')
470 self._iopub_socket.setsockopt(zmq.IDENTITY, self.session.session)
471 self._iopub_socket.setsockopt(zmq.IDENTITY, self.session.session)
471 connect_socket(self._iopub_socket, content.iopub)
472 connect_socket(self._iopub_socket, content.iopub)
472 self._update_engines(dict(content.engines))
473 self._update_engines(dict(content.engines))
473 else:
474 else:
474 self._connected = False
475 self._connected = False
475 raise Exception("Failed to connect!")
476 raise Exception("Failed to connect!")
476
477
477 #--------------------------------------------------------------------------
478 #--------------------------------------------------------------------------
478 # handlers and callbacks for incoming messages
479 # handlers and callbacks for incoming messages
479 #--------------------------------------------------------------------------
480 #--------------------------------------------------------------------------
480
481
481 def _unwrap_exception(self, content):
482 def _unwrap_exception(self, content):
482 """unwrap exception, and remap engine_id to int."""
483 """unwrap exception, and remap engine_id to int."""
483 e = error.unwrap_exception(content)
484 e = error.unwrap_exception(content)
484 # print e.traceback
485 # print e.traceback
485 if e.engine_info:
486 if e.engine_info:
486 e_uuid = e.engine_info['engine_uuid']
487 e_uuid = e.engine_info['engine_uuid']
487 eid = self._engines[e_uuid]
488 eid = self._engines[e_uuid]
488 e.engine_info['engine_id'] = eid
489 e.engine_info['engine_id'] = eid
489 return e
490 return e
490
491
491 def _extract_metadata(self, header, parent, content):
492 def _extract_metadata(self, header, parent, content):
492 md = {'msg_id' : parent['msg_id'],
493 md = {'msg_id' : parent['msg_id'],
493 'received' : datetime.now(),
494 'received' : datetime.now(),
494 'engine_uuid' : header.get('engine', None),
495 'engine_uuid' : header.get('engine', None),
495 'follow' : parent.get('follow', []),
496 'follow' : parent.get('follow', []),
496 'after' : parent.get('after', []),
497 'after' : parent.get('after', []),
497 'status' : content['status'],
498 'status' : content['status'],
498 }
499 }
499
500
500 if md['engine_uuid'] is not None:
501 if md['engine_uuid'] is not None:
501 md['engine_id'] = self._engines.get(md['engine_uuid'], None)
502 md['engine_id'] = self._engines.get(md['engine_uuid'], None)
502
503
503 if 'date' in parent:
504 if 'date' in parent:
504 md['submitted'] = parent['date']
505 md['submitted'] = parent['date']
505 if 'started' in header:
506 if 'started' in header:
506 md['started'] = header['started']
507 md['started'] = header['started']
507 if 'date' in header:
508 if 'date' in header:
508 md['completed'] = header['date']
509 md['completed'] = header['date']
509 return md
510 return md
510
511
511 def _register_engine(self, msg):
512 def _register_engine(self, msg):
512 """Register a new engine, and update our connection info."""
513 """Register a new engine, and update our connection info."""
513 content = msg['content']
514 content = msg['content']
514 eid = content['id']
515 eid = content['id']
515 d = {eid : content['queue']}
516 d = {eid : content['queue']}
516 self._update_engines(d)
517 self._update_engines(d)
517
518
518 def _unregister_engine(self, msg):
519 def _unregister_engine(self, msg):
519 """Unregister an engine that has died."""
520 """Unregister an engine that has died."""
520 content = msg['content']
521 content = msg['content']
521 eid = int(content['id'])
522 eid = int(content['id'])
522 if eid in self._ids:
523 if eid in self._ids:
523 self._ids.remove(eid)
524 self._ids.remove(eid)
524 uuid = self._engines.pop(eid)
525 uuid = self._engines.pop(eid)
525
526
526 self._handle_stranded_msgs(eid, uuid)
527 self._handle_stranded_msgs(eid, uuid)
527
528
528 if self._task_socket and self._task_scheme == 'pure':
529 if self._task_socket and self._task_scheme == 'pure':
529 self._stop_scheduling_tasks()
530 self._stop_scheduling_tasks()
530
531
531 def _handle_stranded_msgs(self, eid, uuid):
532 def _handle_stranded_msgs(self, eid, uuid):
532 """Handle messages known to be on an engine when the engine unregisters.
533 """Handle messages known to be on an engine when the engine unregisters.
533
534
534 It is possible that this will fire prematurely - that is, an engine will
535 It is possible that this will fire prematurely - that is, an engine will
535 go down after completing a result, and the client will be notified
536 go down after completing a result, and the client will be notified
536 of the unregistration and later receive the successful result.
537 of the unregistration and later receive the successful result.
537 """
538 """
538
539
539 outstanding = self._outstanding_dict[uuid]
540 outstanding = self._outstanding_dict[uuid]
540
541
541 for msg_id in list(outstanding):
542 for msg_id in list(outstanding):
542 if msg_id in self.results:
543 if msg_id in self.results:
543 # we already
544 # we already
544 continue
545 continue
545 try:
546 try:
546 raise error.EngineError("Engine %r died while running task %r"%(eid, msg_id))
547 raise error.EngineError("Engine %r died while running task %r"%(eid, msg_id))
547 except:
548 except:
548 content = error.wrap_exception()
549 content = error.wrap_exception()
549 # build a fake message:
550 # build a fake message:
550 parent = {}
551 parent = {}
551 header = {}
552 header = {}
552 parent['msg_id'] = msg_id
553 parent['msg_id'] = msg_id
553 header['engine'] = uuid
554 header['engine'] = uuid
554 header['date'] = datetime.now()
555 header['date'] = datetime.now()
555 msg = dict(parent_header=parent, header=header, content=content)
556 msg = dict(parent_header=parent, header=header, content=content)
556 self._handle_apply_reply(msg)
557 self._handle_apply_reply(msg)
557
558
558 def _handle_execute_reply(self, msg):
559 def _handle_execute_reply(self, msg):
559 """Save the reply to an execute_request into our results.
560 """Save the reply to an execute_request into our results.
560
561
561 execute messages are never actually used. apply is used instead.
562 execute messages are never actually used. apply is used instead.
562 """
563 """
563
564
564 parent = msg['parent_header']
565 parent = msg['parent_header']
565 msg_id = parent['msg_id']
566 msg_id = parent['msg_id']
566 if msg_id not in self.outstanding:
567 if msg_id not in self.outstanding:
567 if msg_id in self.history:
568 if msg_id in self.history:
568 print ("got stale result: %s"%msg_id)
569 print ("got stale result: %s"%msg_id)
569 else:
570 else:
570 print ("got unknown result: %s"%msg_id)
571 print ("got unknown result: %s"%msg_id)
571 else:
572 else:
572 self.outstanding.remove(msg_id)
573 self.outstanding.remove(msg_id)
573 self.results[msg_id] = self._unwrap_exception(msg['content'])
574 self.results[msg_id] = self._unwrap_exception(msg['content'])
574
575
575 def _handle_apply_reply(self, msg):
576 def _handle_apply_reply(self, msg):
576 """Save the reply to an apply_request into our results."""
577 """Save the reply to an apply_request into our results."""
577 parent = msg['parent_header']
578 parent = msg['parent_header']
578 msg_id = parent['msg_id']
579 msg_id = parent['msg_id']
579 if msg_id not in self.outstanding:
580 if msg_id not in self.outstanding:
580 if msg_id in self.history:
581 if msg_id in self.history:
581 print ("got stale result: %s"%msg_id)
582 print ("got stale result: %s"%msg_id)
582 print self.results[msg_id]
583 print self.results[msg_id]
583 print msg
584 print msg
584 else:
585 else:
585 print ("got unknown result: %s"%msg_id)
586 print ("got unknown result: %s"%msg_id)
586 else:
587 else:
587 self.outstanding.remove(msg_id)
588 self.outstanding.remove(msg_id)
588 content = msg['content']
589 content = msg['content']
589 header = msg['header']
590 header = msg['header']
590
591
591 # construct metadata:
592 # construct metadata:
592 md = self.metadata[msg_id]
593 md = self.metadata[msg_id]
593 md.update(self._extract_metadata(header, parent, content))
594 md.update(self._extract_metadata(header, parent, content))
594 # is this redundant?
595 # is this redundant?
595 self.metadata[msg_id] = md
596 self.metadata[msg_id] = md
596
597
597 e_outstanding = self._outstanding_dict[md['engine_uuid']]
598 e_outstanding = self._outstanding_dict[md['engine_uuid']]
598 if msg_id in e_outstanding:
599 if msg_id in e_outstanding:
599 e_outstanding.remove(msg_id)
600 e_outstanding.remove(msg_id)
600
601
601 # construct result:
602 # construct result:
602 if content['status'] == 'ok':
603 if content['status'] == 'ok':
603 self.results[msg_id] = util.unserialize_object(msg['buffers'])[0]
604 self.results[msg_id] = util.unserialize_object(msg['buffers'])[0]
604 elif content['status'] == 'aborted':
605 elif content['status'] == 'aborted':
605 self.results[msg_id] = error.TaskAborted(msg_id)
606 self.results[msg_id] = error.TaskAborted(msg_id)
606 elif content['status'] == 'resubmitted':
607 elif content['status'] == 'resubmitted':
607 # TODO: handle resubmission
608 # TODO: handle resubmission
608 pass
609 pass
609 else:
610 else:
610 self.results[msg_id] = self._unwrap_exception(content)
611 self.results[msg_id] = self._unwrap_exception(content)
611
612
612 def _flush_notifications(self):
613 def _flush_notifications(self):
613 """Flush notifications of engine registrations waiting
614 """Flush notifications of engine registrations waiting
614 in ZMQ queue."""
615 in ZMQ queue."""
615 idents,msg = self.session.recv(self._notification_socket, mode=zmq.NOBLOCK)
616 idents,msg = self.session.recv(self._notification_socket, mode=zmq.NOBLOCK)
616 while msg is not None:
617 while msg is not None:
617 if self.debug:
618 if self.debug:
618 pprint(msg)
619 pprint(msg)
619 msg_type = msg['msg_type']
620 msg_type = msg['msg_type']
620 handler = self._notification_handlers.get(msg_type, None)
621 handler = self._notification_handlers.get(msg_type, None)
621 if handler is None:
622 if handler is None:
622 raise Exception("Unhandled message type: %s"%msg.msg_type)
623 raise Exception("Unhandled message type: %s"%msg.msg_type)
623 else:
624 else:
624 handler(msg)
625 handler(msg)
625 idents,msg = self.session.recv(self._notification_socket, mode=zmq.NOBLOCK)
626 idents,msg = self.session.recv(self._notification_socket, mode=zmq.NOBLOCK)
626
627
627 def _flush_results(self, sock):
628 def _flush_results(self, sock):
628 """Flush task or queue results waiting in ZMQ queue."""
629 """Flush task or queue results waiting in ZMQ queue."""
629 idents,msg = self.session.recv(sock, mode=zmq.NOBLOCK)
630 idents,msg = self.session.recv(sock, mode=zmq.NOBLOCK)
630 while msg is not None:
631 while msg is not None:
631 if self.debug:
632 if self.debug:
632 pprint(msg)
633 pprint(msg)
633 msg_type = msg['msg_type']
634 msg_type = msg['msg_type']
634 handler = self._queue_handlers.get(msg_type, None)
635 handler = self._queue_handlers.get(msg_type, None)
635 if handler is None:
636 if handler is None:
636 raise Exception("Unhandled message type: %s"%msg.msg_type)
637 raise Exception("Unhandled message type: %s"%msg.msg_type)
637 else:
638 else:
638 handler(msg)
639 handler(msg)
639 idents,msg = self.session.recv(sock, mode=zmq.NOBLOCK)
640 idents,msg = self.session.recv(sock, mode=zmq.NOBLOCK)
640
641
641 def _flush_control(self, sock):
642 def _flush_control(self, sock):
642 """Flush replies from the control channel waiting
643 """Flush replies from the control channel waiting
643 in the ZMQ queue.
644 in the ZMQ queue.
644
645
645 Currently: ignore them."""
646 Currently: ignore them."""
646 if self._ignored_control_replies <= 0:
647 if self._ignored_control_replies <= 0:
647 return
648 return
648 idents,msg = self.session.recv(sock, mode=zmq.NOBLOCK)
649 idents,msg = self.session.recv(sock, mode=zmq.NOBLOCK)
649 while msg is not None:
650 while msg is not None:
650 self._ignored_control_replies -= 1
651 self._ignored_control_replies -= 1
651 if self.debug:
652 if self.debug:
652 pprint(msg)
653 pprint(msg)
653 idents,msg = self.session.recv(sock, mode=zmq.NOBLOCK)
654 idents,msg = self.session.recv(sock, mode=zmq.NOBLOCK)
654
655
655 def _flush_ignored_control(self):
656 def _flush_ignored_control(self):
656 """flush ignored control replies"""
657 """flush ignored control replies"""
657 while self._ignored_control_replies > 0:
658 while self._ignored_control_replies > 0:
658 self.session.recv(self._control_socket)
659 self.session.recv(self._control_socket)
659 self._ignored_control_replies -= 1
660 self._ignored_control_replies -= 1
660
661
661 def _flush_ignored_hub_replies(self):
662 def _flush_ignored_hub_replies(self):
662 ident,msg = self.session.recv(self._query_socket, mode=zmq.NOBLOCK)
663 ident,msg = self.session.recv(self._query_socket, mode=zmq.NOBLOCK)
663 while msg is not None:
664 while msg is not None:
664 ident,msg = self.session.recv(self._query_socket, mode=zmq.NOBLOCK)
665 ident,msg = self.session.recv(self._query_socket, mode=zmq.NOBLOCK)
665
666
666 def _flush_iopub(self, sock):
667 def _flush_iopub(self, sock):
667 """Flush replies from the iopub channel waiting
668 """Flush replies from the iopub channel waiting
668 in the ZMQ queue.
669 in the ZMQ queue.
669 """
670 """
670 idents,msg = self.session.recv(sock, mode=zmq.NOBLOCK)
671 idents,msg = self.session.recv(sock, mode=zmq.NOBLOCK)
671 while msg is not None:
672 while msg is not None:
672 if self.debug:
673 if self.debug:
673 pprint(msg)
674 pprint(msg)
674 parent = msg['parent_header']
675 parent = msg['parent_header']
675 msg_id = parent['msg_id']
676 msg_id = parent['msg_id']
676 content = msg['content']
677 content = msg['content']
677 header = msg['header']
678 header = msg['header']
678 msg_type = msg['msg_type']
679 msg_type = msg['msg_type']
679
680
680 # init metadata:
681 # init metadata:
681 md = self.metadata[msg_id]
682 md = self.metadata[msg_id]
682
683
683 if msg_type == 'stream':
684 if msg_type == 'stream':
684 name = content['name']
685 name = content['name']
685 s = md[name] or ''
686 s = md[name] or ''
686 md[name] = s + content['data']
687 md[name] = s + content['data']
687 elif msg_type == 'pyerr':
688 elif msg_type == 'pyerr':
688 md.update({'pyerr' : self._unwrap_exception(content)})
689 md.update({'pyerr' : self._unwrap_exception(content)})
689 elif msg_type == 'pyin':
690 elif msg_type == 'pyin':
690 md.update({'pyin' : content['code']})
691 md.update({'pyin' : content['code']})
691 else:
692 else:
692 md.update({msg_type : content.get('data', '')})
693 md.update({msg_type : content.get('data', '')})
693
694
694 # reduntant?
695 # reduntant?
695 self.metadata[msg_id] = md
696 self.metadata[msg_id] = md
696
697
697 idents,msg = self.session.recv(sock, mode=zmq.NOBLOCK)
698 idents,msg = self.session.recv(sock, mode=zmq.NOBLOCK)
698
699
699 #--------------------------------------------------------------------------
700 #--------------------------------------------------------------------------
700 # len, getitem
701 # len, getitem
701 #--------------------------------------------------------------------------
702 #--------------------------------------------------------------------------
702
703
703 def __len__(self):
704 def __len__(self):
704 """len(client) returns # of engines."""
705 """len(client) returns # of engines."""
705 return len(self.ids)
706 return len(self.ids)
706
707
707 def __getitem__(self, key):
708 def __getitem__(self, key):
708 """index access returns DirectView multiplexer objects
709 """index access returns DirectView multiplexer objects
709
710
710 Must be int, slice, or list/tuple/xrange of ints"""
711 Must be int, slice, or list/tuple/xrange of ints"""
711 if not isinstance(key, (int, slice, tuple, list, xrange)):
712 if not isinstance(key, (int, slice, tuple, list, xrange)):
712 raise TypeError("key by int/slice/iterable of ints only, not %s"%(type(key)))
713 raise TypeError("key by int/slice/iterable of ints only, not %s"%(type(key)))
713 else:
714 else:
714 return self.direct_view(key)
715 return self.direct_view(key)
715
716
716 #--------------------------------------------------------------------------
717 #--------------------------------------------------------------------------
717 # Begin public methods
718 # Begin public methods
718 #--------------------------------------------------------------------------
719 #--------------------------------------------------------------------------
719
720
720 @property
721 @property
721 def ids(self):
722 def ids(self):
722 """Always up-to-date ids property."""
723 """Always up-to-date ids property."""
723 self._flush_notifications()
724 self._flush_notifications()
724 # always copy:
725 # always copy:
725 return list(self._ids)
726 return list(self._ids)
726
727
727 def close(self):
728 def close(self):
728 if self._closed:
729 if self._closed:
729 return
730 return
730 snames = filter(lambda n: n.endswith('socket'), dir(self))
731 snames = filter(lambda n: n.endswith('socket'), dir(self))
731 for socket in map(lambda name: getattr(self, name), snames):
732 for socket in map(lambda name: getattr(self, name), snames):
732 if isinstance(socket, zmq.Socket) and not socket.closed:
733 if isinstance(socket, zmq.Socket) and not socket.closed:
733 socket.close()
734 socket.close()
734 self._closed = True
735 self._closed = True
735
736
736 def spin(self):
737 def spin(self):
737 """Flush any registration notifications and execution results
738 """Flush any registration notifications and execution results
738 waiting in the ZMQ queue.
739 waiting in the ZMQ queue.
739 """
740 """
740 if self._notification_socket:
741 if self._notification_socket:
741 self._flush_notifications()
742 self._flush_notifications()
742 if self._mux_socket:
743 if self._mux_socket:
743 self._flush_results(self._mux_socket)
744 self._flush_results(self._mux_socket)
744 if self._task_socket:
745 if self._task_socket:
745 self._flush_results(self._task_socket)
746 self._flush_results(self._task_socket)
746 if self._control_socket:
747 if self._control_socket:
747 self._flush_control(self._control_socket)
748 self._flush_control(self._control_socket)
748 if self._iopub_socket:
749 if self._iopub_socket:
749 self._flush_iopub(self._iopub_socket)
750 self._flush_iopub(self._iopub_socket)
750 if self._query_socket:
751 if self._query_socket:
751 self._flush_ignored_hub_replies()
752 self._flush_ignored_hub_replies()
752
753
753 def wait(self, jobs=None, timeout=-1):
754 def wait(self, jobs=None, timeout=-1):
754 """waits on one or more `jobs`, for up to `timeout` seconds.
755 """waits on one or more `jobs`, for up to `timeout` seconds.
755
756
756 Parameters
757 Parameters
757 ----------
758 ----------
758
759
759 jobs : int, str, or list of ints and/or strs, or one or more AsyncResult objects
760 jobs : int, str, or list of ints and/or strs, or one or more AsyncResult objects
760 ints are indices to self.history
761 ints are indices to self.history
761 strs are msg_ids
762 strs are msg_ids
762 default: wait on all outstanding messages
763 default: wait on all outstanding messages
763 timeout : float
764 timeout : float
764 a time in seconds, after which to give up.
765 a time in seconds, after which to give up.
765 default is -1, which means no timeout
766 default is -1, which means no timeout
766
767
767 Returns
768 Returns
768 -------
769 -------
769
770
770 True : when all msg_ids are done
771 True : when all msg_ids are done
771 False : timeout reached, some msg_ids still outstanding
772 False : timeout reached, some msg_ids still outstanding
772 """
773 """
773 tic = time.time()
774 tic = time.time()
774 if jobs is None:
775 if jobs is None:
775 theids = self.outstanding
776 theids = self.outstanding
776 else:
777 else:
777 if isinstance(jobs, (int, str, AsyncResult)):
778 if isinstance(jobs, (int, str, AsyncResult)):
778 jobs = [jobs]
779 jobs = [jobs]
779 theids = set()
780 theids = set()
780 for job in jobs:
781 for job in jobs:
781 if isinstance(job, int):
782 if isinstance(job, int):
782 # index access
783 # index access
783 job = self.history[job]
784 job = self.history[job]
784 elif isinstance(job, AsyncResult):
785 elif isinstance(job, AsyncResult):
785 map(theids.add, job.msg_ids)
786 map(theids.add, job.msg_ids)
786 continue
787 continue
787 theids.add(job)
788 theids.add(job)
788 if not theids.intersection(self.outstanding):
789 if not theids.intersection(self.outstanding):
789 return True
790 return True
790 self.spin()
791 self.spin()
791 while theids.intersection(self.outstanding):
792 while theids.intersection(self.outstanding):
792 if timeout >= 0 and ( time.time()-tic ) > timeout:
793 if timeout >= 0 and ( time.time()-tic ) > timeout:
793 break
794 break
794 time.sleep(1e-3)
795 time.sleep(1e-3)
795 self.spin()
796 self.spin()
796 return len(theids.intersection(self.outstanding)) == 0
797 return len(theids.intersection(self.outstanding)) == 0
797
798
798 #--------------------------------------------------------------------------
799 #--------------------------------------------------------------------------
799 # Control methods
800 # Control methods
800 #--------------------------------------------------------------------------
801 #--------------------------------------------------------------------------
801
802
802 @spin_first
803 @spin_first
803 def clear(self, targets=None, block=None):
804 def clear(self, targets=None, block=None):
804 """Clear the namespace in target(s)."""
805 """Clear the namespace in target(s)."""
805 block = self.block if block is None else block
806 block = self.block if block is None else block
806 targets = self._build_targets(targets)[0]
807 targets = self._build_targets(targets)[0]
807 for t in targets:
808 for t in targets:
808 self.session.send(self._control_socket, 'clear_request', content={}, ident=t)
809 self.session.send(self._control_socket, 'clear_request', content={}, ident=t)
809 error = False
810 error = False
810 if block:
811 if block:
811 self._flush_ignored_control()
812 self._flush_ignored_control()
812 for i in range(len(targets)):
813 for i in range(len(targets)):
813 idents,msg = self.session.recv(self._control_socket,0)
814 idents,msg = self.session.recv(self._control_socket,0)
814 if self.debug:
815 if self.debug:
815 pprint(msg)
816 pprint(msg)
816 if msg['content']['status'] != 'ok':
817 if msg['content']['status'] != 'ok':
817 error = self._unwrap_exception(msg['content'])
818 error = self._unwrap_exception(msg['content'])
818 else:
819 else:
819 self._ignored_control_replies += len(targets)
820 self._ignored_control_replies += len(targets)
820 if error:
821 if error:
821 raise error
822 raise error
822
823
823
824
824 @spin_first
825 @spin_first
825 def abort(self, jobs=None, targets=None, block=None):
826 def abort(self, jobs=None, targets=None, block=None):
826 """Abort specific jobs from the execution queues of target(s).
827 """Abort specific jobs from the execution queues of target(s).
827
828
828 This is a mechanism to prevent jobs that have already been submitted
829 This is a mechanism to prevent jobs that have already been submitted
829 from executing.
830 from executing.
830
831
831 Parameters
832 Parameters
832 ----------
833 ----------
833
834
834 jobs : msg_id, list of msg_ids, or AsyncResult
835 jobs : msg_id, list of msg_ids, or AsyncResult
835 The jobs to be aborted
836 The jobs to be aborted
836
837
837
838
838 """
839 """
839 block = self.block if block is None else block
840 block = self.block if block is None else block
840 targets = self._build_targets(targets)[0]
841 targets = self._build_targets(targets)[0]
841 msg_ids = []
842 msg_ids = []
842 if isinstance(jobs, (basestring,AsyncResult)):
843 if isinstance(jobs, (basestring,AsyncResult)):
843 jobs = [jobs]
844 jobs = [jobs]
844 bad_ids = filter(lambda obj: not isinstance(obj, (basestring, AsyncResult)), jobs)
845 bad_ids = filter(lambda obj: not isinstance(obj, (basestring, AsyncResult)), jobs)
845 if bad_ids:
846 if bad_ids:
846 raise TypeError("Invalid msg_id type %r, expected str or AsyncResult"%bad_ids[0])
847 raise TypeError("Invalid msg_id type %r, expected str or AsyncResult"%bad_ids[0])
847 for j in jobs:
848 for j in jobs:
848 if isinstance(j, AsyncResult):
849 if isinstance(j, AsyncResult):
849 msg_ids.extend(j.msg_ids)
850 msg_ids.extend(j.msg_ids)
850 else:
851 else:
851 msg_ids.append(j)
852 msg_ids.append(j)
852 content = dict(msg_ids=msg_ids)
853 content = dict(msg_ids=msg_ids)
853 for t in targets:
854 for t in targets:
854 self.session.send(self._control_socket, 'abort_request',
855 self.session.send(self._control_socket, 'abort_request',
855 content=content, ident=t)
856 content=content, ident=t)
856 error = False
857 error = False
857 if block:
858 if block:
858 self._flush_ignored_control()
859 self._flush_ignored_control()
859 for i in range(len(targets)):
860 for i in range(len(targets)):
860 idents,msg = self.session.recv(self._control_socket,0)
861 idents,msg = self.session.recv(self._control_socket,0)
861 if self.debug:
862 if self.debug:
862 pprint(msg)
863 pprint(msg)
863 if msg['content']['status'] != 'ok':
864 if msg['content']['status'] != 'ok':
864 error = self._unwrap_exception(msg['content'])
865 error = self._unwrap_exception(msg['content'])
865 else:
866 else:
866 self._ignored_control_replies += len(targets)
867 self._ignored_control_replies += len(targets)
867 if error:
868 if error:
868 raise error
869 raise error
869
870
870 @spin_first
871 @spin_first
871 def shutdown(self, targets=None, restart=False, hub=False, block=None):
872 def shutdown(self, targets=None, restart=False, hub=False, block=None):
872 """Terminates one or more engine processes, optionally including the hub."""
873 """Terminates one or more engine processes, optionally including the hub."""
873 block = self.block if block is None else block
874 block = self.block if block is None else block
874 if hub:
875 if hub:
875 targets = 'all'
876 targets = 'all'
876 targets = self._build_targets(targets)[0]
877 targets = self._build_targets(targets)[0]
877 for t in targets:
878 for t in targets:
878 self.session.send(self._control_socket, 'shutdown_request',
879 self.session.send(self._control_socket, 'shutdown_request',
879 content={'restart':restart},ident=t)
880 content={'restart':restart},ident=t)
880 error = False
881 error = False
881 if block or hub:
882 if block or hub:
882 self._flush_ignored_control()
883 self._flush_ignored_control()
883 for i in range(len(targets)):
884 for i in range(len(targets)):
884 idents,msg = self.session.recv(self._control_socket, 0)
885 idents,msg = self.session.recv(self._control_socket, 0)
885 if self.debug:
886 if self.debug:
886 pprint(msg)
887 pprint(msg)
887 if msg['content']['status'] != 'ok':
888 if msg['content']['status'] != 'ok':
888 error = self._unwrap_exception(msg['content'])
889 error = self._unwrap_exception(msg['content'])
889 else:
890 else:
890 self._ignored_control_replies += len(targets)
891 self._ignored_control_replies += len(targets)
891
892
892 if hub:
893 if hub:
893 time.sleep(0.25)
894 time.sleep(0.25)
894 self.session.send(self._query_socket, 'shutdown_request')
895 self.session.send(self._query_socket, 'shutdown_request')
895 idents,msg = self.session.recv(self._query_socket, 0)
896 idents,msg = self.session.recv(self._query_socket, 0)
896 if self.debug:
897 if self.debug:
897 pprint(msg)
898 pprint(msg)
898 if msg['content']['status'] != 'ok':
899 if msg['content']['status'] != 'ok':
899 error = self._unwrap_exception(msg['content'])
900 error = self._unwrap_exception(msg['content'])
900
901
901 if error:
902 if error:
902 raise error
903 raise error
903
904
904 #--------------------------------------------------------------------------
905 #--------------------------------------------------------------------------
905 # Execution related methods
906 # Execution related methods
906 #--------------------------------------------------------------------------
907 #--------------------------------------------------------------------------
907
908
908 def _maybe_raise(self, result):
909 def _maybe_raise(self, result):
909 """wrapper for maybe raising an exception if apply failed."""
910 """wrapper for maybe raising an exception if apply failed."""
910 if isinstance(result, error.RemoteError):
911 if isinstance(result, error.RemoteError):
911 raise result
912 raise result
912
913
913 return result
914 return result
914
915
915 def send_apply_message(self, socket, f, args=None, kwargs=None, subheader=None, track=False,
916 def send_apply_message(self, socket, f, args=None, kwargs=None, subheader=None, track=False,
916 ident=None):
917 ident=None):
917 """construct and send an apply message via a socket.
918 """construct and send an apply message via a socket.
918
919
919 This is the principal method with which all engine execution is performed by views.
920 This is the principal method with which all engine execution is performed by views.
920 """
921 """
921
922
922 assert not self._closed, "cannot use me anymore, I'm closed!"
923 assert not self._closed, "cannot use me anymore, I'm closed!"
923 # defaults:
924 # defaults:
924 args = args if args is not None else []
925 args = args if args is not None else []
925 kwargs = kwargs if kwargs is not None else {}
926 kwargs = kwargs if kwargs is not None else {}
926 subheader = subheader if subheader is not None else {}
927 subheader = subheader if subheader is not None else {}
927
928
928 # validate arguments
929 # validate arguments
929 if not callable(f):
930 if not callable(f):
930 raise TypeError("f must be callable, not %s"%type(f))
931 raise TypeError("f must be callable, not %s"%type(f))
931 if not isinstance(args, (tuple, list)):
932 if not isinstance(args, (tuple, list)):
932 raise TypeError("args must be tuple or list, not %s"%type(args))
933 raise TypeError("args must be tuple or list, not %s"%type(args))
933 if not isinstance(kwargs, dict):
934 if not isinstance(kwargs, dict):
934 raise TypeError("kwargs must be dict, not %s"%type(kwargs))
935 raise TypeError("kwargs must be dict, not %s"%type(kwargs))
935 if not isinstance(subheader, dict):
936 if not isinstance(subheader, dict):
936 raise TypeError("subheader must be dict, not %s"%type(subheader))
937 raise TypeError("subheader must be dict, not %s"%type(subheader))
937
938
938 bufs = util.pack_apply_message(f,args,kwargs)
939 bufs = util.pack_apply_message(f,args,kwargs)
939
940
940 msg = self.session.send(socket, "apply_request", buffers=bufs, ident=ident,
941 msg = self.session.send(socket, "apply_request", buffers=bufs, ident=ident,
941 subheader=subheader, track=track)
942 subheader=subheader, track=track)
942
943
943 msg_id = msg['msg_id']
944 msg_id = msg['msg_id']
944 self.outstanding.add(msg_id)
945 self.outstanding.add(msg_id)
945 if ident:
946 if ident:
946 # possibly routed to a specific engine
947 # possibly routed to a specific engine
947 if isinstance(ident, list):
948 if isinstance(ident, list):
948 ident = ident[-1]
949 ident = ident[-1]
949 if ident in self._engines.values():
950 if ident in self._engines.values():
950 # save for later, in case of engine death
951 # save for later, in case of engine death
951 self._outstanding_dict[ident].add(msg_id)
952 self._outstanding_dict[ident].add(msg_id)
952 self.history.append(msg_id)
953 self.history.append(msg_id)
953 self.metadata[msg_id]['submitted'] = datetime.now()
954 self.metadata[msg_id]['submitted'] = datetime.now()
954
955
955 return msg
956 return msg
956
957
957 #--------------------------------------------------------------------------
958 #--------------------------------------------------------------------------
958 # construct a View object
959 # construct a View object
959 #--------------------------------------------------------------------------
960 #--------------------------------------------------------------------------
960
961
961 def load_balanced_view(self, targets=None):
962 def load_balanced_view(self, targets=None):
962 """construct a DirectView object.
963 """construct a DirectView object.
963
964
964 If no arguments are specified, create a LoadBalancedView
965 If no arguments are specified, create a LoadBalancedView
965 using all engines.
966 using all engines.
966
967
967 Parameters
968 Parameters
968 ----------
969 ----------
969
970
970 targets: list,slice,int,etc. [default: use all engines]
971 targets: list,slice,int,etc. [default: use all engines]
971 The subset of engines across which to load-balance
972 The subset of engines across which to load-balance
972 """
973 """
973 if targets is not None:
974 if targets is not None:
974 targets = self._build_targets(targets)[1]
975 targets = self._build_targets(targets)[1]
975 return LoadBalancedView(client=self, socket=self._task_socket, targets=targets)
976 return LoadBalancedView(client=self, socket=self._task_socket, targets=targets)
976
977
977 def direct_view(self, targets='all'):
978 def direct_view(self, targets='all'):
978 """construct a DirectView object.
979 """construct a DirectView object.
979
980
980 If no targets are specified, create a DirectView
981 If no targets are specified, create a DirectView
981 using all engines.
982 using all engines.
982
983
983 Parameters
984 Parameters
984 ----------
985 ----------
985
986
986 targets: list,slice,int,etc. [default: use all engines]
987 targets: list,slice,int,etc. [default: use all engines]
987 The engines to use for the View
988 The engines to use for the View
988 """
989 """
989 single = isinstance(targets, int)
990 single = isinstance(targets, int)
990 targets = self._build_targets(targets)[1]
991 targets = self._build_targets(targets)[1]
991 if single:
992 if single:
992 targets = targets[0]
993 targets = targets[0]
993 return DirectView(client=self, socket=self._mux_socket, targets=targets)
994 return DirectView(client=self, socket=self._mux_socket, targets=targets)
994
995
995 #--------------------------------------------------------------------------
996 #--------------------------------------------------------------------------
996 # Query methods
997 # Query methods
997 #--------------------------------------------------------------------------
998 #--------------------------------------------------------------------------
998
999
999 @spin_first
1000 @spin_first
1000 def get_result(self, indices_or_msg_ids=None, block=None):
1001 def get_result(self, indices_or_msg_ids=None, block=None):
1001 """Retrieve a result by msg_id or history index, wrapped in an AsyncResult object.
1002 """Retrieve a result by msg_id or history index, wrapped in an AsyncResult object.
1002
1003
1003 If the client already has the results, no request to the Hub will be made.
1004 If the client already has the results, no request to the Hub will be made.
1004
1005
1005 This is a convenient way to construct AsyncResult objects, which are wrappers
1006 This is a convenient way to construct AsyncResult objects, which are wrappers
1006 that include metadata about execution, and allow for awaiting results that
1007 that include metadata about execution, and allow for awaiting results that
1007 were not submitted by this Client.
1008 were not submitted by this Client.
1008
1009
1009 It can also be a convenient way to retrieve the metadata associated with
1010 It can also be a convenient way to retrieve the metadata associated with
1010 blocking execution, since it always retrieves
1011 blocking execution, since it always retrieves
1011
1012
1012 Examples
1013 Examples
1013 --------
1014 --------
1014 ::
1015 ::
1015
1016
1016 In [10]: r = client.apply()
1017 In [10]: r = client.apply()
1017
1018
1018 Parameters
1019 Parameters
1019 ----------
1020 ----------
1020
1021
1021 indices_or_msg_ids : integer history index, str msg_id, or list of either
1022 indices_or_msg_ids : integer history index, str msg_id, or list of either
1022 The indices or msg_ids of indices to be retrieved
1023 The indices or msg_ids of indices to be retrieved
1023
1024
1024 block : bool
1025 block : bool
1025 Whether to wait for the result to be done
1026 Whether to wait for the result to be done
1026
1027
1027 Returns
1028 Returns
1028 -------
1029 -------
1029
1030
1030 AsyncResult
1031 AsyncResult
1031 A single AsyncResult object will always be returned.
1032 A single AsyncResult object will always be returned.
1032
1033
1033 AsyncHubResult
1034 AsyncHubResult
1034 A subclass of AsyncResult that retrieves results from the Hub
1035 A subclass of AsyncResult that retrieves results from the Hub
1035
1036
1036 """
1037 """
1037 block = self.block if block is None else block
1038 block = self.block if block is None else block
1038 if indices_or_msg_ids is None:
1039 if indices_or_msg_ids is None:
1039 indices_or_msg_ids = -1
1040 indices_or_msg_ids = -1
1040
1041
1041 if not isinstance(indices_or_msg_ids, (list,tuple)):
1042 if not isinstance(indices_or_msg_ids, (list,tuple)):
1042 indices_or_msg_ids = [indices_or_msg_ids]
1043 indices_or_msg_ids = [indices_or_msg_ids]
1043
1044
1044 theids = []
1045 theids = []
1045 for id in indices_or_msg_ids:
1046 for id in indices_or_msg_ids:
1046 if isinstance(id, int):
1047 if isinstance(id, int):
1047 id = self.history[id]
1048 id = self.history[id]
1048 if not isinstance(id, str):
1049 if not isinstance(id, str):
1049 raise TypeError("indices must be str or int, not %r"%id)
1050 raise TypeError("indices must be str or int, not %r"%id)
1050 theids.append(id)
1051 theids.append(id)
1051
1052
1052 local_ids = filter(lambda msg_id: msg_id in self.history or msg_id in self.results, theids)
1053 local_ids = filter(lambda msg_id: msg_id in self.history or msg_id in self.results, theids)
1053 remote_ids = filter(lambda msg_id: msg_id not in local_ids, theids)
1054 remote_ids = filter(lambda msg_id: msg_id not in local_ids, theids)
1054
1055
1055 if remote_ids:
1056 if remote_ids:
1056 ar = AsyncHubResult(self, msg_ids=theids)
1057 ar = AsyncHubResult(self, msg_ids=theids)
1057 else:
1058 else:
1058 ar = AsyncResult(self, msg_ids=theids)
1059 ar = AsyncResult(self, msg_ids=theids)
1059
1060
1060 if block:
1061 if block:
1061 ar.wait()
1062 ar.wait()
1062
1063
1063 return ar
1064 return ar
1064
1065
1065 @spin_first
1066 @spin_first
1066 def resubmit(self, indices_or_msg_ids=None, subheader=None, block=None):
1067 def resubmit(self, indices_or_msg_ids=None, subheader=None, block=None):
1067 """Resubmit one or more tasks.
1068 """Resubmit one or more tasks.
1068
1069
1069 in-flight tasks may not be resubmitted.
1070 in-flight tasks may not be resubmitted.
1070
1071
1071 Parameters
1072 Parameters
1072 ----------
1073 ----------
1073
1074
1074 indices_or_msg_ids : integer history index, str msg_id, or list of either
1075 indices_or_msg_ids : integer history index, str msg_id, or list of either
1075 The indices or msg_ids of indices to be retrieved
1076 The indices or msg_ids of indices to be retrieved
1076
1077
1077 block : bool
1078 block : bool
1078 Whether to wait for the result to be done
1079 Whether to wait for the result to be done
1079
1080
1080 Returns
1081 Returns
1081 -------
1082 -------
1082
1083
1083 AsyncHubResult
1084 AsyncHubResult
1084 A subclass of AsyncResult that retrieves results from the Hub
1085 A subclass of AsyncResult that retrieves results from the Hub
1085
1086
1086 """
1087 """
1087 block = self.block if block is None else block
1088 block = self.block if block is None else block
1088 if indices_or_msg_ids is None:
1089 if indices_or_msg_ids is None:
1089 indices_or_msg_ids = -1
1090 indices_or_msg_ids = -1
1090
1091
1091 if not isinstance(indices_or_msg_ids, (list,tuple)):
1092 if not isinstance(indices_or_msg_ids, (list,tuple)):
1092 indices_or_msg_ids = [indices_or_msg_ids]
1093 indices_or_msg_ids = [indices_or_msg_ids]
1093
1094
1094 theids = []
1095 theids = []
1095 for id in indices_or_msg_ids:
1096 for id in indices_or_msg_ids:
1096 if isinstance(id, int):
1097 if isinstance(id, int):
1097 id = self.history[id]
1098 id = self.history[id]
1098 if not isinstance(id, str):
1099 if not isinstance(id, str):
1099 raise TypeError("indices must be str or int, not %r"%id)
1100 raise TypeError("indices must be str or int, not %r"%id)
1100 theids.append(id)
1101 theids.append(id)
1101
1102
1102 for msg_id in theids:
1103 for msg_id in theids:
1103 self.outstanding.discard(msg_id)
1104 self.outstanding.discard(msg_id)
1104 if msg_id in self.history:
1105 if msg_id in self.history:
1105 self.history.remove(msg_id)
1106 self.history.remove(msg_id)
1106 self.results.pop(msg_id, None)
1107 self.results.pop(msg_id, None)
1107 self.metadata.pop(msg_id, None)
1108 self.metadata.pop(msg_id, None)
1108 content = dict(msg_ids = theids)
1109 content = dict(msg_ids = theids)
1109
1110
1110 self.session.send(self._query_socket, 'resubmit_request', content)
1111 self.session.send(self._query_socket, 'resubmit_request', content)
1111
1112
1112 zmq.select([self._query_socket], [], [])
1113 zmq.select([self._query_socket], [], [])
1113 idents,msg = self.session.recv(self._query_socket, zmq.NOBLOCK)
1114 idents,msg = self.session.recv(self._query_socket, zmq.NOBLOCK)
1114 if self.debug:
1115 if self.debug:
1115 pprint(msg)
1116 pprint(msg)
1116 content = msg['content']
1117 content = msg['content']
1117 if content['status'] != 'ok':
1118 if content['status'] != 'ok':
1118 raise self._unwrap_exception(content)
1119 raise self._unwrap_exception(content)
1119
1120
1120 ar = AsyncHubResult(self, msg_ids=theids)
1121 ar = AsyncHubResult(self, msg_ids=theids)
1121
1122
1122 if block:
1123 if block:
1123 ar.wait()
1124 ar.wait()
1124
1125
1125 return ar
1126 return ar
1126
1127
1127 @spin_first
1128 @spin_first
1128 def result_status(self, msg_ids, status_only=True):
1129 def result_status(self, msg_ids, status_only=True):
1129 """Check on the status of the result(s) of the apply request with `msg_ids`.
1130 """Check on the status of the result(s) of the apply request with `msg_ids`.
1130
1131
1131 If status_only is False, then the actual results will be retrieved, else
1132 If status_only is False, then the actual results will be retrieved, else
1132 only the status of the results will be checked.
1133 only the status of the results will be checked.
1133
1134
1134 Parameters
1135 Parameters
1135 ----------
1136 ----------
1136
1137
1137 msg_ids : list of msg_ids
1138 msg_ids : list of msg_ids
1138 if int:
1139 if int:
1139 Passed as index to self.history for convenience.
1140 Passed as index to self.history for convenience.
1140 status_only : bool (default: True)
1141 status_only : bool (default: True)
1141 if False:
1142 if False:
1142 Retrieve the actual results of completed tasks.
1143 Retrieve the actual results of completed tasks.
1143
1144
1144 Returns
1145 Returns
1145 -------
1146 -------
1146
1147
1147 results : dict
1148 results : dict
1148 There will always be the keys 'pending' and 'completed', which will
1149 There will always be the keys 'pending' and 'completed', which will
1149 be lists of msg_ids that are incomplete or complete. If `status_only`
1150 be lists of msg_ids that are incomplete or complete. If `status_only`
1150 is False, then completed results will be keyed by their `msg_id`.
1151 is False, then completed results will be keyed by their `msg_id`.
1151 """
1152 """
1152 if not isinstance(msg_ids, (list,tuple)):
1153 if not isinstance(msg_ids, (list,tuple)):
1153 msg_ids = [msg_ids]
1154 msg_ids = [msg_ids]
1154
1155
1155 theids = []
1156 theids = []
1156 for msg_id in msg_ids:
1157 for msg_id in msg_ids:
1157 if isinstance(msg_id, int):
1158 if isinstance(msg_id, int):
1158 msg_id = self.history[msg_id]
1159 msg_id = self.history[msg_id]
1159 if not isinstance(msg_id, basestring):
1160 if not isinstance(msg_id, basestring):
1160 raise TypeError("msg_ids must be str, not %r"%msg_id)
1161 raise TypeError("msg_ids must be str, not %r"%msg_id)
1161 theids.append(msg_id)
1162 theids.append(msg_id)
1162
1163
1163 completed = []
1164 completed = []
1164 local_results = {}
1165 local_results = {}
1165
1166
1166 # comment this block out to temporarily disable local shortcut:
1167 # comment this block out to temporarily disable local shortcut:
1167 for msg_id in theids:
1168 for msg_id in theids:
1168 if msg_id in self.results:
1169 if msg_id in self.results:
1169 completed.append(msg_id)
1170 completed.append(msg_id)
1170 local_results[msg_id] = self.results[msg_id]
1171 local_results[msg_id] = self.results[msg_id]
1171 theids.remove(msg_id)
1172 theids.remove(msg_id)
1172
1173
1173 if theids: # some not locally cached
1174 if theids: # some not locally cached
1174 content = dict(msg_ids=theids, status_only=status_only)
1175 content = dict(msg_ids=theids, status_only=status_only)
1175 msg = self.session.send(self._query_socket, "result_request", content=content)
1176 msg = self.session.send(self._query_socket, "result_request", content=content)
1176 zmq.select([self._query_socket], [], [])
1177 zmq.select([self._query_socket], [], [])
1177 idents,msg = self.session.recv(self._query_socket, zmq.NOBLOCK)
1178 idents,msg = self.session.recv(self._query_socket, zmq.NOBLOCK)
1178 if self.debug:
1179 if self.debug:
1179 pprint(msg)
1180 pprint(msg)
1180 content = msg['content']
1181 content = msg['content']
1181 if content['status'] != 'ok':
1182 if content['status'] != 'ok':
1182 raise self._unwrap_exception(content)
1183 raise self._unwrap_exception(content)
1183 buffers = msg['buffers']
1184 buffers = msg['buffers']
1184 else:
1185 else:
1185 content = dict(completed=[],pending=[])
1186 content = dict(completed=[],pending=[])
1186
1187
1187 content['completed'].extend(completed)
1188 content['completed'].extend(completed)
1188
1189
1189 if status_only:
1190 if status_only:
1190 return content
1191 return content
1191
1192
1192 failures = []
1193 failures = []
1193 # load cached results into result:
1194 # load cached results into result:
1194 content.update(local_results)
1195 content.update(local_results)
1195
1196
1196 # update cache with results:
1197 # update cache with results:
1197 for msg_id in sorted(theids):
1198 for msg_id in sorted(theids):
1198 if msg_id in content['completed']:
1199 if msg_id in content['completed']:
1199 rec = content[msg_id]
1200 rec = content[msg_id]
1200 parent = rec['header']
1201 parent = rec['header']
1201 header = rec['result_header']
1202 header = rec['result_header']
1202 rcontent = rec['result_content']
1203 rcontent = rec['result_content']
1203 iodict = rec['io']
1204 iodict = rec['io']
1204 if isinstance(rcontent, str):
1205 if isinstance(rcontent, str):
1205 rcontent = self.session.unpack(rcontent)
1206 rcontent = self.session.unpack(rcontent)
1206
1207
1207 md = self.metadata[msg_id]
1208 md = self.metadata[msg_id]
1208 md.update(self._extract_metadata(header, parent, rcontent))
1209 md.update(self._extract_metadata(header, parent, rcontent))
1209 md.update(iodict)
1210 md.update(iodict)
1210
1211
1211 if rcontent['status'] == 'ok':
1212 if rcontent['status'] == 'ok':
1212 res,buffers = util.unserialize_object(buffers)
1213 res,buffers = util.unserialize_object(buffers)
1213 else:
1214 else:
1214 print rcontent
1215 print rcontent
1215 res = self._unwrap_exception(rcontent)
1216 res = self._unwrap_exception(rcontent)
1216 failures.append(res)
1217 failures.append(res)
1217
1218
1218 self.results[msg_id] = res
1219 self.results[msg_id] = res
1219 content[msg_id] = res
1220 content[msg_id] = res
1220
1221
1221 if len(theids) == 1 and failures:
1222 if len(theids) == 1 and failures:
1222 raise failures[0]
1223 raise failures[0]
1223
1224
1224 error.collect_exceptions(failures, "result_status")
1225 error.collect_exceptions(failures, "result_status")
1225 return content
1226 return content
1226
1227
1227 @spin_first
1228 @spin_first
1228 def queue_status(self, targets='all', verbose=False):
1229 def queue_status(self, targets='all', verbose=False):
1229 """Fetch the status of engine queues.
1230 """Fetch the status of engine queues.
1230
1231
1231 Parameters
1232 Parameters
1232 ----------
1233 ----------
1233
1234
1234 targets : int/str/list of ints/strs
1235 targets : int/str/list of ints/strs
1235 the engines whose states are to be queried.
1236 the engines whose states are to be queried.
1236 default : all
1237 default : all
1237 verbose : bool
1238 verbose : bool
1238 Whether to return lengths only, or lists of ids for each element
1239 Whether to return lengths only, or lists of ids for each element
1239 """
1240 """
1240 engine_ids = self._build_targets(targets)[1]
1241 engine_ids = self._build_targets(targets)[1]
1241 content = dict(targets=engine_ids, verbose=verbose)
1242 content = dict(targets=engine_ids, verbose=verbose)
1242 self.session.send(self._query_socket, "queue_request", content=content)
1243 self.session.send(self._query_socket, "queue_request", content=content)
1243 idents,msg = self.session.recv(self._query_socket, 0)
1244 idents,msg = self.session.recv(self._query_socket, 0)
1244 if self.debug:
1245 if self.debug:
1245 pprint(msg)
1246 pprint(msg)
1246 content = msg['content']
1247 content = msg['content']
1247 status = content.pop('status')
1248 status = content.pop('status')
1248 if status != 'ok':
1249 if status != 'ok':
1249 raise self._unwrap_exception(content)
1250 raise self._unwrap_exception(content)
1250 content = util.rekey(content)
1251 content = rekey(content)
1251 if isinstance(targets, int):
1252 if isinstance(targets, int):
1252 return content[targets]
1253 return content[targets]
1253 else:
1254 else:
1254 return content
1255 return content
1255
1256
1256 @spin_first
1257 @spin_first
1257 def purge_results(self, jobs=[], targets=[]):
1258 def purge_results(self, jobs=[], targets=[]):
1258 """Tell the Hub to forget results.
1259 """Tell the Hub to forget results.
1259
1260
1260 Individual results can be purged by msg_id, or the entire
1261 Individual results can be purged by msg_id, or the entire
1261 history of specific targets can be purged.
1262 history of specific targets can be purged.
1262
1263
1263 Parameters
1264 Parameters
1264 ----------
1265 ----------
1265
1266
1266 jobs : str or list of str or AsyncResult objects
1267 jobs : str or list of str or AsyncResult objects
1267 the msg_ids whose results should be forgotten.
1268 the msg_ids whose results should be forgotten.
1268 targets : int/str/list of ints/strs
1269 targets : int/str/list of ints/strs
1269 The targets, by uuid or int_id, whose entire history is to be purged.
1270 The targets, by uuid or int_id, whose entire history is to be purged.
1270 Use `targets='all'` to scrub everything from the Hub's memory.
1271 Use `targets='all'` to scrub everything from the Hub's memory.
1271
1272
1272 default : None
1273 default : None
1273 """
1274 """
1274 if not targets and not jobs:
1275 if not targets and not jobs:
1275 raise ValueError("Must specify at least one of `targets` and `jobs`")
1276 raise ValueError("Must specify at least one of `targets` and `jobs`")
1276 if targets:
1277 if targets:
1277 targets = self._build_targets(targets)[1]
1278 targets = self._build_targets(targets)[1]
1278
1279
1279 # construct msg_ids from jobs
1280 # construct msg_ids from jobs
1280 msg_ids = []
1281 msg_ids = []
1281 if isinstance(jobs, (basestring,AsyncResult)):
1282 if isinstance(jobs, (basestring,AsyncResult)):
1282 jobs = [jobs]
1283 jobs = [jobs]
1283 bad_ids = filter(lambda obj: not isinstance(obj, (basestring, AsyncResult)), jobs)
1284 bad_ids = filter(lambda obj: not isinstance(obj, (basestring, AsyncResult)), jobs)
1284 if bad_ids:
1285 if bad_ids:
1285 raise TypeError("Invalid msg_id type %r, expected str or AsyncResult"%bad_ids[0])
1286 raise TypeError("Invalid msg_id type %r, expected str or AsyncResult"%bad_ids[0])
1286 for j in jobs:
1287 for j in jobs:
1287 if isinstance(j, AsyncResult):
1288 if isinstance(j, AsyncResult):
1288 msg_ids.extend(j.msg_ids)
1289 msg_ids.extend(j.msg_ids)
1289 else:
1290 else:
1290 msg_ids.append(j)
1291 msg_ids.append(j)
1291
1292
1292 content = dict(targets=targets, msg_ids=msg_ids)
1293 content = dict(targets=targets, msg_ids=msg_ids)
1293 self.session.send(self._query_socket, "purge_request", content=content)
1294 self.session.send(self._query_socket, "purge_request", content=content)
1294 idents, msg = self.session.recv(self._query_socket, 0)
1295 idents, msg = self.session.recv(self._query_socket, 0)
1295 if self.debug:
1296 if self.debug:
1296 pprint(msg)
1297 pprint(msg)
1297 content = msg['content']
1298 content = msg['content']
1298 if content['status'] != 'ok':
1299 if content['status'] != 'ok':
1299 raise self._unwrap_exception(content)
1300 raise self._unwrap_exception(content)
1300
1301
1301 @spin_first
1302 @spin_first
1302 def hub_history(self):
1303 def hub_history(self):
1303 """Get the Hub's history
1304 """Get the Hub's history
1304
1305
1305 Just like the Client, the Hub has a history, which is a list of msg_ids.
1306 Just like the Client, the Hub has a history, which is a list of msg_ids.
1306 This will contain the history of all clients, and, depending on configuration,
1307 This will contain the history of all clients, and, depending on configuration,
1307 may contain history across multiple cluster sessions.
1308 may contain history across multiple cluster sessions.
1308
1309
1309 Any msg_id returned here is a valid argument to `get_result`.
1310 Any msg_id returned here is a valid argument to `get_result`.
1310
1311
1311 Returns
1312 Returns
1312 -------
1313 -------
1313
1314
1314 msg_ids : list of strs
1315 msg_ids : list of strs
1315 list of all msg_ids, ordered by task submission time.
1316 list of all msg_ids, ordered by task submission time.
1316 """
1317 """
1317
1318
1318 self.session.send(self._query_socket, "history_request", content={})
1319 self.session.send(self._query_socket, "history_request", content={})
1319 idents, msg = self.session.recv(self._query_socket, 0)
1320 idents, msg = self.session.recv(self._query_socket, 0)
1320
1321
1321 if self.debug:
1322 if self.debug:
1322 pprint(msg)
1323 pprint(msg)
1323 content = msg['content']
1324 content = msg['content']
1324 if content['status'] != 'ok':
1325 if content['status'] != 'ok':
1325 raise self._unwrap_exception(content)
1326 raise self._unwrap_exception(content)
1326 else:
1327 else:
1327 return content['history']
1328 return content['history']
1328
1329
1329 @spin_first
1330 @spin_first
1330 def db_query(self, query, keys=None):
1331 def db_query(self, query, keys=None):
1331 """Query the Hub's TaskRecord database
1332 """Query the Hub's TaskRecord database
1332
1333
1333 This will return a list of task record dicts that match `query`
1334 This will return a list of task record dicts that match `query`
1334
1335
1335 Parameters
1336 Parameters
1336 ----------
1337 ----------
1337
1338
1338 query : mongodb query dict
1339 query : mongodb query dict
1339 The search dict. See mongodb query docs for details.
1340 The search dict. See mongodb query docs for details.
1340 keys : list of strs [optional]
1341 keys : list of strs [optional]
1341 The subset of keys to be returned. The default is to fetch everything but buffers.
1342 The subset of keys to be returned. The default is to fetch everything but buffers.
1342 'msg_id' will *always* be included.
1343 'msg_id' will *always* be included.
1343 """
1344 """
1344 if isinstance(keys, basestring):
1345 if isinstance(keys, basestring):
1345 keys = [keys]
1346 keys = [keys]
1346 content = dict(query=query, keys=keys)
1347 content = dict(query=query, keys=keys)
1347 self.session.send(self._query_socket, "db_request", content=content)
1348 self.session.send(self._query_socket, "db_request", content=content)
1348 idents, msg = self.session.recv(self._query_socket, 0)
1349 idents, msg = self.session.recv(self._query_socket, 0)
1349 if self.debug:
1350 if self.debug:
1350 pprint(msg)
1351 pprint(msg)
1351 content = msg['content']
1352 content = msg['content']
1352 if content['status'] != 'ok':
1353 if content['status'] != 'ok':
1353 raise self._unwrap_exception(content)
1354 raise self._unwrap_exception(content)
1354
1355
1355 records = content['records']
1356 records = content['records']
1356
1357
1357 buffer_lens = content['buffer_lens']
1358 buffer_lens = content['buffer_lens']
1358 result_buffer_lens = content['result_buffer_lens']
1359 result_buffer_lens = content['result_buffer_lens']
1359 buffers = msg['buffers']
1360 buffers = msg['buffers']
1360 has_bufs = buffer_lens is not None
1361 has_bufs = buffer_lens is not None
1361 has_rbufs = result_buffer_lens is not None
1362 has_rbufs = result_buffer_lens is not None
1362 for i,rec in enumerate(records):
1363 for i,rec in enumerate(records):
1363 # relink buffers
1364 # relink buffers
1364 if has_bufs:
1365 if has_bufs:
1365 blen = buffer_lens[i]
1366 blen = buffer_lens[i]
1366 rec['buffers'], buffers = buffers[:blen],buffers[blen:]
1367 rec['buffers'], buffers = buffers[:blen],buffers[blen:]
1367 if has_rbufs:
1368 if has_rbufs:
1368 blen = result_buffer_lens[i]
1369 blen = result_buffer_lens[i]
1369 rec['result_buffers'], buffers = buffers[:blen],buffers[blen:]
1370 rec['result_buffers'], buffers = buffers[:blen],buffers[blen:]
1370
1371
1371 return records
1372 return records
1372
1373
1373 __all__ = [ 'Client' ]
1374 __all__ = [ 'Client' ]
@@ -1,473 +1,450
1 """some generic utilities for dealing with classes, urls, and serialization
1 """some generic utilities for dealing with classes, urls, and serialization
2
2
3 Authors:
3 Authors:
4
4
5 * Min RK
5 * Min RK
6 """
6 """
7 #-----------------------------------------------------------------------------
7 #-----------------------------------------------------------------------------
8 # Copyright (C) 2010-2011 The IPython Development Team
8 # Copyright (C) 2010-2011 The IPython Development Team
9 #
9 #
10 # Distributed under the terms of the BSD License. The full license is in
10 # Distributed under the terms of the BSD License. The full license is in
11 # the file COPYING, distributed as part of this software.
11 # the file COPYING, distributed as part of this software.
12 #-----------------------------------------------------------------------------
12 #-----------------------------------------------------------------------------
13
13
14 #-----------------------------------------------------------------------------
14 #-----------------------------------------------------------------------------
15 # Imports
15 # Imports
16 #-----------------------------------------------------------------------------
16 #-----------------------------------------------------------------------------
17
17
18 # Standard library imports.
18 # Standard library imports.
19 import logging
19 import logging
20 import os
20 import os
21 import re
21 import re
22 import stat
22 import stat
23 import socket
23 import socket
24 import sys
24 import sys
25 from signal import signal, SIGINT, SIGABRT, SIGTERM
25 from signal import signal, SIGINT, SIGABRT, SIGTERM
26 try:
26 try:
27 from signal import SIGKILL
27 from signal import SIGKILL
28 except ImportError:
28 except ImportError:
29 SIGKILL=None
29 SIGKILL=None
30
30
31 try:
31 try:
32 import cPickle
32 import cPickle
33 pickle = cPickle
33 pickle = cPickle
34 except:
34 except:
35 cPickle = None
35 cPickle = None
36 import pickle
36 import pickle
37
37
38 # System library imports
38 # System library imports
39 import zmq
39 import zmq
40 from zmq.log import handlers
40 from zmq.log import handlers
41
41
42 # IPython imports
42 # IPython imports
43 from IPython.utils.pickleutil import can, uncan, canSequence, uncanSequence
43 from IPython.utils.pickleutil import can, uncan, canSequence, uncanSequence
44 from IPython.utils.newserialized import serialize, unserialize
44 from IPython.utils.newserialized import serialize, unserialize
45 from IPython.zmq.log import EnginePUBHandler
45 from IPython.zmq.log import EnginePUBHandler
46
46
47 #-----------------------------------------------------------------------------
47 #-----------------------------------------------------------------------------
48 # Classes
48 # Classes
49 #-----------------------------------------------------------------------------
49 #-----------------------------------------------------------------------------
50
50
51 class Namespace(dict):
51 class Namespace(dict):
52 """Subclass of dict for attribute access to keys."""
52 """Subclass of dict for attribute access to keys."""
53
53
54 def __getattr__(self, key):
54 def __getattr__(self, key):
55 """getattr aliased to getitem"""
55 """getattr aliased to getitem"""
56 if key in self.iterkeys():
56 if key in self.iterkeys():
57 return self[key]
57 return self[key]
58 else:
58 else:
59 raise NameError(key)
59 raise NameError(key)
60
60
61 def __setattr__(self, key, value):
61 def __setattr__(self, key, value):
62 """setattr aliased to setitem, with strict"""
62 """setattr aliased to setitem, with strict"""
63 if hasattr(dict, key):
63 if hasattr(dict, key):
64 raise KeyError("Cannot override dict keys %r"%key)
64 raise KeyError("Cannot override dict keys %r"%key)
65 self[key] = value
65 self[key] = value
66
66
67
67
68 class ReverseDict(dict):
68 class ReverseDict(dict):
69 """simple double-keyed subset of dict methods."""
69 """simple double-keyed subset of dict methods."""
70
70
71 def __init__(self, *args, **kwargs):
71 def __init__(self, *args, **kwargs):
72 dict.__init__(self, *args, **kwargs)
72 dict.__init__(self, *args, **kwargs)
73 self._reverse = dict()
73 self._reverse = dict()
74 for key, value in self.iteritems():
74 for key, value in self.iteritems():
75 self._reverse[value] = key
75 self._reverse[value] = key
76
76
77 def __getitem__(self, key):
77 def __getitem__(self, key):
78 try:
78 try:
79 return dict.__getitem__(self, key)
79 return dict.__getitem__(self, key)
80 except KeyError:
80 except KeyError:
81 return self._reverse[key]
81 return self._reverse[key]
82
82
83 def __setitem__(self, key, value):
83 def __setitem__(self, key, value):
84 if key in self._reverse:
84 if key in self._reverse:
85 raise KeyError("Can't have key %r on both sides!"%key)
85 raise KeyError("Can't have key %r on both sides!"%key)
86 dict.__setitem__(self, key, value)
86 dict.__setitem__(self, key, value)
87 self._reverse[value] = key
87 self._reverse[value] = key
88
88
89 def pop(self, key):
89 def pop(self, key):
90 value = dict.pop(self, key)
90 value = dict.pop(self, key)
91 self._reverse.pop(value)
91 self._reverse.pop(value)
92 return value
92 return value
93
93
94 def get(self, key, default=None):
94 def get(self, key, default=None):
95 try:
95 try:
96 return self[key]
96 return self[key]
97 except KeyError:
97 except KeyError:
98 return default
98 return default
99
99
100 #-----------------------------------------------------------------------------
100 #-----------------------------------------------------------------------------
101 # Functions
101 # Functions
102 #-----------------------------------------------------------------------------
102 #-----------------------------------------------------------------------------
103
103
104 def validate_url(url):
104 def validate_url(url):
105 """validate a url for zeromq"""
105 """validate a url for zeromq"""
106 if not isinstance(url, basestring):
106 if not isinstance(url, basestring):
107 raise TypeError("url must be a string, not %r"%type(url))
107 raise TypeError("url must be a string, not %r"%type(url))
108 url = url.lower()
108 url = url.lower()
109
109
110 proto_addr = url.split('://')
110 proto_addr = url.split('://')
111 assert len(proto_addr) == 2, 'Invalid url: %r'%url
111 assert len(proto_addr) == 2, 'Invalid url: %r'%url
112 proto, addr = proto_addr
112 proto, addr = proto_addr
113 assert proto in ['tcp','pgm','epgm','ipc','inproc'], "Invalid protocol: %r"%proto
113 assert proto in ['tcp','pgm','epgm','ipc','inproc'], "Invalid protocol: %r"%proto
114
114
115 # domain pattern adapted from http://www.regexlib.com/REDetails.aspx?regexp_id=391
115 # domain pattern adapted from http://www.regexlib.com/REDetails.aspx?regexp_id=391
116 # author: Remi Sabourin
116 # author: Remi Sabourin
117 pat = re.compile(r'^([\w\d]([\w\d\-]{0,61}[\w\d])?\.)*[\w\d]([\w\d\-]{0,61}[\w\d])?$')
117 pat = re.compile(r'^([\w\d]([\w\d\-]{0,61}[\w\d])?\.)*[\w\d]([\w\d\-]{0,61}[\w\d])?$')
118
118
119 if proto == 'tcp':
119 if proto == 'tcp':
120 lis = addr.split(':')
120 lis = addr.split(':')
121 assert len(lis) == 2, 'Invalid url: %r'%url
121 assert len(lis) == 2, 'Invalid url: %r'%url
122 addr,s_port = lis
122 addr,s_port = lis
123 try:
123 try:
124 port = int(s_port)
124 port = int(s_port)
125 except ValueError:
125 except ValueError:
126 raise AssertionError("Invalid port %r in url: %r"%(port, url))
126 raise AssertionError("Invalid port %r in url: %r"%(port, url))
127
127
128 assert addr == '*' or pat.match(addr) is not None, 'Invalid url: %r'%url
128 assert addr == '*' or pat.match(addr) is not None, 'Invalid url: %r'%url
129
129
130 else:
130 else:
131 # only validate tcp urls currently
131 # only validate tcp urls currently
132 pass
132 pass
133
133
134 return True
134 return True
135
135
136
136
137 def validate_url_container(container):
137 def validate_url_container(container):
138 """validate a potentially nested collection of urls."""
138 """validate a potentially nested collection of urls."""
139 if isinstance(container, basestring):
139 if isinstance(container, basestring):
140 url = container
140 url = container
141 return validate_url(url)
141 return validate_url(url)
142 elif isinstance(container, dict):
142 elif isinstance(container, dict):
143 container = container.itervalues()
143 container = container.itervalues()
144
144
145 for element in container:
145 for element in container:
146 validate_url_container(element)
146 validate_url_container(element)
147
147
148
148
149 def split_url(url):
149 def split_url(url):
150 """split a zmq url (tcp://ip:port) into ('tcp','ip','port')."""
150 """split a zmq url (tcp://ip:port) into ('tcp','ip','port')."""
151 proto_addr = url.split('://')
151 proto_addr = url.split('://')
152 assert len(proto_addr) == 2, 'Invalid url: %r'%url
152 assert len(proto_addr) == 2, 'Invalid url: %r'%url
153 proto, addr = proto_addr
153 proto, addr = proto_addr
154 lis = addr.split(':')
154 lis = addr.split(':')
155 assert len(lis) == 2, 'Invalid url: %r'%url
155 assert len(lis) == 2, 'Invalid url: %r'%url
156 addr,s_port = lis
156 addr,s_port = lis
157 return proto,addr,s_port
157 return proto,addr,s_port
158
158
159 def disambiguate_ip_address(ip, location=None):
159 def disambiguate_ip_address(ip, location=None):
160 """turn multi-ip interfaces '0.0.0.0' and '*' into connectable
160 """turn multi-ip interfaces '0.0.0.0' and '*' into connectable
161 ones, based on the location (default interpretation of location is localhost)."""
161 ones, based on the location (default interpretation of location is localhost)."""
162 if ip in ('0.0.0.0', '*'):
162 if ip in ('0.0.0.0', '*'):
163 external_ips = socket.gethostbyname_ex(socket.gethostname())[2]
163 external_ips = socket.gethostbyname_ex(socket.gethostname())[2]
164 if location is None or location in external_ips:
164 if location is None or location in external_ips:
165 ip='127.0.0.1'
165 ip='127.0.0.1'
166 elif location:
166 elif location:
167 return location
167 return location
168 return ip
168 return ip
169
169
170 def disambiguate_url(url, location=None):
170 def disambiguate_url(url, location=None):
171 """turn multi-ip interfaces '0.0.0.0' and '*' into connectable
171 """turn multi-ip interfaces '0.0.0.0' and '*' into connectable
172 ones, based on the location (default interpretation is localhost).
172 ones, based on the location (default interpretation is localhost).
173
173
174 This is for zeromq urls, such as tcp://*:10101."""
174 This is for zeromq urls, such as tcp://*:10101."""
175 try:
175 try:
176 proto,ip,port = split_url(url)
176 proto,ip,port = split_url(url)
177 except AssertionError:
177 except AssertionError:
178 # probably not tcp url; could be ipc, etc.
178 # probably not tcp url; could be ipc, etc.
179 return url
179 return url
180
180
181 ip = disambiguate_ip_address(ip,location)
181 ip = disambiguate_ip_address(ip,location)
182
182
183 return "%s://%s:%s"%(proto,ip,port)
183 return "%s://%s:%s"%(proto,ip,port)
184
184
185
186 def rekey(dikt):
187 """Rekey a dict that has been forced to use str keys where there should be
188 ints by json. This belongs in the jsonutil added by fperez."""
189 for k in dikt.iterkeys():
190 if isinstance(k, str):
191 ik=fk=None
192 try:
193 ik = int(k)
194 except ValueError:
195 try:
196 fk = float(k)
197 except ValueError:
198 continue
199 if ik is not None:
200 nk = ik
201 else:
202 nk = fk
203 if nk in dikt:
204 raise KeyError("already have key %r"%nk)
205 dikt[nk] = dikt.pop(k)
206 return dikt
207
208 def serialize_object(obj, threshold=64e-6):
185 def serialize_object(obj, threshold=64e-6):
209 """Serialize an object into a list of sendable buffers.
186 """Serialize an object into a list of sendable buffers.
210
187
211 Parameters
188 Parameters
212 ----------
189 ----------
213
190
214 obj : object
191 obj : object
215 The object to be serialized
192 The object to be serialized
216 threshold : float
193 threshold : float
217 The threshold for not double-pickling the content.
194 The threshold for not double-pickling the content.
218
195
219
196
220 Returns
197 Returns
221 -------
198 -------
222 ('pmd', [bufs]) :
199 ('pmd', [bufs]) :
223 where pmd is the pickled metadata wrapper,
200 where pmd is the pickled metadata wrapper,
224 bufs is a list of data buffers
201 bufs is a list of data buffers
225 """
202 """
226 databuffers = []
203 databuffers = []
227 if isinstance(obj, (list, tuple)):
204 if isinstance(obj, (list, tuple)):
228 clist = canSequence(obj)
205 clist = canSequence(obj)
229 slist = map(serialize, clist)
206 slist = map(serialize, clist)
230 for s in slist:
207 for s in slist:
231 if s.typeDescriptor in ('buffer', 'ndarray') or s.getDataSize() > threshold:
208 if s.typeDescriptor in ('buffer', 'ndarray') or s.getDataSize() > threshold:
232 databuffers.append(s.getData())
209 databuffers.append(s.getData())
233 s.data = None
210 s.data = None
234 return pickle.dumps(slist,-1), databuffers
211 return pickle.dumps(slist,-1), databuffers
235 elif isinstance(obj, dict):
212 elif isinstance(obj, dict):
236 sobj = {}
213 sobj = {}
237 for k in sorted(obj.iterkeys()):
214 for k in sorted(obj.iterkeys()):
238 s = serialize(can(obj[k]))
215 s = serialize(can(obj[k]))
239 if s.typeDescriptor in ('buffer', 'ndarray') or s.getDataSize() > threshold:
216 if s.typeDescriptor in ('buffer', 'ndarray') or s.getDataSize() > threshold:
240 databuffers.append(s.getData())
217 databuffers.append(s.getData())
241 s.data = None
218 s.data = None
242 sobj[k] = s
219 sobj[k] = s
243 return pickle.dumps(sobj,-1),databuffers
220 return pickle.dumps(sobj,-1),databuffers
244 else:
221 else:
245 s = serialize(can(obj))
222 s = serialize(can(obj))
246 if s.typeDescriptor in ('buffer', 'ndarray') or s.getDataSize() > threshold:
223 if s.typeDescriptor in ('buffer', 'ndarray') or s.getDataSize() > threshold:
247 databuffers.append(s.getData())
224 databuffers.append(s.getData())
248 s.data = None
225 s.data = None
249 return pickle.dumps(s,-1),databuffers
226 return pickle.dumps(s,-1),databuffers
250
227
251
228
252 def unserialize_object(bufs):
229 def unserialize_object(bufs):
253 """reconstruct an object serialized by serialize_object from data buffers."""
230 """reconstruct an object serialized by serialize_object from data buffers."""
254 bufs = list(bufs)
231 bufs = list(bufs)
255 sobj = pickle.loads(bufs.pop(0))
232 sobj = pickle.loads(bufs.pop(0))
256 if isinstance(sobj, (list, tuple)):
233 if isinstance(sobj, (list, tuple)):
257 for s in sobj:
234 for s in sobj:
258 if s.data is None:
235 if s.data is None:
259 s.data = bufs.pop(0)
236 s.data = bufs.pop(0)
260 return uncanSequence(map(unserialize, sobj)), bufs
237 return uncanSequence(map(unserialize, sobj)), bufs
261 elif isinstance(sobj, dict):
238 elif isinstance(sobj, dict):
262 newobj = {}
239 newobj = {}
263 for k in sorted(sobj.iterkeys()):
240 for k in sorted(sobj.iterkeys()):
264 s = sobj[k]
241 s = sobj[k]
265 if s.data is None:
242 if s.data is None:
266 s.data = bufs.pop(0)
243 s.data = bufs.pop(0)
267 newobj[k] = uncan(unserialize(s))
244 newobj[k] = uncan(unserialize(s))
268 return newobj, bufs
245 return newobj, bufs
269 else:
246 else:
270 if sobj.data is None:
247 if sobj.data is None:
271 sobj.data = bufs.pop(0)
248 sobj.data = bufs.pop(0)
272 return uncan(unserialize(sobj)), bufs
249 return uncan(unserialize(sobj)), bufs
273
250
274 def pack_apply_message(f, args, kwargs, threshold=64e-6):
251 def pack_apply_message(f, args, kwargs, threshold=64e-6):
275 """pack up a function, args, and kwargs to be sent over the wire
252 """pack up a function, args, and kwargs to be sent over the wire
276 as a series of buffers. Any object whose data is larger than `threshold`
253 as a series of buffers. Any object whose data is larger than `threshold`
277 will not have their data copied (currently only numpy arrays support zero-copy)"""
254 will not have their data copied (currently only numpy arrays support zero-copy)"""
278 msg = [pickle.dumps(can(f),-1)]
255 msg = [pickle.dumps(can(f),-1)]
279 databuffers = [] # for large objects
256 databuffers = [] # for large objects
280 sargs, bufs = serialize_object(args,threshold)
257 sargs, bufs = serialize_object(args,threshold)
281 msg.append(sargs)
258 msg.append(sargs)
282 databuffers.extend(bufs)
259 databuffers.extend(bufs)
283 skwargs, bufs = serialize_object(kwargs,threshold)
260 skwargs, bufs = serialize_object(kwargs,threshold)
284 msg.append(skwargs)
261 msg.append(skwargs)
285 databuffers.extend(bufs)
262 databuffers.extend(bufs)
286 msg.extend(databuffers)
263 msg.extend(databuffers)
287 return msg
264 return msg
288
265
289 def unpack_apply_message(bufs, g=None, copy=True):
266 def unpack_apply_message(bufs, g=None, copy=True):
290 """unpack f,args,kwargs from buffers packed by pack_apply_message()
267 """unpack f,args,kwargs from buffers packed by pack_apply_message()
291 Returns: original f,args,kwargs"""
268 Returns: original f,args,kwargs"""
292 bufs = list(bufs) # allow us to pop
269 bufs = list(bufs) # allow us to pop
293 assert len(bufs) >= 3, "not enough buffers!"
270 assert len(bufs) >= 3, "not enough buffers!"
294 if not copy:
271 if not copy:
295 for i in range(3):
272 for i in range(3):
296 bufs[i] = bufs[i].bytes
273 bufs[i] = bufs[i].bytes
297 cf = pickle.loads(bufs.pop(0))
274 cf = pickle.loads(bufs.pop(0))
298 sargs = list(pickle.loads(bufs.pop(0)))
275 sargs = list(pickle.loads(bufs.pop(0)))
299 skwargs = dict(pickle.loads(bufs.pop(0)))
276 skwargs = dict(pickle.loads(bufs.pop(0)))
300 # print sargs, skwargs
277 # print sargs, skwargs
301 f = uncan(cf, g)
278 f = uncan(cf, g)
302 for sa in sargs:
279 for sa in sargs:
303 if sa.data is None:
280 if sa.data is None:
304 m = bufs.pop(0)
281 m = bufs.pop(0)
305 if sa.getTypeDescriptor() in ('buffer', 'ndarray'):
282 if sa.getTypeDescriptor() in ('buffer', 'ndarray'):
306 # always use a buffer, until memoryviews get sorted out
283 # always use a buffer, until memoryviews get sorted out
307 sa.data = buffer(m)
284 sa.data = buffer(m)
308 # disable memoryview support
285 # disable memoryview support
309 # if copy:
286 # if copy:
310 # sa.data = buffer(m)
287 # sa.data = buffer(m)
311 # else:
288 # else:
312 # sa.data = m.buffer
289 # sa.data = m.buffer
313 else:
290 else:
314 if copy:
291 if copy:
315 sa.data = m
292 sa.data = m
316 else:
293 else:
317 sa.data = m.bytes
294 sa.data = m.bytes
318
295
319 args = uncanSequence(map(unserialize, sargs), g)
296 args = uncanSequence(map(unserialize, sargs), g)
320 kwargs = {}
297 kwargs = {}
321 for k in sorted(skwargs.iterkeys()):
298 for k in sorted(skwargs.iterkeys()):
322 sa = skwargs[k]
299 sa = skwargs[k]
323 if sa.data is None:
300 if sa.data is None:
324 m = bufs.pop(0)
301 m = bufs.pop(0)
325 if sa.getTypeDescriptor() in ('buffer', 'ndarray'):
302 if sa.getTypeDescriptor() in ('buffer', 'ndarray'):
326 # always use a buffer, until memoryviews get sorted out
303 # always use a buffer, until memoryviews get sorted out
327 sa.data = buffer(m)
304 sa.data = buffer(m)
328 # disable memoryview support
305 # disable memoryview support
329 # if copy:
306 # if copy:
330 # sa.data = buffer(m)
307 # sa.data = buffer(m)
331 # else:
308 # else:
332 # sa.data = m.buffer
309 # sa.data = m.buffer
333 else:
310 else:
334 if copy:
311 if copy:
335 sa.data = m
312 sa.data = m
336 else:
313 else:
337 sa.data = m.bytes
314 sa.data = m.bytes
338
315
339 kwargs[k] = uncan(unserialize(sa), g)
316 kwargs[k] = uncan(unserialize(sa), g)
340
317
341 return f,args,kwargs
318 return f,args,kwargs
342
319
343 #--------------------------------------------------------------------------
320 #--------------------------------------------------------------------------
344 # helpers for implementing old MEC API via view.apply
321 # helpers for implementing old MEC API via view.apply
345 #--------------------------------------------------------------------------
322 #--------------------------------------------------------------------------
346
323
347 def interactive(f):
324 def interactive(f):
348 """decorator for making functions appear as interactively defined.
325 """decorator for making functions appear as interactively defined.
349 This results in the function being linked to the user_ns as globals()
326 This results in the function being linked to the user_ns as globals()
350 instead of the module globals().
327 instead of the module globals().
351 """
328 """
352 f.__module__ = '__main__'
329 f.__module__ = '__main__'
353 return f
330 return f
354
331
355 @interactive
332 @interactive
356 def _push(ns):
333 def _push(ns):
357 """helper method for implementing `client.push` via `client.apply`"""
334 """helper method for implementing `client.push` via `client.apply`"""
358 globals().update(ns)
335 globals().update(ns)
359
336
360 @interactive
337 @interactive
361 def _pull(keys):
338 def _pull(keys):
362 """helper method for implementing `client.pull` via `client.apply`"""
339 """helper method for implementing `client.pull` via `client.apply`"""
363 user_ns = globals()
340 user_ns = globals()
364 if isinstance(keys, (list,tuple, set)):
341 if isinstance(keys, (list,tuple, set)):
365 for key in keys:
342 for key in keys:
366 if not user_ns.has_key(key):
343 if not user_ns.has_key(key):
367 raise NameError("name '%s' is not defined"%key)
344 raise NameError("name '%s' is not defined"%key)
368 return map(user_ns.get, keys)
345 return map(user_ns.get, keys)
369 else:
346 else:
370 if not user_ns.has_key(keys):
347 if not user_ns.has_key(keys):
371 raise NameError("name '%s' is not defined"%keys)
348 raise NameError("name '%s' is not defined"%keys)
372 return user_ns.get(keys)
349 return user_ns.get(keys)
373
350
374 @interactive
351 @interactive
375 def _execute(code):
352 def _execute(code):
376 """helper method for implementing `client.execute` via `client.apply`"""
353 """helper method for implementing `client.execute` via `client.apply`"""
377 exec code in globals()
354 exec code in globals()
378
355
379 #--------------------------------------------------------------------------
356 #--------------------------------------------------------------------------
380 # extra process management utilities
357 # extra process management utilities
381 #--------------------------------------------------------------------------
358 #--------------------------------------------------------------------------
382
359
383 _random_ports = set()
360 _random_ports = set()
384
361
385 def select_random_ports(n):
362 def select_random_ports(n):
386 """Selects and return n random ports that are available."""
363 """Selects and return n random ports that are available."""
387 ports = []
364 ports = []
388 for i in xrange(n):
365 for i in xrange(n):
389 sock = socket.socket()
366 sock = socket.socket()
390 sock.bind(('', 0))
367 sock.bind(('', 0))
391 while sock.getsockname()[1] in _random_ports:
368 while sock.getsockname()[1] in _random_ports:
392 sock.close()
369 sock.close()
393 sock = socket.socket()
370 sock = socket.socket()
394 sock.bind(('', 0))
371 sock.bind(('', 0))
395 ports.append(sock)
372 ports.append(sock)
396 for i, sock in enumerate(ports):
373 for i, sock in enumerate(ports):
397 port = sock.getsockname()[1]
374 port = sock.getsockname()[1]
398 sock.close()
375 sock.close()
399 ports[i] = port
376 ports[i] = port
400 _random_ports.add(port)
377 _random_ports.add(port)
401 return ports
378 return ports
402
379
403 def signal_children(children):
380 def signal_children(children):
404 """Relay interupt/term signals to children, for more solid process cleanup."""
381 """Relay interupt/term signals to children, for more solid process cleanup."""
405 def terminate_children(sig, frame):
382 def terminate_children(sig, frame):
406 logging.critical("Got signal %i, terminating children..."%sig)
383 logging.critical("Got signal %i, terminating children..."%sig)
407 for child in children:
384 for child in children:
408 child.terminate()
385 child.terminate()
409
386
410 sys.exit(sig != SIGINT)
387 sys.exit(sig != SIGINT)
411 # sys.exit(sig)
388 # sys.exit(sig)
412 for sig in (SIGINT, SIGABRT, SIGTERM):
389 for sig in (SIGINT, SIGABRT, SIGTERM):
413 signal(sig, terminate_children)
390 signal(sig, terminate_children)
414
391
415 def generate_exec_key(keyfile):
392 def generate_exec_key(keyfile):
416 import uuid
393 import uuid
417 newkey = str(uuid.uuid4())
394 newkey = str(uuid.uuid4())
418 with open(keyfile, 'w') as f:
395 with open(keyfile, 'w') as f:
419 # f.write('ipython-key ')
396 # f.write('ipython-key ')
420 f.write(newkey+'\n')
397 f.write(newkey+'\n')
421 # set user-only RW permissions (0600)
398 # set user-only RW permissions (0600)
422 # this will have no effect on Windows
399 # this will have no effect on Windows
423 os.chmod(keyfile, stat.S_IRUSR|stat.S_IWUSR)
400 os.chmod(keyfile, stat.S_IRUSR|stat.S_IWUSR)
424
401
425
402
426 def integer_loglevel(loglevel):
403 def integer_loglevel(loglevel):
427 try:
404 try:
428 loglevel = int(loglevel)
405 loglevel = int(loglevel)
429 except ValueError:
406 except ValueError:
430 if isinstance(loglevel, str):
407 if isinstance(loglevel, str):
431 loglevel = getattr(logging, loglevel)
408 loglevel = getattr(logging, loglevel)
432 return loglevel
409 return loglevel
433
410
434 def connect_logger(logname, context, iface, root="ip", loglevel=logging.DEBUG):
411 def connect_logger(logname, context, iface, root="ip", loglevel=logging.DEBUG):
435 logger = logging.getLogger(logname)
412 logger = logging.getLogger(logname)
436 if any([isinstance(h, handlers.PUBHandler) for h in logger.handlers]):
413 if any([isinstance(h, handlers.PUBHandler) for h in logger.handlers]):
437 # don't add a second PUBHandler
414 # don't add a second PUBHandler
438 return
415 return
439 loglevel = integer_loglevel(loglevel)
416 loglevel = integer_loglevel(loglevel)
440 lsock = context.socket(zmq.PUB)
417 lsock = context.socket(zmq.PUB)
441 lsock.connect(iface)
418 lsock.connect(iface)
442 handler = handlers.PUBHandler(lsock)
419 handler = handlers.PUBHandler(lsock)
443 handler.setLevel(loglevel)
420 handler.setLevel(loglevel)
444 handler.root_topic = root
421 handler.root_topic = root
445 logger.addHandler(handler)
422 logger.addHandler(handler)
446 logger.setLevel(loglevel)
423 logger.setLevel(loglevel)
447
424
448 def connect_engine_logger(context, iface, engine, loglevel=logging.DEBUG):
425 def connect_engine_logger(context, iface, engine, loglevel=logging.DEBUG):
449 logger = logging.getLogger()
426 logger = logging.getLogger()
450 if any([isinstance(h, handlers.PUBHandler) for h in logger.handlers]):
427 if any([isinstance(h, handlers.PUBHandler) for h in logger.handlers]):
451 # don't add a second PUBHandler
428 # don't add a second PUBHandler
452 return
429 return
453 loglevel = integer_loglevel(loglevel)
430 loglevel = integer_loglevel(loglevel)
454 lsock = context.socket(zmq.PUB)
431 lsock = context.socket(zmq.PUB)
455 lsock.connect(iface)
432 lsock.connect(iface)
456 handler = EnginePUBHandler(engine, lsock)
433 handler = EnginePUBHandler(engine, lsock)
457 handler.setLevel(loglevel)
434 handler.setLevel(loglevel)
458 logger.addHandler(handler)
435 logger.addHandler(handler)
459 logger.setLevel(loglevel)
436 logger.setLevel(loglevel)
460 return logger
437 return logger
461
438
462 def local_logger(logname, loglevel=logging.DEBUG):
439 def local_logger(logname, loglevel=logging.DEBUG):
463 loglevel = integer_loglevel(loglevel)
440 loglevel = integer_loglevel(loglevel)
464 logger = logging.getLogger(logname)
441 logger = logging.getLogger(logname)
465 if any([isinstance(h, logging.StreamHandler) for h in logger.handlers]):
442 if any([isinstance(h, logging.StreamHandler) for h in logger.handlers]):
466 # don't add a second StreamHandler
443 # don't add a second StreamHandler
467 return
444 return
468 handler = logging.StreamHandler()
445 handler = logging.StreamHandler()
469 handler.setLevel(loglevel)
446 handler.setLevel(loglevel)
470 logger.addHandler(handler)
447 logger.addHandler(handler)
471 logger.setLevel(loglevel)
448 logger.setLevel(loglevel)
472 return logger
449 return logger
473
450
@@ -1,134 +1,157
1 """Utilities to manipulate JSON objects.
1 """Utilities to manipulate JSON objects.
2 """
2 """
3 #-----------------------------------------------------------------------------
3 #-----------------------------------------------------------------------------
4 # Copyright (C) 2010 The IPython Development Team
4 # Copyright (C) 2010 The IPython Development Team
5 #
5 #
6 # Distributed under the terms of the BSD License. The full license is in
6 # Distributed under the terms of the BSD License. The full license is in
7 # the file COPYING.txt, distributed as part of this software.
7 # the file COPYING.txt, distributed as part of this software.
8 #-----------------------------------------------------------------------------
8 #-----------------------------------------------------------------------------
9
9
10 #-----------------------------------------------------------------------------
10 #-----------------------------------------------------------------------------
11 # Imports
11 # Imports
12 #-----------------------------------------------------------------------------
12 #-----------------------------------------------------------------------------
13 # stdlib
13 # stdlib
14 import re
14 import re
15 import types
15 import types
16 from datetime import datetime
16 from datetime import datetime
17
17
18 #-----------------------------------------------------------------------------
18 #-----------------------------------------------------------------------------
19 # Globals and constants
19 # Globals and constants
20 #-----------------------------------------------------------------------------
20 #-----------------------------------------------------------------------------
21
21
22 # timestamp formats
22 # timestamp formats
23 ISO8601="%Y-%m-%dT%H:%M:%S.%f"
23 ISO8601="%Y-%m-%dT%H:%M:%S.%f"
24 ISO8601_PAT=re.compile(r"^\d{4}-\d{2}-\d{2}T\d{2}:\d{2}:\d{2}\.\d+$")
24 ISO8601_PAT=re.compile(r"^\d{4}-\d{2}-\d{2}T\d{2}:\d{2}:\d{2}\.\d+$")
25
25
26 #-----------------------------------------------------------------------------
26 #-----------------------------------------------------------------------------
27 # Classes and functions
27 # Classes and functions
28 #-----------------------------------------------------------------------------
28 #-----------------------------------------------------------------------------
29
29
30 def rekey(dikt):
31 """Rekey a dict that has been forced to use str keys where there should be
32 ints by json."""
33 for k in dikt.iterkeys():
34 if isinstance(k, basestring):
35 ik=fk=None
36 try:
37 ik = int(k)
38 except ValueError:
39 try:
40 fk = float(k)
41 except ValueError:
42 continue
43 if ik is not None:
44 nk = ik
45 else:
46 nk = fk
47 if nk in dikt:
48 raise KeyError("already have key %r"%nk)
49 dikt[nk] = dikt.pop(k)
50 return dikt
51
52
30 def extract_dates(obj):
53 def extract_dates(obj):
31 """extract ISO8601 dates from unpacked JSON"""
54 """extract ISO8601 dates from unpacked JSON"""
32 if isinstance(obj, dict):
55 if isinstance(obj, dict):
33 obj = dict(obj) # don't clobber
56 obj = dict(obj) # don't clobber
34 for k,v in obj.iteritems():
57 for k,v in obj.iteritems():
35 obj[k] = extract_dates(v)
58 obj[k] = extract_dates(v)
36 elif isinstance(obj, (list, tuple)):
59 elif isinstance(obj, (list, tuple)):
37 obj = [ extract_dates(o) for o in obj ]
60 obj = [ extract_dates(o) for o in obj ]
38 elif isinstance(obj, basestring):
61 elif isinstance(obj, basestring):
39 if ISO8601_PAT.match(obj):
62 if ISO8601_PAT.match(obj):
40 obj = datetime.strptime(obj, ISO8601)
63 obj = datetime.strptime(obj, ISO8601)
41 return obj
64 return obj
42
65
43 def squash_dates(obj):
66 def squash_dates(obj):
44 """squash datetime objects into ISO8601 strings"""
67 """squash datetime objects into ISO8601 strings"""
45 if isinstance(obj, dict):
68 if isinstance(obj, dict):
46 obj = dict(obj) # don't clobber
69 obj = dict(obj) # don't clobber
47 for k,v in obj.iteritems():
70 for k,v in obj.iteritems():
48 obj[k] = squash_dates(v)
71 obj[k] = squash_dates(v)
49 elif isinstance(obj, (list, tuple)):
72 elif isinstance(obj, (list, tuple)):
50 obj = [ squash_dates(o) for o in obj ]
73 obj = [ squash_dates(o) for o in obj ]
51 elif isinstance(obj, datetime):
74 elif isinstance(obj, datetime):
52 obj = obj.strftime(ISO8601)
75 obj = obj.strftime(ISO8601)
53 return obj
76 return obj
54
77
55 def date_default(obj):
78 def date_default(obj):
56 """default function for packing datetime objects in JSON."""
79 """default function for packing datetime objects in JSON."""
57 if isinstance(obj, datetime):
80 if isinstance(obj, datetime):
58 return obj.strftime(ISO8601)
81 return obj.strftime(ISO8601)
59 else:
82 else:
60 raise TypeError("%r is not JSON serializable"%obj)
83 raise TypeError("%r is not JSON serializable"%obj)
61
84
62
85
63
86
64 def json_clean(obj):
87 def json_clean(obj):
65 """Clean an object to ensure it's safe to encode in JSON.
88 """Clean an object to ensure it's safe to encode in JSON.
66
89
67 Atomic, immutable objects are returned unmodified. Sets and tuples are
90 Atomic, immutable objects are returned unmodified. Sets and tuples are
68 converted to lists, lists are copied and dicts are also copied.
91 converted to lists, lists are copied and dicts are also copied.
69
92
70 Note: dicts whose keys could cause collisions upon encoding (such as a dict
93 Note: dicts whose keys could cause collisions upon encoding (such as a dict
71 with both the number 1 and the string '1' as keys) will cause a ValueError
94 with both the number 1 and the string '1' as keys) will cause a ValueError
72 to be raised.
95 to be raised.
73
96
74 Parameters
97 Parameters
75 ----------
98 ----------
76 obj : any python object
99 obj : any python object
77
100
78 Returns
101 Returns
79 -------
102 -------
80 out : object
103 out : object
81
104
82 A version of the input which will not cause an encoding error when
105 A version of the input which will not cause an encoding error when
83 encoded as JSON. Note that this function does not *encode* its inputs,
106 encoded as JSON. Note that this function does not *encode* its inputs,
84 it simply sanitizes it so that there will be no encoding errors later.
107 it simply sanitizes it so that there will be no encoding errors later.
85
108
86 Examples
109 Examples
87 --------
110 --------
88 >>> json_clean(4)
111 >>> json_clean(4)
89 4
112 4
90 >>> json_clean(range(10))
113 >>> json_clean(range(10))
91 [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]
114 [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]
92 >>> json_clean(dict(x=1, y=2))
115 >>> json_clean(dict(x=1, y=2))
93 {'y': 2, 'x': 1}
116 {'y': 2, 'x': 1}
94 >>> json_clean(dict(x=1, y=2, z=[1,2,3]))
117 >>> json_clean(dict(x=1, y=2, z=[1,2,3]))
95 {'y': 2, 'x': 1, 'z': [1, 2, 3]}
118 {'y': 2, 'x': 1, 'z': [1, 2, 3]}
96 >>> json_clean(True)
119 >>> json_clean(True)
97 True
120 True
98 """
121 """
99 # types that are 'atomic' and ok in json as-is. bool doesn't need to be
122 # types that are 'atomic' and ok in json as-is. bool doesn't need to be
100 # listed explicitly because bools pass as int instances
123 # listed explicitly because bools pass as int instances
101 atomic_ok = (basestring, int, float, types.NoneType)
124 atomic_ok = (basestring, int, float, types.NoneType)
102
125
103 # containers that we need to convert into lists
126 # containers that we need to convert into lists
104 container_to_list = (tuple, set, types.GeneratorType)
127 container_to_list = (tuple, set, types.GeneratorType)
105
128
106 if isinstance(obj, atomic_ok):
129 if isinstance(obj, atomic_ok):
107 return obj
130 return obj
108
131
109 if isinstance(obj, container_to_list) or (
132 if isinstance(obj, container_to_list) or (
110 hasattr(obj, '__iter__') and hasattr(obj, 'next')):
133 hasattr(obj, '__iter__') and hasattr(obj, 'next')):
111 obj = list(obj)
134 obj = list(obj)
112
135
113 if isinstance(obj, list):
136 if isinstance(obj, list):
114 return [json_clean(x) for x in obj]
137 return [json_clean(x) for x in obj]
115
138
116 if isinstance(obj, dict):
139 if isinstance(obj, dict):
117 # First, validate that the dict won't lose data in conversion due to
140 # First, validate that the dict won't lose data in conversion due to
118 # key collisions after stringification. This can happen with keys like
141 # key collisions after stringification. This can happen with keys like
119 # True and 'true' or 1 and '1', which collide in JSON.
142 # True and 'true' or 1 and '1', which collide in JSON.
120 nkeys = len(obj)
143 nkeys = len(obj)
121 nkeys_collapsed = len(set(map(str, obj)))
144 nkeys_collapsed = len(set(map(str, obj)))
122 if nkeys != nkeys_collapsed:
145 if nkeys != nkeys_collapsed:
123 raise ValueError('dict can not be safely converted to JSON: '
146 raise ValueError('dict can not be safely converted to JSON: '
124 'key collision would lead to dropped values')
147 'key collision would lead to dropped values')
125 # If all OK, proceed by making the new dict that will be json-safe
148 # If all OK, proceed by making the new dict that will be json-safe
126 out = {}
149 out = {}
127 for k,v in obj.iteritems():
150 for k,v in obj.iteritems():
128 out[str(k)] = json_clean(v)
151 out[str(k)] = json_clean(v)
129 return out
152 return out
130
153
131 # If we get here, we don't know how to handle the object, so we just get
154 # If we get here, we don't know how to handle the object, so we just get
132 # its repr and return that. This will catch lambdas, open sockets, class
155 # its repr and return that. This will catch lambdas, open sockets, class
133 # objects, and any other complicated contraption that json can't encode
156 # objects, and any other complicated contraption that json can't encode
134 return repr(obj)
157 return repr(obj)
General Comments 0
You need to be logged in to leave comments. Login now