##// END OF EJS Templates
adapt kernel/error.py to zmq, improve error propagation.
MinRK -
Show More
@@ -0,0 +1,276
1 # encoding: utf-8
2
3 """Classes and functions for kernel related errors and exceptions."""
4 from __future__ import print_function
5
6 __docformat__ = "restructuredtext en"
7
8 # Tell nose to skip this module
9 __test__ = {}
10
11 #-------------------------------------------------------------------------------
12 # Copyright (C) 2008 The IPython Development Team
13 #
14 # Distributed under the terms of the BSD License. The full license is in
15 # the file COPYING, distributed as part of this software.
16 #-------------------------------------------------------------------------------
17
18 #-------------------------------------------------------------------------------
19 # Error classes
20 #-------------------------------------------------------------------------------
21 class IPythonError(Exception):
22 """Base exception that all of our exceptions inherit from.
23
24 This can be raised by code that doesn't have any more specific
25 information."""
26
27 pass
28
29 # Exceptions associated with the controller objects
30 class ControllerError(IPythonError): pass
31
32 class ControllerCreationError(ControllerError): pass
33
34
35 # Exceptions associated with the Engines
36 class EngineError(IPythonError): pass
37
38 class EngineCreationError(EngineError): pass
39
40 class KernelError(IPythonError):
41 pass
42
43 class NotDefined(KernelError):
44 def __init__(self, name):
45 self.name = name
46 self.args = (name,)
47
48 def __repr__(self):
49 return '<NotDefined: %s>' % self.name
50
51 __str__ = __repr__
52
53
54 class QueueCleared(KernelError):
55 pass
56
57
58 class IdInUse(KernelError):
59 pass
60
61
62 class ProtocolError(KernelError):
63 pass
64
65
66 class ConnectionError(KernelError):
67 pass
68
69
70 class InvalidEngineID(KernelError):
71 pass
72
73
74 class NoEnginesRegistered(KernelError):
75 pass
76
77
78 class InvalidClientID(KernelError):
79 pass
80
81
82 class InvalidDeferredID(KernelError):
83 pass
84
85
86 class SerializationError(KernelError):
87 pass
88
89
90 class MessageSizeError(KernelError):
91 pass
92
93
94 class PBMessageSizeError(MessageSizeError):
95 pass
96
97
98 class ResultNotCompleted(KernelError):
99 pass
100
101
102 class ResultAlreadyRetrieved(KernelError):
103 pass
104
105 class ClientError(KernelError):
106 pass
107
108
109 class TaskAborted(KernelError):
110 pass
111
112
113 class TaskTimeout(KernelError):
114 pass
115
116
117 class NotAPendingResult(KernelError):
118 pass
119
120
121 class UnpickleableException(KernelError):
122 pass
123
124
125 class AbortedPendingDeferredError(KernelError):
126 pass
127
128
129 class InvalidProperty(KernelError):
130 pass
131
132
133 class MissingBlockArgument(KernelError):
134 pass
135
136
137 class StopLocalExecution(KernelError):
138 pass
139
140
141 class SecurityError(KernelError):
142 pass
143
144
145 class FileTimeoutError(KernelError):
146 pass
147
148 class RemoteError(KernelError):
149 """Error raised elsewhere"""
150 ename=None
151 evalue=None
152 traceback=None
153 engine_info=None
154
155 def __init__(self, ename, evalue, traceback, engine_info=None):
156 self.ename=ename
157 self.evalue=evalue
158 self.traceback=traceback
159 self.engine_info=engine_info or {}
160 self.args=(ename, evalue)
161
162 def __repr__(self):
163 engineid = self.engine_info.get('engineid', ' ')
164 return "<Remote[%s]:%s(%s)>"%(engineid, self.ename, self.evalue)
165
166 def __str__(self):
167 sig = "%s(%s)"%(self.ename, self.evalue)
168 if self.traceback:
169 return sig + '\n' + self.traceback
170 else:
171 return sig
172
173
174 class TaskRejectError(KernelError):
175 """Exception to raise when a task should be rejected by an engine.
176
177 This exception can be used to allow a task running on an engine to test
178 if the engine (or the user's namespace on the engine) has the needed
179 task dependencies. If not, the task should raise this exception. For
180 the task to be retried on another engine, the task should be created
181 with the `retries` argument > 1.
182
183 The advantage of this approach over our older properties system is that
184 tasks have full access to the user's namespace on the engines and the
185 properties don't have to be managed or tested by the controller.
186 """
187
188
189 class CompositeError(KernelError):
190 """Error for representing possibly multiple errors on engines"""
191 def __init__(self, message, elist):
192 Exception.__init__(self, *(message, elist))
193 # Don't use pack_exception because it will conflict with the .message
194 # attribute that is being deprecated in 2.6 and beyond.
195 self.msg = message
196 self.elist = elist
197 self.args = [ e[0] for e in elist ]
198
199 def _get_engine_str(self, ei):
200 if not ei:
201 return '[Engine Exception]'
202 else:
203 return '[%i:%s]: ' % (ei['engineid'], ei['method'])
204
205 def _get_traceback(self, ev):
206 try:
207 tb = ev._ipython_traceback_text
208 except AttributeError:
209 return 'No traceback available'
210 else:
211 return tb
212
213 def __str__(self):
214 s = str(self.msg)
215 for en, ev, etb, ei in self.elist:
216 engine_str = self._get_engine_str(ei)
217 s = s + '\n' + engine_str + en + ': ' + str(ev)
218 return s
219
220 def __repr__(self):
221 return "CompositeError(%i)"%len(self.elist)
222
223 def print_tracebacks(self, excid=None):
224 if excid is None:
225 for (en,ev,etb,ei) in self.elist:
226 print (self._get_engine_str(ei))
227 print (etb or 'No traceback available')
228 print ()
229 else:
230 try:
231 en,ev,etb,ei = self.elist[excid]
232 except:
233 raise IndexError("an exception with index %i does not exist"%excid)
234 else:
235 print (self._get_engine_str(ei))
236 print (etb or 'No traceback available')
237
238 def raise_exception(self, excid=0):
239 try:
240 en,ev,etb,ei = self.elist[excid]
241 except:
242 raise IndexError("an exception with index %i does not exist"%excid)
243 else:
244 try:
245 raise RemoteError(en, ev, etb, ei)
246 except:
247 et,ev,tb = sys.exc_info()
248
249
250 def collect_exceptions(rdict, method):
251 """check a result dict for errors, and raise CompositeError if any exist.
252 Passthrough otherwise."""
253 elist = []
254 for r in rdict.values():
255 if isinstance(r, RemoteError):
256 en, ev, etb, ei = r.ename, r.evalue, r.traceback, r.engine_info
257 # Sometimes we could have CompositeError in our list. Just take
258 # the errors out of them and put them in our new list. This
259 # has the effect of flattening lists of CompositeErrors into one
260 # CompositeError
261 if en=='CompositeError':
262 for e in ev.elist:
263 elist.append(e)
264 else:
265 elist.append((en, ev, etb, ei))
266 if len(elist)==0:
267 return rdict
268 else:
269 msg = "one or more exceptions from call to method: %s" % (method)
270 # This silliness is needed so the debugger has access to the exception
271 # instance (e in this case)
272 try:
273 raise CompositeError(msg, elist)
274 except CompositeError, e:
275 raise e
276
@@ -1,960 +1,975
1 """A semi-synchronous Client for the ZMQ controller"""
1 """A semi-synchronous Client for the ZMQ controller"""
2 #-----------------------------------------------------------------------------
2 #-----------------------------------------------------------------------------
3 # Copyright (C) 2010 The IPython Development Team
3 # Copyright (C) 2010 The IPython Development Team
4 #
4 #
5 # Distributed under the terms of the BSD License. The full license is in
5 # Distributed under the terms of the BSD License. The full license is in
6 # the file COPYING, distributed as part of this software.
6 # the file COPYING, distributed as part of this software.
7 #-----------------------------------------------------------------------------
7 #-----------------------------------------------------------------------------
8
8
9 #-----------------------------------------------------------------------------
9 #-----------------------------------------------------------------------------
10 # Imports
10 # Imports
11 #-----------------------------------------------------------------------------
11 #-----------------------------------------------------------------------------
12
12
13 from __future__ import print_function
13 from __future__ import print_function
14
14
15 import os
15 import os
16 import time
16 import time
17 from getpass import getpass
17 from getpass import getpass
18 from pprint import pprint
18 from pprint import pprint
19
19
20 import zmq
20 import zmq
21 from zmq.eventloop import ioloop, zmqstream
21 from zmq.eventloop import ioloop, zmqstream
22
22
23 from IPython.external.decorator import decorator
23 from IPython.external.decorator import decorator
24 from IPython.zmq import tunnel
24 from IPython.zmq import tunnel
25
25
26 import streamsession as ss
26 import streamsession as ss
27 # from remotenamespace import RemoteNamespace
27 # from remotenamespace import RemoteNamespace
28 from view import DirectView, LoadBalancedView
28 from view import DirectView, LoadBalancedView
29 from dependency import Dependency, depend, require
29 from dependency import Dependency, depend, require
30 import error
30
31
31 def _push(ns):
32 def _push(ns):
32 globals().update(ns)
33 globals().update(ns)
33
34
34 def _pull(keys):
35 def _pull(keys):
35 g = globals()
36 g = globals()
36 if isinstance(keys, (list,tuple, set)):
37 if isinstance(keys, (list,tuple, set)):
37 for key in keys:
38 for key in keys:
38 if not g.has_key(key):
39 if not g.has_key(key):
39 raise NameError("name '%s' is not defined"%key)
40 raise NameError("name '%s' is not defined"%key)
40 return map(g.get, keys)
41 return map(g.get, keys)
41 else:
42 else:
42 if not g.has_key(keys):
43 if not g.has_key(keys):
43 raise NameError("name '%s' is not defined"%keys)
44 raise NameError("name '%s' is not defined"%keys)
44 return g.get(keys)
45 return g.get(keys)
45
46
46 def _clear():
47 def _clear():
47 globals().clear()
48 globals().clear()
48
49
49 def execute(code):
50 def execute(code):
50 exec code in globals()
51 exec code in globals()
51
52
52 #--------------------------------------------------------------------------
53 #--------------------------------------------------------------------------
53 # Decorators for Client methods
54 # Decorators for Client methods
54 #--------------------------------------------------------------------------
55 #--------------------------------------------------------------------------
55
56
56 @decorator
57 @decorator
57 def spinfirst(f, self, *args, **kwargs):
58 def spinfirst(f, self, *args, **kwargs):
58 """Call spin() to sync state prior to calling the method."""
59 """Call spin() to sync state prior to calling the method."""
59 self.spin()
60 self.spin()
60 return f(self, *args, **kwargs)
61 return f(self, *args, **kwargs)
61
62
62 @decorator
63 @decorator
63 def defaultblock(f, self, *args, **kwargs):
64 def defaultblock(f, self, *args, **kwargs):
64 """Default to self.block; preserve self.block."""
65 """Default to self.block; preserve self.block."""
65 block = kwargs.get('block',None)
66 block = kwargs.get('block',None)
66 block = self.block if block is None else block
67 block = self.block if block is None else block
67 saveblock = self.block
68 saveblock = self.block
68 self.block = block
69 self.block = block
69 ret = f(self, *args, **kwargs)
70 ret = f(self, *args, **kwargs)
70 self.block = saveblock
71 self.block = saveblock
71 return ret
72 return ret
72
73
73 def remote(client, bound=False, block=None, targets=None):
74 def remote(client, bound=False, block=None, targets=None):
74 """Turn a function into a remote function.
75 """Turn a function into a remote function.
75
76
76 This method can be used for map:
77 This method can be used for map:
77
78
78 >>> @remote(client,block=True)
79 >>> @remote(client,block=True)
79 def func(a)
80 def func(a)
80 """
81 """
81 def remote_function(f):
82 def remote_function(f):
82 return RemoteFunction(client, f, bound, block, targets)
83 return RemoteFunction(client, f, bound, block, targets)
83 return remote_function
84 return remote_function
84
85
85 #--------------------------------------------------------------------------
86 #--------------------------------------------------------------------------
86 # Classes
87 # Classes
87 #--------------------------------------------------------------------------
88 #--------------------------------------------------------------------------
88
89
89 class RemoteFunction(object):
90 class RemoteFunction(object):
90 """Turn an existing function into a remote function.
91 """Turn an existing function into a remote function.
91
92
92 Parameters
93 Parameters
93 ----------
94 ----------
94
95
95 client : Client instance
96 client : Client instance
96 The client to be used to connect to engines
97 The client to be used to connect to engines
97 f : callable
98 f : callable
98 The function to be wrapped into a remote function
99 The function to be wrapped into a remote function
99 bound : bool [default: False]
100 bound : bool [default: False]
100 Whether the affect the remote namespace when called
101 Whether the affect the remote namespace when called
101 block : bool [default: None]
102 block : bool [default: None]
102 Whether to wait for results or not. The default behavior is
103 Whether to wait for results or not. The default behavior is
103 to use the current `block` attribute of `client`
104 to use the current `block` attribute of `client`
104 targets : valid target list [default: all]
105 targets : valid target list [default: all]
105 The targets on which to execute.
106 The targets on which to execute.
106 """
107 """
107
108
108 client = None # the remote connection
109 client = None # the remote connection
109 func = None # the wrapped function
110 func = None # the wrapped function
110 block = None # whether to block
111 block = None # whether to block
111 bound = None # whether to affect the namespace
112 bound = None # whether to affect the namespace
112 targets = None # where to execute
113 targets = None # where to execute
113
114
114 def __init__(self, client, f, bound=False, block=None, targets=None):
115 def __init__(self, client, f, bound=False, block=None, targets=None):
115 self.client = client
116 self.client = client
116 self.func = f
117 self.func = f
117 self.block=block
118 self.block=block
118 self.bound=bound
119 self.bound=bound
119 self.targets=targets
120 self.targets=targets
120
121
121 def __call__(self, *args, **kwargs):
122 def __call__(self, *args, **kwargs):
122 return self.client.apply(self.func, args=args, kwargs=kwargs,
123 return self.client.apply(self.func, args=args, kwargs=kwargs,
123 block=self.block, targets=self.targets, bound=self.bound)
124 block=self.block, targets=self.targets, bound=self.bound)
124
125
125
126
126 class AbortedTask(object):
127 class AbortedTask(object):
127 """A basic wrapper object describing an aborted task."""
128 """A basic wrapper object describing an aborted task."""
128 def __init__(self, msg_id):
129 def __init__(self, msg_id):
129 self.msg_id = msg_id
130 self.msg_id = msg_id
130
131
131 class ControllerError(Exception):
132 class ResultDict(dict):
132 """Exception Class for errors in the controller (not the Engine)."""
133 """A subclass of dict that raises errors if it has them."""
133 def __init__(self, etype, evalue, tb):
134 def __getitem__(self, key):
134 self.etype = etype
135 res = dict.__getitem__(self, key)
135 self.evalue = evalue
136 if isinstance(res, error.KernelError):
136 self.traceback=tb
137 raise res
138 return res
137
139
138 class Client(object):
140 class Client(object):
139 """A semi-synchronous client to the IPython ZMQ controller
141 """A semi-synchronous client to the IPython ZMQ controller
140
142
141 Parameters
143 Parameters
142 ----------
144 ----------
143
145
144 addr : bytes; zmq url, e.g. 'tcp://127.0.0.1:10101'
146 addr : bytes; zmq url, e.g. 'tcp://127.0.0.1:10101'
145 The address of the controller's registration socket.
147 The address of the controller's registration socket.
146 [Default: 'tcp://127.0.0.1:10101']
148 [Default: 'tcp://127.0.0.1:10101']
147 context : zmq.Context
149 context : zmq.Context
148 Pass an existing zmq.Context instance, otherwise the client will create its own
150 Pass an existing zmq.Context instance, otherwise the client will create its own
149 username : bytes
151 username : bytes
150 set username to be passed to the Session object
152 set username to be passed to the Session object
151 debug : bool
153 debug : bool
152 flag for lots of message printing for debug purposes
154 flag for lots of message printing for debug purposes
153
155
154 #-------------- ssh related args ----------------
156 #-------------- ssh related args ----------------
155 # These are args for configuring the ssh tunnel to be used
157 # These are args for configuring the ssh tunnel to be used
156 # credentials are used to forward connections over ssh to the Controller
158 # credentials are used to forward connections over ssh to the Controller
157 # Note that the ip given in `addr` needs to be relative to sshserver
159 # Note that the ip given in `addr` needs to be relative to sshserver
158 # The most basic case is to leave addr as pointing to localhost (127.0.0.1),
160 # The most basic case is to leave addr as pointing to localhost (127.0.0.1),
159 # and set sshserver as the same machine the Controller is on. However,
161 # and set sshserver as the same machine the Controller is on. However,
160 # the only requirement is that sshserver is able to see the Controller
162 # the only requirement is that sshserver is able to see the Controller
161 # (i.e. is within the same trusted network).
163 # (i.e. is within the same trusted network).
162
164
163 sshserver : str
165 sshserver : str
164 A string of the form passed to ssh, i.e. 'server.tld' or 'user@server.tld:port'
166 A string of the form passed to ssh, i.e. 'server.tld' or 'user@server.tld:port'
165 If keyfile or password is specified, and this is not, it will default to
167 If keyfile or password is specified, and this is not, it will default to
166 the ip given in addr.
168 the ip given in addr.
167 sshkey : str; path to public ssh key file
169 sshkey : str; path to public ssh key file
168 This specifies a key to be used in ssh login, default None.
170 This specifies a key to be used in ssh login, default None.
169 Regular default ssh keys will be used without specifying this argument.
171 Regular default ssh keys will be used without specifying this argument.
170 password : str;
172 password : str;
171 Your ssh password to sshserver. Note that if this is left None,
173 Your ssh password to sshserver. Note that if this is left None,
172 you will be prompted for it if passwordless key based login is unavailable.
174 you will be prompted for it if passwordless key based login is unavailable.
173
175
174 #------- exec authentication args -------
176 #------- exec authentication args -------
175 # If even localhost is untrusted, you can have some protection against
177 # If even localhost is untrusted, you can have some protection against
176 # unauthorized execution by using a key. Messages are still sent
178 # unauthorized execution by using a key. Messages are still sent
177 # as cleartext, so if someone can snoop your loopback traffic this will
179 # as cleartext, so if someone can snoop your loopback traffic this will
178 # not help anything.
180 # not help anything.
179
181
180 exec_key : str
182 exec_key : str
181 an authentication key or file containing a key
183 an authentication key or file containing a key
182 default: None
184 default: None
183
185
184
186
185 Attributes
187 Attributes
186 ----------
188 ----------
187 ids : set of int engine IDs
189 ids : set of int engine IDs
188 requesting the ids attribute always synchronizes
190 requesting the ids attribute always synchronizes
189 the registration state. To request ids without synchronization,
191 the registration state. To request ids without synchronization,
190 use semi-private _ids attributes.
192 use semi-private _ids attributes.
191
193
192 history : list of msg_ids
194 history : list of msg_ids
193 a list of msg_ids, keeping track of all the execution
195 a list of msg_ids, keeping track of all the execution
194 messages you have submitted in order.
196 messages you have submitted in order.
195
197
196 outstanding : set of msg_ids
198 outstanding : set of msg_ids
197 a set of msg_ids that have been submitted, but whose
199 a set of msg_ids that have been submitted, but whose
198 results have not yet been received.
200 results have not yet been received.
199
201
200 results : dict
202 results : dict
201 a dict of all our results, keyed by msg_id
203 a dict of all our results, keyed by msg_id
202
204
203 block : bool
205 block : bool
204 determines default behavior when block not specified
206 determines default behavior when block not specified
205 in execution methods
207 in execution methods
206
208
207 Methods
209 Methods
208 -------
210 -------
209 spin : flushes incoming results and registration state changes
211 spin : flushes incoming results and registration state changes
210 control methods spin, and requesting `ids` also ensures up to date
212 control methods spin, and requesting `ids` also ensures up to date
211
213
212 barrier : wait on one or more msg_ids
214 barrier : wait on one or more msg_ids
213
215
214 execution methods: apply/apply_bound/apply_to/apply_bound
216 execution methods: apply/apply_bound/apply_to/apply_bound
215 legacy: execute, run
217 legacy: execute, run
216
218
217 query methods: queue_status, get_result, purge
219 query methods: queue_status, get_result, purge
218
220
219 control methods: abort, kill
221 control methods: abort, kill
220
222
221 """
223 """
222
224
223
225
224 _connected=False
226 _connected=False
225 _ssh=False
227 _ssh=False
226 _engines=None
228 _engines=None
227 _addr='tcp://127.0.0.1:10101'
229 _addr='tcp://127.0.0.1:10101'
228 _registration_socket=None
230 _registration_socket=None
229 _query_socket=None
231 _query_socket=None
230 _control_socket=None
232 _control_socket=None
231 _notification_socket=None
233 _notification_socket=None
232 _mux_socket=None
234 _mux_socket=None
233 _task_socket=None
235 _task_socket=None
234 block = False
236 block = False
235 outstanding=None
237 outstanding=None
236 results = None
238 results = None
237 history = None
239 history = None
238 debug = False
240 debug = False
239
241
240 def __init__(self, addr='tcp://127.0.0.1:10101', context=None, username=None, debug=False,
242 def __init__(self, addr='tcp://127.0.0.1:10101', context=None, username=None, debug=False,
241 sshserver=None, sshkey=None, password=None, paramiko=None,
243 sshserver=None, sshkey=None, password=None, paramiko=None,
242 exec_key=None,):
244 exec_key=None,):
243 if context is None:
245 if context is None:
244 context = zmq.Context()
246 context = zmq.Context()
245 self.context = context
247 self.context = context
246 self._addr = addr
248 self._addr = addr
247 self._ssh = bool(sshserver or sshkey or password)
249 self._ssh = bool(sshserver or sshkey or password)
248 if self._ssh and sshserver is None:
250 if self._ssh and sshserver is None:
249 # default to the same
251 # default to the same
250 sshserver = addr.split('://')[1].split(':')[0]
252 sshserver = addr.split('://')[1].split(':')[0]
251 if self._ssh and password is None:
253 if self._ssh and password is None:
252 if tunnel.try_passwordless_ssh(sshserver, sshkey, paramiko):
254 if tunnel.try_passwordless_ssh(sshserver, sshkey, paramiko):
253 password=False
255 password=False
254 else:
256 else:
255 password = getpass("SSH Password for %s: "%sshserver)
257 password = getpass("SSH Password for %s: "%sshserver)
256 ssh_kwargs = dict(keyfile=sshkey, password=password, paramiko=paramiko)
258 ssh_kwargs = dict(keyfile=sshkey, password=password, paramiko=paramiko)
257
259
258 if exec_key is not None and os.path.isfile(exec_key):
260 if exec_key is not None and os.path.isfile(exec_key):
259 arg = 'keyfile'
261 arg = 'keyfile'
260 else:
262 else:
261 arg = 'key'
263 arg = 'key'
262 key_arg = {arg:exec_key}
264 key_arg = {arg:exec_key}
263 if username is None:
265 if username is None:
264 self.session = ss.StreamSession(**key_arg)
266 self.session = ss.StreamSession(**key_arg)
265 else:
267 else:
266 self.session = ss.StreamSession(username, **key_arg)
268 self.session = ss.StreamSession(username, **key_arg)
267 self._registration_socket = self.context.socket(zmq.XREQ)
269 self._registration_socket = self.context.socket(zmq.XREQ)
268 self._registration_socket.setsockopt(zmq.IDENTITY, self.session.session)
270 self._registration_socket.setsockopt(zmq.IDENTITY, self.session.session)
269 if self._ssh:
271 if self._ssh:
270 tunnel.tunnel_connection(self._registration_socket, addr, sshserver, **ssh_kwargs)
272 tunnel.tunnel_connection(self._registration_socket, addr, sshserver, **ssh_kwargs)
271 else:
273 else:
272 self._registration_socket.connect(addr)
274 self._registration_socket.connect(addr)
273 self._engines = {}
275 self._engines = {}
274 self._ids = set()
276 self._ids = set()
275 self.outstanding=set()
277 self.outstanding=set()
276 self.results = {}
278 self.results = {}
277 self.history = []
279 self.history = []
278 self.debug = debug
280 self.debug = debug
279 self.session.debug = debug
281 self.session.debug = debug
280
282
281 self._notification_handlers = {'registration_notification' : self._register_engine,
283 self._notification_handlers = {'registration_notification' : self._register_engine,
282 'unregistration_notification' : self._unregister_engine,
284 'unregistration_notification' : self._unregister_engine,
283 }
285 }
284 self._queue_handlers = {'execute_reply' : self._handle_execute_reply,
286 self._queue_handlers = {'execute_reply' : self._handle_execute_reply,
285 'apply_reply' : self._handle_apply_reply}
287 'apply_reply' : self._handle_apply_reply}
286 self._connect(sshserver, ssh_kwargs)
288 self._connect(sshserver, ssh_kwargs)
287
289
288
290
289 @property
291 @property
290 def ids(self):
292 def ids(self):
291 """Always up to date ids property."""
293 """Always up to date ids property."""
292 self._flush_notifications()
294 self._flush_notifications()
293 return self._ids
295 return self._ids
294
296
295 def _update_engines(self, engines):
297 def _update_engines(self, engines):
296 """Update our engines dict and _ids from a dict of the form: {id:uuid}."""
298 """Update our engines dict and _ids from a dict of the form: {id:uuid}."""
297 for k,v in engines.iteritems():
299 for k,v in engines.iteritems():
298 eid = int(k)
300 eid = int(k)
299 self._engines[eid] = bytes(v) # force not unicode
301 self._engines[eid] = bytes(v) # force not unicode
300 self._ids.add(eid)
302 self._ids.add(eid)
301
303
302 def _build_targets(self, targets):
304 def _build_targets(self, targets):
303 """Turn valid target IDs or 'all' into two lists:
305 """Turn valid target IDs or 'all' into two lists:
304 (int_ids, uuids).
306 (int_ids, uuids).
305 """
307 """
306 if targets is None:
308 if targets is None:
307 targets = self._ids
309 targets = self._ids
308 elif isinstance(targets, str):
310 elif isinstance(targets, str):
309 if targets.lower() == 'all':
311 if targets.lower() == 'all':
310 targets = self._ids
312 targets = self._ids
311 else:
313 else:
312 raise TypeError("%r not valid str target, must be 'all'"%(targets))
314 raise TypeError("%r not valid str target, must be 'all'"%(targets))
313 elif isinstance(targets, int):
315 elif isinstance(targets, int):
314 targets = [targets]
316 targets = [targets]
315 return [self._engines[t] for t in targets], list(targets)
317 return [self._engines[t] for t in targets], list(targets)
316
318
317 def _connect(self, sshserver, ssh_kwargs):
319 def _connect(self, sshserver, ssh_kwargs):
318 """setup all our socket connections to the controller. This is called from
320 """setup all our socket connections to the controller. This is called from
319 __init__."""
321 __init__."""
320 if self._connected:
322 if self._connected:
321 return
323 return
322 self._connected=True
324 self._connected=True
323
325
324 def connect_socket(s, addr):
326 def connect_socket(s, addr):
325 if self._ssh:
327 if self._ssh:
326 return tunnel.tunnel_connection(s, addr, sshserver, **ssh_kwargs)
328 return tunnel.tunnel_connection(s, addr, sshserver, **ssh_kwargs)
327 else:
329 else:
328 return s.connect(addr)
330 return s.connect(addr)
329
331
330 self.session.send(self._registration_socket, 'connection_request')
332 self.session.send(self._registration_socket, 'connection_request')
331 idents,msg = self.session.recv(self._registration_socket,mode=0)
333 idents,msg = self.session.recv(self._registration_socket,mode=0)
332 if self.debug:
334 if self.debug:
333 pprint(msg)
335 pprint(msg)
334 msg = ss.Message(msg)
336 msg = ss.Message(msg)
335 content = msg.content
337 content = msg.content
336 if content.status == 'ok':
338 if content.status == 'ok':
337 if content.queue:
339 if content.queue:
338 self._mux_socket = self.context.socket(zmq.PAIR)
340 self._mux_socket = self.context.socket(zmq.PAIR)
339 self._mux_socket.setsockopt(zmq.IDENTITY, self.session.session)
341 self._mux_socket.setsockopt(zmq.IDENTITY, self.session.session)
340 connect_socket(self._mux_socket, content.queue)
342 connect_socket(self._mux_socket, content.queue)
341 if content.task:
343 if content.task:
342 self._task_socket = self.context.socket(zmq.PAIR)
344 self._task_socket = self.context.socket(zmq.PAIR)
343 self._task_socket.setsockopt(zmq.IDENTITY, self.session.session)
345 self._task_socket.setsockopt(zmq.IDENTITY, self.session.session)
344 connect_socket(self._task_socket, content.task)
346 connect_socket(self._task_socket, content.task)
345 if content.notification:
347 if content.notification:
346 self._notification_socket = self.context.socket(zmq.SUB)
348 self._notification_socket = self.context.socket(zmq.SUB)
347 connect_socket(self._notification_socket, content.notification)
349 connect_socket(self._notification_socket, content.notification)
348 self._notification_socket.setsockopt(zmq.SUBSCRIBE, "")
350 self._notification_socket.setsockopt(zmq.SUBSCRIBE, "")
349 if content.query:
351 if content.query:
350 self._query_socket = self.context.socket(zmq.PAIR)
352 self._query_socket = self.context.socket(zmq.PAIR)
351 self._query_socket.setsockopt(zmq.IDENTITY, self.session.session)
353 self._query_socket.setsockopt(zmq.IDENTITY, self.session.session)
352 connect_socket(self._query_socket, content.query)
354 connect_socket(self._query_socket, content.query)
353 if content.control:
355 if content.control:
354 self._control_socket = self.context.socket(zmq.PAIR)
356 self._control_socket = self.context.socket(zmq.PAIR)
355 self._control_socket.setsockopt(zmq.IDENTITY, self.session.session)
357 self._control_socket.setsockopt(zmq.IDENTITY, self.session.session)
356 connect_socket(self._control_socket, content.control)
358 connect_socket(self._control_socket, content.control)
357 self._update_engines(dict(content.engines))
359 self._update_engines(dict(content.engines))
358
360
359 else:
361 else:
360 self._connected = False
362 self._connected = False
361 raise Exception("Failed to connect!")
363 raise Exception("Failed to connect!")
362
364
363 #--------------------------------------------------------------------------
365 #--------------------------------------------------------------------------
364 # handlers and callbacks for incoming messages
366 # handlers and callbacks for incoming messages
365 #--------------------------------------------------------------------------
367 #--------------------------------------------------------------------------
366
368
367 def _register_engine(self, msg):
369 def _register_engine(self, msg):
368 """Register a new engine, and update our connection info."""
370 """Register a new engine, and update our connection info."""
369 content = msg['content']
371 content = msg['content']
370 eid = content['id']
372 eid = content['id']
371 d = {eid : content['queue']}
373 d = {eid : content['queue']}
372 self._update_engines(d)
374 self._update_engines(d)
373 self._ids.add(int(eid))
375 self._ids.add(int(eid))
374
376
375 def _unregister_engine(self, msg):
377 def _unregister_engine(self, msg):
376 """Unregister an engine that has died."""
378 """Unregister an engine that has died."""
377 content = msg['content']
379 content = msg['content']
378 eid = int(content['id'])
380 eid = int(content['id'])
379 if eid in self._ids:
381 if eid in self._ids:
380 self._ids.remove(eid)
382 self._ids.remove(eid)
381 self._engines.pop(eid)
383 self._engines.pop(eid)
382
384
383 def _handle_execute_reply(self, msg):
385 def _handle_execute_reply(self, msg):
384 """Save the reply to an execute_request into our results."""
386 """Save the reply to an execute_request into our results."""
385 parent = msg['parent_header']
387 parent = msg['parent_header']
386 msg_id = parent['msg_id']
388 msg_id = parent['msg_id']
387 if msg_id not in self.outstanding:
389 if msg_id not in self.outstanding:
388 print("got unknown result: %s"%msg_id)
390 print("got unknown result: %s"%msg_id)
389 else:
391 else:
390 self.outstanding.remove(msg_id)
392 self.outstanding.remove(msg_id)
391 self.results[msg_id] = ss.unwrap_exception(msg['content'])
393 self.results[msg_id] = ss.unwrap_exception(msg['content'])
392
394
393 def _handle_apply_reply(self, msg):
395 def _handle_apply_reply(self, msg):
394 """Save the reply to an apply_request into our results."""
396 """Save the reply to an apply_request into our results."""
395 parent = msg['parent_header']
397 parent = msg['parent_header']
396 msg_id = parent['msg_id']
398 msg_id = parent['msg_id']
397 if msg_id not in self.outstanding:
399 if msg_id not in self.outstanding:
398 print ("got unknown result: %s"%msg_id)
400 print ("got unknown result: %s"%msg_id)
399 else:
401 else:
400 self.outstanding.remove(msg_id)
402 self.outstanding.remove(msg_id)
401 content = msg['content']
403 content = msg['content']
402 if content['status'] == 'ok':
404 if content['status'] == 'ok':
403 self.results[msg_id] = ss.unserialize_object(msg['buffers'])
405 self.results[msg_id] = ss.unserialize_object(msg['buffers'])
404 elif content['status'] == 'aborted':
406 elif content['status'] == 'aborted':
405 self.results[msg_id] = AbortedTask(msg_id)
407 self.results[msg_id] = error.AbortedTask(msg_id)
406 elif content['status'] == 'resubmitted':
408 elif content['status'] == 'resubmitted':
407 # TODO: handle resubmission
409 # TODO: handle resubmission
408 pass
410 pass
409 else:
411 else:
410 self.results[msg_id] = ss.unwrap_exception(content)
412 e = ss.unwrap_exception(content)
413 e_uuid = e.engine_info['engineid']
414 for k,v in self._engines.iteritems():
415 if v == e_uuid:
416 e.engine_info['engineid'] = k
417 break
418 self.results[msg_id] = e
411
419
412 def _flush_notifications(self):
420 def _flush_notifications(self):
413 """Flush notifications of engine registrations waiting
421 """Flush notifications of engine registrations waiting
414 in ZMQ queue."""
422 in ZMQ queue."""
415 msg = self.session.recv(self._notification_socket, mode=zmq.NOBLOCK)
423 msg = self.session.recv(self._notification_socket, mode=zmq.NOBLOCK)
416 while msg is not None:
424 while msg is not None:
417 if self.debug:
425 if self.debug:
418 pprint(msg)
426 pprint(msg)
419 msg = msg[-1]
427 msg = msg[-1]
420 msg_type = msg['msg_type']
428 msg_type = msg['msg_type']
421 handler = self._notification_handlers.get(msg_type, None)
429 handler = self._notification_handlers.get(msg_type, None)
422 if handler is None:
430 if handler is None:
423 raise Exception("Unhandled message type: %s"%msg.msg_type)
431 raise Exception("Unhandled message type: %s"%msg.msg_type)
424 else:
432 else:
425 handler(msg)
433 handler(msg)
426 msg = self.session.recv(self._notification_socket, mode=zmq.NOBLOCK)
434 msg = self.session.recv(self._notification_socket, mode=zmq.NOBLOCK)
427
435
428 def _flush_results(self, sock):
436 def _flush_results(self, sock):
429 """Flush task or queue results waiting in ZMQ queue."""
437 """Flush task or queue results waiting in ZMQ queue."""
430 msg = self.session.recv(sock, mode=zmq.NOBLOCK)
438 msg = self.session.recv(sock, mode=zmq.NOBLOCK)
431 while msg is not None:
439 while msg is not None:
432 if self.debug:
440 if self.debug:
433 pprint(msg)
441 pprint(msg)
434 msg = msg[-1]
442 msg = msg[-1]
435 msg_type = msg['msg_type']
443 msg_type = msg['msg_type']
436 handler = self._queue_handlers.get(msg_type, None)
444 handler = self._queue_handlers.get(msg_type, None)
437 if handler is None:
445 if handler is None:
438 raise Exception("Unhandled message type: %s"%msg.msg_type)
446 raise Exception("Unhandled message type: %s"%msg.msg_type)
439 else:
447 else:
440 handler(msg)
448 handler(msg)
441 msg = self.session.recv(sock, mode=zmq.NOBLOCK)
449 msg = self.session.recv(sock, mode=zmq.NOBLOCK)
442
450
443 def _flush_control(self, sock):
451 def _flush_control(self, sock):
444 """Flush replies from the control channel waiting
452 """Flush replies from the control channel waiting
445 in the ZMQ queue.
453 in the ZMQ queue.
446
454
447 Currently: ignore them."""
455 Currently: ignore them."""
448 msg = self.session.recv(sock, mode=zmq.NOBLOCK)
456 msg = self.session.recv(sock, mode=zmq.NOBLOCK)
449 while msg is not None:
457 while msg is not None:
450 if self.debug:
458 if self.debug:
451 pprint(msg)
459 pprint(msg)
452 msg = self.session.recv(sock, mode=zmq.NOBLOCK)
460 msg = self.session.recv(sock, mode=zmq.NOBLOCK)
453
461
454 #--------------------------------------------------------------------------
462 #--------------------------------------------------------------------------
455 # getitem
463 # getitem
456 #--------------------------------------------------------------------------
464 #--------------------------------------------------------------------------
457
465
458 def __getitem__(self, key):
466 def __getitem__(self, key):
459 """Dict access returns DirectView multiplexer objects or,
467 """Dict access returns DirectView multiplexer objects or,
460 if key is None, a LoadBalancedView."""
468 if key is None, a LoadBalancedView."""
461 if key is None:
469 if key is None:
462 return LoadBalancedView(self)
470 return LoadBalancedView(self)
463 if isinstance(key, int):
471 if isinstance(key, int):
464 if key not in self.ids:
472 if key not in self.ids:
465 raise IndexError("No such engine: %i"%key)
473 raise IndexError("No such engine: %i"%key)
466 return DirectView(self, key)
474 return DirectView(self, key)
467
475
468 if isinstance(key, slice):
476 if isinstance(key, slice):
469 indices = range(len(self.ids))[key]
477 indices = range(len(self.ids))[key]
470 ids = sorted(self._ids)
478 ids = sorted(self._ids)
471 key = [ ids[i] for i in indices ]
479 key = [ ids[i] for i in indices ]
472 # newkeys = sorted(self._ids)[thekeys[k]]
480 # newkeys = sorted(self._ids)[thekeys[k]]
473
481
474 if isinstance(key, (tuple, list, xrange)):
482 if isinstance(key, (tuple, list, xrange)):
475 _,targets = self._build_targets(list(key))
483 _,targets = self._build_targets(list(key))
476 return DirectView(self, targets)
484 return DirectView(self, targets)
477 else:
485 else:
478 raise TypeError("key by int/iterable of ints only, not %s"%(type(key)))
486 raise TypeError("key by int/iterable of ints only, not %s"%(type(key)))
479
487
480 #--------------------------------------------------------------------------
488 #--------------------------------------------------------------------------
481 # Begin public methods
489 # Begin public methods
482 #--------------------------------------------------------------------------
490 #--------------------------------------------------------------------------
483
491
484 def spin(self):
492 def spin(self):
485 """Flush any registration notifications and execution results
493 """Flush any registration notifications and execution results
486 waiting in the ZMQ queue.
494 waiting in the ZMQ queue.
487 """
495 """
488 if self._notification_socket:
496 if self._notification_socket:
489 self._flush_notifications()
497 self._flush_notifications()
490 if self._mux_socket:
498 if self._mux_socket:
491 self._flush_results(self._mux_socket)
499 self._flush_results(self._mux_socket)
492 if self._task_socket:
500 if self._task_socket:
493 self._flush_results(self._task_socket)
501 self._flush_results(self._task_socket)
494 if self._control_socket:
502 if self._control_socket:
495 self._flush_control(self._control_socket)
503 self._flush_control(self._control_socket)
496
504
497 def barrier(self, msg_ids=None, timeout=-1):
505 def barrier(self, msg_ids=None, timeout=-1):
498 """waits on one or more `msg_ids`, for up to `timeout` seconds.
506 """waits on one or more `msg_ids`, for up to `timeout` seconds.
499
507
500 Parameters
508 Parameters
501 ----------
509 ----------
502 msg_ids : int, str, or list of ints and/or strs
510 msg_ids : int, str, or list of ints and/or strs
503 ints are indices to self.history
511 ints are indices to self.history
504 strs are msg_ids
512 strs are msg_ids
505 default: wait on all outstanding messages
513 default: wait on all outstanding messages
506 timeout : float
514 timeout : float
507 a time in seconds, after which to give up.
515 a time in seconds, after which to give up.
508 default is -1, which means no timeout
516 default is -1, which means no timeout
509
517
510 Returns
518 Returns
511 -------
519 -------
512 True : when all msg_ids are done
520 True : when all msg_ids are done
513 False : timeout reached, some msg_ids still outstanding
521 False : timeout reached, some msg_ids still outstanding
514 """
522 """
515 tic = time.time()
523 tic = time.time()
516 if msg_ids is None:
524 if msg_ids is None:
517 theids = self.outstanding
525 theids = self.outstanding
518 else:
526 else:
519 if isinstance(msg_ids, (int, str)):
527 if isinstance(msg_ids, (int, str)):
520 msg_ids = [msg_ids]
528 msg_ids = [msg_ids]
521 theids = set()
529 theids = set()
522 for msg_id in msg_ids:
530 for msg_id in msg_ids:
523 if isinstance(msg_id, int):
531 if isinstance(msg_id, int):
524 msg_id = self.history[msg_id]
532 msg_id = self.history[msg_id]
525 theids.add(msg_id)
533 theids.add(msg_id)
526 self.spin()
534 self.spin()
527 while theids.intersection(self.outstanding):
535 while theids.intersection(self.outstanding):
528 if timeout >= 0 and ( time.time()-tic ) > timeout:
536 if timeout >= 0 and ( time.time()-tic ) > timeout:
529 break
537 break
530 time.sleep(1e-3)
538 time.sleep(1e-3)
531 self.spin()
539 self.spin()
532 return len(theids.intersection(self.outstanding)) == 0
540 return len(theids.intersection(self.outstanding)) == 0
533
541
534 #--------------------------------------------------------------------------
542 #--------------------------------------------------------------------------
535 # Control methods
543 # Control methods
536 #--------------------------------------------------------------------------
544 #--------------------------------------------------------------------------
537
545
538 @spinfirst
546 @spinfirst
539 @defaultblock
547 @defaultblock
540 def clear(self, targets=None, block=None):
548 def clear(self, targets=None, block=None):
541 """Clear the namespace in target(s)."""
549 """Clear the namespace in target(s)."""
542 targets = self._build_targets(targets)[0]
550 targets = self._build_targets(targets)[0]
543 for t in targets:
551 for t in targets:
544 self.session.send(self._control_socket, 'clear_request', content={}, ident=t)
552 self.session.send(self._control_socket, 'clear_request', content={}, ident=t)
545 error = False
553 error = False
546 if self.block:
554 if self.block:
547 for i in range(len(targets)):
555 for i in range(len(targets)):
548 idents,msg = self.session.recv(self._control_socket,0)
556 idents,msg = self.session.recv(self._control_socket,0)
549 if self.debug:
557 if self.debug:
550 pprint(msg)
558 pprint(msg)
551 if msg['content']['status'] != 'ok':
559 if msg['content']['status'] != 'ok':
552 error = ss.unwrap_exception(msg['content'])
560 error = ss.unwrap_exception(msg['content'])
553 if error:
561 if error:
554 return error
562 return error
555
563
556
564
557 @spinfirst
565 @spinfirst
558 @defaultblock
566 @defaultblock
559 def abort(self, msg_ids = None, targets=None, block=None):
567 def abort(self, msg_ids = None, targets=None, block=None):
560 """Abort the execution queues of target(s)."""
568 """Abort the execution queues of target(s)."""
561 targets = self._build_targets(targets)[0]
569 targets = self._build_targets(targets)[0]
562 if isinstance(msg_ids, basestring):
570 if isinstance(msg_ids, basestring):
563 msg_ids = [msg_ids]
571 msg_ids = [msg_ids]
564 content = dict(msg_ids=msg_ids)
572 content = dict(msg_ids=msg_ids)
565 for t in targets:
573 for t in targets:
566 self.session.send(self._control_socket, 'abort_request',
574 self.session.send(self._control_socket, 'abort_request',
567 content=content, ident=t)
575 content=content, ident=t)
568 error = False
576 error = False
569 if self.block:
577 if self.block:
570 for i in range(len(targets)):
578 for i in range(len(targets)):
571 idents,msg = self.session.recv(self._control_socket,0)
579 idents,msg = self.session.recv(self._control_socket,0)
572 if self.debug:
580 if self.debug:
573 pprint(msg)
581 pprint(msg)
574 if msg['content']['status'] != 'ok':
582 if msg['content']['status'] != 'ok':
575 error = ss.unwrap_exception(msg['content'])
583 error = ss.unwrap_exception(msg['content'])
576 if error:
584 if error:
577 return error
585 return error
578
586
579 @spinfirst
587 @spinfirst
580 @defaultblock
588 @defaultblock
581 def shutdown(self, targets=None, restart=False, controller=False, block=None):
589 def shutdown(self, targets=None, restart=False, controller=False, block=None):
582 """Terminates one or more engine processes, optionally including the controller."""
590 """Terminates one or more engine processes, optionally including the controller."""
583 if controller:
591 if controller:
584 targets = 'all'
592 targets = 'all'
585 targets = self._build_targets(targets)[0]
593 targets = self._build_targets(targets)[0]
586 for t in targets:
594 for t in targets:
587 self.session.send(self._control_socket, 'shutdown_request',
595 self.session.send(self._control_socket, 'shutdown_request',
588 content={'restart':restart},ident=t)
596 content={'restart':restart},ident=t)
589 error = False
597 error = False
590 if block or controller:
598 if block or controller:
591 for i in range(len(targets)):
599 for i in range(len(targets)):
592 idents,msg = self.session.recv(self._control_socket,0)
600 idents,msg = self.session.recv(self._control_socket,0)
593 if self.debug:
601 if self.debug:
594 pprint(msg)
602 pprint(msg)
595 if msg['content']['status'] != 'ok':
603 if msg['content']['status'] != 'ok':
596 error = ss.unwrap_exception(msg['content'])
604 error = ss.unwrap_exception(msg['content'])
597
605
598 if controller:
606 if controller:
599 time.sleep(0.25)
607 time.sleep(0.25)
600 self.session.send(self._query_socket, 'shutdown_request')
608 self.session.send(self._query_socket, 'shutdown_request')
601 idents,msg = self.session.recv(self._query_socket, 0)
609 idents,msg = self.session.recv(self._query_socket, 0)
602 if self.debug:
610 if self.debug:
603 pprint(msg)
611 pprint(msg)
604 if msg['content']['status'] != 'ok':
612 if msg['content']['status'] != 'ok':
605 error = ss.unwrap_exception(msg['content'])
613 error = ss.unwrap_exception(msg['content'])
606
614
607 if error:
615 if error:
608 return error
616 return error
609
617
610 #--------------------------------------------------------------------------
618 #--------------------------------------------------------------------------
611 # Execution methods
619 # Execution methods
612 #--------------------------------------------------------------------------
620 #--------------------------------------------------------------------------
613
621
614 @defaultblock
622 @defaultblock
615 def execute(self, code, targets='all', block=None):
623 def execute(self, code, targets='all', block=None):
616 """Executes `code` on `targets` in blocking or nonblocking manner.
624 """Executes `code` on `targets` in blocking or nonblocking manner.
617
625
618 Parameters
626 Parameters
619 ----------
627 ----------
620 code : str
628 code : str
621 the code string to be executed
629 the code string to be executed
622 targets : int/str/list of ints/strs
630 targets : int/str/list of ints/strs
623 the engines on which to execute
631 the engines on which to execute
624 default : all
632 default : all
625 block : bool
633 block : bool
626 whether or not to wait until done to return
634 whether or not to wait until done to return
627 default: self.block
635 default: self.block
628 """
636 """
629 # block = self.block if block is None else block
637 # block = self.block if block is None else block
630 # saveblock = self.block
638 # saveblock = self.block
631 # self.block = block
639 # self.block = block
632 result = self.apply(execute, (code,), targets=targets, block=block, bound=True)
640 result = self.apply(execute, (code,), targets=targets, block=block, bound=True)
633 # self.block = saveblock
641 # self.block = saveblock
634 return result
642 return result
635
643
636 def run(self, code, block=None):
644 def run(self, code, block=None):
637 """Runs `code` on an engine.
645 """Runs `code` on an engine.
638
646
639 Calls to this are load-balanced.
647 Calls to this are load-balanced.
640
648
641 Parameters
649 Parameters
642 ----------
650 ----------
643 code : str
651 code : str
644 the code string to be executed
652 the code string to be executed
645 block : bool
653 block : bool
646 whether or not to wait until done
654 whether or not to wait until done
647
655
648 """
656 """
649 result = self.apply(execute, (code,), targets=None, block=block, bound=False)
657 result = self.apply(execute, (code,), targets=None, block=block, bound=False)
650 return result
658 return result
651
659
660 def _maybe_raise(self, result):
661 """wrapper for maybe raising an exception if apply failed."""
662 if isinstance(result, error.RemoteError):
663 raise result
664
665 return result
666
652 def apply(self, f, args=None, kwargs=None, bound=True, block=None, targets=None,
667 def apply(self, f, args=None, kwargs=None, bound=True, block=None, targets=None,
653 after=None, follow=None):
668 after=None, follow=None):
654 """Call `f(*args, **kwargs)` on a remote engine(s), returning the result.
669 """Call `f(*args, **kwargs)` on a remote engine(s), returning the result.
655
670
656 This is the central execution command for the client.
671 This is the central execution command for the client.
657
672
658 Parameters
673 Parameters
659 ----------
674 ----------
660
675
661 f : function
676 f : function
662 The fuction to be called remotely
677 The fuction to be called remotely
663 args : tuple/list
678 args : tuple/list
664 The positional arguments passed to `f`
679 The positional arguments passed to `f`
665 kwargs : dict
680 kwargs : dict
666 The keyword arguments passed to `f`
681 The keyword arguments passed to `f`
667 bound : bool (default: True)
682 bound : bool (default: True)
668 Whether to execute in the Engine(s) namespace, or in a clean
683 Whether to execute in the Engine(s) namespace, or in a clean
669 namespace not affecting the engine.
684 namespace not affecting the engine.
670 block : bool (default: self.block)
685 block : bool (default: self.block)
671 Whether to wait for the result, or return immediately.
686 Whether to wait for the result, or return immediately.
672 False:
687 False:
673 returns msg_id(s)
688 returns msg_id(s)
674 if multiple targets:
689 if multiple targets:
675 list of ids
690 list of ids
676 True:
691 True:
677 returns actual result(s) of f(*args, **kwargs)
692 returns actual result(s) of f(*args, **kwargs)
678 if multiple targets:
693 if multiple targets:
679 dict of results, by engine ID
694 dict of results, by engine ID
680 targets : int,list of ints, 'all', None
695 targets : int,list of ints, 'all', None
681 Specify the destination of the job.
696 Specify the destination of the job.
682 if None:
697 if None:
683 Submit via Task queue for load-balancing.
698 Submit via Task queue for load-balancing.
684 if 'all':
699 if 'all':
685 Run on all active engines
700 Run on all active engines
686 if list:
701 if list:
687 Run on each specified engine
702 Run on each specified engine
688 if int:
703 if int:
689 Run on single engine
704 Run on single engine
690
705
691 after : Dependency or collection of msg_ids
706 after : Dependency or collection of msg_ids
692 Only for load-balanced execution (targets=None)
707 Only for load-balanced execution (targets=None)
693 Specify a list of msg_ids as a time-based dependency.
708 Specify a list of msg_ids as a time-based dependency.
694 This job will only be run *after* the dependencies
709 This job will only be run *after* the dependencies
695 have been met.
710 have been met.
696
711
697 follow : Dependency or collection of msg_ids
712 follow : Dependency or collection of msg_ids
698 Only for load-balanced execution (targets=None)
713 Only for load-balanced execution (targets=None)
699 Specify a list of msg_ids as a location-based dependency.
714 Specify a list of msg_ids as a location-based dependency.
700 This job will only be run on an engine where this dependency
715 This job will only be run on an engine where this dependency
701 is met.
716 is met.
702
717
703 Returns
718 Returns
704 -------
719 -------
705 if block is False:
720 if block is False:
706 if single target:
721 if single target:
707 return msg_id
722 return msg_id
708 else:
723 else:
709 return list of msg_ids
724 return list of msg_ids
710 ? (should this be dict like block=True) ?
725 ? (should this be dict like block=True) ?
711 else:
726 else:
712 if single target:
727 if single target:
713 return result of f(*args, **kwargs)
728 return result of f(*args, **kwargs)
714 else:
729 else:
715 return dict of results, keyed by engine
730 return dict of results, keyed by engine
716 """
731 """
717
732
718 # defaults:
733 # defaults:
719 block = block if block is not None else self.block
734 block = block if block is not None else self.block
720 args = args if args is not None else []
735 args = args if args is not None else []
721 kwargs = kwargs if kwargs is not None else {}
736 kwargs = kwargs if kwargs is not None else {}
722
737
723 # enforce types of f,args,kwrags
738 # enforce types of f,args,kwrags
724 if not callable(f):
739 if not callable(f):
725 raise TypeError("f must be callable, not %s"%type(f))
740 raise TypeError("f must be callable, not %s"%type(f))
726 if not isinstance(args, (tuple, list)):
741 if not isinstance(args, (tuple, list)):
727 raise TypeError("args must be tuple or list, not %s"%type(args))
742 raise TypeError("args must be tuple or list, not %s"%type(args))
728 if not isinstance(kwargs, dict):
743 if not isinstance(kwargs, dict):
729 raise TypeError("kwargs must be dict, not %s"%type(kwargs))
744 raise TypeError("kwargs must be dict, not %s"%type(kwargs))
730
745
731 options = dict(bound=bound, block=block, after=after, follow=follow)
746 options = dict(bound=bound, block=block, after=after, follow=follow)
732
747
733 if targets is None:
748 if targets is None:
734 return self._apply_balanced(f, args, kwargs, **options)
749 return self._apply_balanced(f, args, kwargs, **options)
735 else:
750 else:
736 return self._apply_direct(f, args, kwargs, targets=targets, **options)
751 return self._apply_direct(f, args, kwargs, targets=targets, **options)
737
752
738 def _apply_balanced(self, f, args, kwargs, bound=True, block=None,
753 def _apply_balanced(self, f, args, kwargs, bound=True, block=None,
739 after=None, follow=None):
754 after=None, follow=None):
740 """The underlying method for applying functions in a load balanced
755 """The underlying method for applying functions in a load balanced
741 manner, via the task queue."""
756 manner, via the task queue."""
742 if isinstance(after, Dependency):
757 if isinstance(after, Dependency):
743 after = after.as_dict()
758 after = after.as_dict()
744 elif after is None:
759 elif after is None:
745 after = []
760 after = []
746 if isinstance(follow, Dependency):
761 if isinstance(follow, Dependency):
747 follow = follow.as_dict()
762 follow = follow.as_dict()
748 elif follow is None:
763 elif follow is None:
749 follow = []
764 follow = []
750 subheader = dict(after=after, follow=follow)
765 subheader = dict(after=after, follow=follow)
751
766
752 bufs = ss.pack_apply_message(f,args,kwargs)
767 bufs = ss.pack_apply_message(f,args,kwargs)
753 content = dict(bound=bound)
768 content = dict(bound=bound)
754 msg = self.session.send(self._task_socket, "apply_request",
769 msg = self.session.send(self._task_socket, "apply_request",
755 content=content, buffers=bufs, subheader=subheader)
770 content=content, buffers=bufs, subheader=subheader)
756 msg_id = msg['msg_id']
771 msg_id = msg['msg_id']
757 self.outstanding.add(msg_id)
772 self.outstanding.add(msg_id)
758 self.history.append(msg_id)
773 self.history.append(msg_id)
759 if block:
774 if block:
760 self.barrier(msg_id)
775 self.barrier(msg_id)
761 return self.results[msg_id]
776 return self._maybe_raise(self.results[msg_id])
762 else:
777 else:
763 return msg_id
778 return msg_id
764
779
765 def _apply_direct(self, f, args, kwargs, bound=True, block=None, targets=None,
780 def _apply_direct(self, f, args, kwargs, bound=True, block=None, targets=None,
766 after=None, follow=None):
781 after=None, follow=None):
767 """Then underlying method for applying functions to specific engines
782 """Then underlying method for applying functions to specific engines
768 via the MUX queue."""
783 via the MUX queue."""
769
784
770 queues,targets = self._build_targets(targets)
785 queues,targets = self._build_targets(targets)
771 bufs = ss.pack_apply_message(f,args,kwargs)
786 bufs = ss.pack_apply_message(f,args,kwargs)
772 if isinstance(after, Dependency):
787 if isinstance(after, Dependency):
773 after = after.as_dict()
788 after = after.as_dict()
774 elif after is None:
789 elif after is None:
775 after = []
790 after = []
776 if isinstance(follow, Dependency):
791 if isinstance(follow, Dependency):
777 follow = follow.as_dict()
792 follow = follow.as_dict()
778 elif follow is None:
793 elif follow is None:
779 follow = []
794 follow = []
780 subheader = dict(after=after, follow=follow)
795 subheader = dict(after=after, follow=follow)
781 content = dict(bound=bound)
796 content = dict(bound=bound)
782 msg_ids = []
797 msg_ids = []
783 for queue in queues:
798 for queue in queues:
784 msg = self.session.send(self._mux_socket, "apply_request",
799 msg = self.session.send(self._mux_socket, "apply_request",
785 content=content, buffers=bufs,ident=queue, subheader=subheader)
800 content=content, buffers=bufs,ident=queue, subheader=subheader)
786 msg_id = msg['msg_id']
801 msg_id = msg['msg_id']
787 self.outstanding.add(msg_id)
802 self.outstanding.add(msg_id)
788 self.history.append(msg_id)
803 self.history.append(msg_id)
789 msg_ids.append(msg_id)
804 msg_ids.append(msg_id)
790 if block:
805 if block:
791 self.barrier(msg_ids)
806 self.barrier(msg_ids)
792 else:
807 else:
793 if len(msg_ids) == 1:
808 if len(msg_ids) == 1:
794 return msg_ids[0]
809 return msg_ids[0]
795 else:
810 else:
796 return msg_ids
811 return msg_ids
797 if len(msg_ids) == 1:
812 if len(msg_ids) == 1:
798 return self.results[msg_ids[0]]
813 return self._maybe_raise(self.results[msg_ids[0]])
799 else:
814 else:
800 result = {}
815 result = {}
801 for target,mid in zip(targets, msg_ids):
816 for target,mid in zip(targets, msg_ids):
802 result[target] = self.results[mid]
817 result[target] = self.results[mid]
803 return result
818 return error.collect_exceptions(result, f.__name__)
804
819
805 #--------------------------------------------------------------------------
820 #--------------------------------------------------------------------------
806 # Data movement
821 # Data movement
807 #--------------------------------------------------------------------------
822 #--------------------------------------------------------------------------
808
823
809 @defaultblock
824 @defaultblock
810 def push(self, ns, targets=None, block=None):
825 def push(self, ns, targets=None, block=None):
811 """Push the contents of `ns` into the namespace on `target`"""
826 """Push the contents of `ns` into the namespace on `target`"""
812 if not isinstance(ns, dict):
827 if not isinstance(ns, dict):
813 raise TypeError("Must be a dict, not %s"%type(ns))
828 raise TypeError("Must be a dict, not %s"%type(ns))
814 result = self.apply(_push, (ns,), targets=targets, block=block, bound=True)
829 result = self.apply(_push, (ns,), targets=targets, block=block, bound=True)
815 return result
830 return result
816
831
817 @defaultblock
832 @defaultblock
818 def pull(self, keys, targets=None, block=True):
833 def pull(self, keys, targets=None, block=True):
819 """Pull objects from `target`'s namespace by `keys`"""
834 """Pull objects from `target`'s namespace by `keys`"""
820 if isinstance(keys, str):
835 if isinstance(keys, str):
821 pass
836 pass
822 elif isinstance(keys, (list,tuple,set)):
837 elif isinstance(keys, (list,tuple,set)):
823 for key in keys:
838 for key in keys:
824 if not isinstance(key, str):
839 if not isinstance(key, str):
825 raise TypeError
840 raise TypeError
826 result = self.apply(_pull, (keys,), targets=targets, block=block, bound=True)
841 result = self.apply(_pull, (keys,), targets=targets, block=block, bound=True)
827 return result
842 return result
828
843
829 #--------------------------------------------------------------------------
844 #--------------------------------------------------------------------------
830 # Query methods
845 # Query methods
831 #--------------------------------------------------------------------------
846 #--------------------------------------------------------------------------
832
847
833 @spinfirst
848 @spinfirst
834 def get_results(self, msg_ids, status_only=False):
849 def get_results(self, msg_ids, status_only=False):
835 """Returns the result of the execute or task request with `msg_ids`.
850 """Returns the result of the execute or task request with `msg_ids`.
836
851
837 Parameters
852 Parameters
838 ----------
853 ----------
839 msg_ids : list of ints or msg_ids
854 msg_ids : list of ints or msg_ids
840 if int:
855 if int:
841 Passed as index to self.history for convenience.
856 Passed as index to self.history for convenience.
842 status_only : bool (default: False)
857 status_only : bool (default: False)
843 if False:
858 if False:
844 return the actual results
859 return the actual results
845 """
860 """
846 if not isinstance(msg_ids, (list,tuple)):
861 if not isinstance(msg_ids, (list,tuple)):
847 msg_ids = [msg_ids]
862 msg_ids = [msg_ids]
848 theids = []
863 theids = []
849 for msg_id in msg_ids:
864 for msg_id in msg_ids:
850 if isinstance(msg_id, int):
865 if isinstance(msg_id, int):
851 msg_id = self.history[msg_id]
866 msg_id = self.history[msg_id]
852 if not isinstance(msg_id, str):
867 if not isinstance(msg_id, str):
853 raise TypeError("msg_ids must be str, not %r"%msg_id)
868 raise TypeError("msg_ids must be str, not %r"%msg_id)
854 theids.append(msg_id)
869 theids.append(msg_id)
855
870
856 completed = []
871 completed = []
857 local_results = {}
872 local_results = {}
858 for msg_id in list(theids):
873 for msg_id in list(theids):
859 if msg_id in self.results:
874 if msg_id in self.results:
860 completed.append(msg_id)
875 completed.append(msg_id)
861 local_results[msg_id] = self.results[msg_id]
876 local_results[msg_id] = self.results[msg_id]
862 theids.remove(msg_id)
877 theids.remove(msg_id)
863
878
864 if theids: # some not locally cached
879 if theids: # some not locally cached
865 content = dict(msg_ids=theids, status_only=status_only)
880 content = dict(msg_ids=theids, status_only=status_only)
866 msg = self.session.send(self._query_socket, "result_request", content=content)
881 msg = self.session.send(self._query_socket, "result_request", content=content)
867 zmq.select([self._query_socket], [], [])
882 zmq.select([self._query_socket], [], [])
868 idents,msg = self.session.recv(self._query_socket, zmq.NOBLOCK)
883 idents,msg = self.session.recv(self._query_socket, zmq.NOBLOCK)
869 if self.debug:
884 if self.debug:
870 pprint(msg)
885 pprint(msg)
871 content = msg['content']
886 content = msg['content']
872 if content['status'] != 'ok':
887 if content['status'] != 'ok':
873 raise ss.unwrap_exception(content)
888 raise ss.unwrap_exception(content)
874 else:
889 else:
875 content = dict(completed=[],pending=[])
890 content = dict(completed=[],pending=[])
876 if not status_only:
891 if not status_only:
877 # load cached results into result:
892 # load cached results into result:
878 content['completed'].extend(completed)
893 content['completed'].extend(completed)
879 content.update(local_results)
894 content.update(local_results)
880 # update cache with results:
895 # update cache with results:
881 for msg_id in msg_ids:
896 for msg_id in msg_ids:
882 if msg_id in content['completed']:
897 if msg_id in content['completed']:
883 self.results[msg_id] = content[msg_id]
898 self.results[msg_id] = content[msg_id]
884 return content
899 return content
885
900
886 @spinfirst
901 @spinfirst
887 def queue_status(self, targets=None, verbose=False):
902 def queue_status(self, targets=None, verbose=False):
888 """Fetch the status of engine queues.
903 """Fetch the status of engine queues.
889
904
890 Parameters
905 Parameters
891 ----------
906 ----------
892 targets : int/str/list of ints/strs
907 targets : int/str/list of ints/strs
893 the engines on which to execute
908 the engines on which to execute
894 default : all
909 default : all
895 verbose : bool
910 verbose : bool
896 whether to return lengths only, or lists of ids for each element
911 whether to return lengths only, or lists of ids for each element
897 """
912 """
898 targets = self._build_targets(targets)[1]
913 targets = self._build_targets(targets)[1]
899 content = dict(targets=targets, verbose=verbose)
914 content = dict(targets=targets, verbose=verbose)
900 self.session.send(self._query_socket, "queue_request", content=content)
915 self.session.send(self._query_socket, "queue_request", content=content)
901 idents,msg = self.session.recv(self._query_socket, 0)
916 idents,msg = self.session.recv(self._query_socket, 0)
902 if self.debug:
917 if self.debug:
903 pprint(msg)
918 pprint(msg)
904 content = msg['content']
919 content = msg['content']
905 status = content.pop('status')
920 status = content.pop('status')
906 if status != 'ok':
921 if status != 'ok':
907 raise ss.unwrap_exception(content)
922 raise ss.unwrap_exception(content)
908 return content
923 return content
909
924
910 @spinfirst
925 @spinfirst
911 def purge_results(self, msg_ids=[], targets=[]):
926 def purge_results(self, msg_ids=[], targets=[]):
912 """Tell the controller to forget results.
927 """Tell the controller to forget results.
913
928
914 Individual results can be purged by msg_id, or the entire
929 Individual results can be purged by msg_id, or the entire
915 history of specific targets can
930 history of specific targets can
916
931
917 Parameters
932 Parameters
918 ----------
933 ----------
919 targets : int/str/list of ints/strs
934 targets : int/str/list of ints/strs
920 the targets
935 the targets
921 default : None
936 default : None
922 """
937 """
923 if not targets and not msg_ids:
938 if not targets and not msg_ids:
924 raise ValueError
939 raise ValueError
925 if targets:
940 if targets:
926 targets = self._build_targets(targets)[1]
941 targets = self._build_targets(targets)[1]
927 content = dict(targets=targets, msg_ids=msg_ids)
942 content = dict(targets=targets, msg_ids=msg_ids)
928 self.session.send(self._query_socket, "purge_request", content=content)
943 self.session.send(self._query_socket, "purge_request", content=content)
929 idents, msg = self.session.recv(self._query_socket, 0)
944 idents, msg = self.session.recv(self._query_socket, 0)
930 if self.debug:
945 if self.debug:
931 pprint(msg)
946 pprint(msg)
932 content = msg['content']
947 content = msg['content']
933 if content['status'] != 'ok':
948 if content['status'] != 'ok':
934 raise ss.unwrap_exception(content)
949 raise ss.unwrap_exception(content)
935
950
936 class AsynClient(Client):
951 class AsynClient(Client):
937 """An Asynchronous client, using the Tornado Event Loop.
952 """An Asynchronous client, using the Tornado Event Loop.
938 !!!unfinished!!!"""
953 !!!unfinished!!!"""
939 io_loop = None
954 io_loop = None
940 _queue_stream = None
955 _queue_stream = None
941 _notifier_stream = None
956 _notifier_stream = None
942 _task_stream = None
957 _task_stream = None
943 _control_stream = None
958 _control_stream = None
944
959
945 def __init__(self, addr, context=None, username=None, debug=False, io_loop=None):
960 def __init__(self, addr, context=None, username=None, debug=False, io_loop=None):
946 Client.__init__(self, addr, context, username, debug)
961 Client.__init__(self, addr, context, username, debug)
947 if io_loop is None:
962 if io_loop is None:
948 io_loop = ioloop.IOLoop.instance()
963 io_loop = ioloop.IOLoop.instance()
949 self.io_loop = io_loop
964 self.io_loop = io_loop
950
965
951 self._queue_stream = zmqstream.ZMQStream(self._mux_socket, io_loop)
966 self._queue_stream = zmqstream.ZMQStream(self._mux_socket, io_loop)
952 self._control_stream = zmqstream.ZMQStream(self._control_socket, io_loop)
967 self._control_stream = zmqstream.ZMQStream(self._control_socket, io_loop)
953 self._task_stream = zmqstream.ZMQStream(self._task_socket, io_loop)
968 self._task_stream = zmqstream.ZMQStream(self._task_socket, io_loop)
954 self._notification_stream = zmqstream.ZMQStream(self._notification_socket, io_loop)
969 self._notification_stream = zmqstream.ZMQStream(self._notification_socket, io_loop)
955
970
956 def spin(self):
971 def spin(self):
957 for stream in (self.queue_stream, self.notifier_stream,
972 for stream in (self.queue_stream, self.notifier_stream,
958 self.task_stream, self.control_stream):
973 self.task_stream, self.control_stream):
959 stream.flush()
974 stream.flush()
960
975
@@ -1,423 +1,430
1 #!/usr/bin/env python
1 #!/usr/bin/env python
2 """
2 """
3 Kernel adapted from kernel.py to use ZMQ Streams
3 Kernel adapted from kernel.py to use ZMQ Streams
4 """
4 """
5
5
6 #-----------------------------------------------------------------------------
6 #-----------------------------------------------------------------------------
7 # Imports
7 # Imports
8 #-----------------------------------------------------------------------------
8 #-----------------------------------------------------------------------------
9
9
10 # Standard library imports.
10 # Standard library imports.
11 from __future__ import print_function
11 from __future__ import print_function
12 import __builtin__
12 import __builtin__
13 from code import CommandCompiler
13 from code import CommandCompiler
14 import os
14 import os
15 import sys
15 import sys
16 import time
16 import time
17 import traceback
17 import traceback
18 from datetime import datetime
18 from datetime import datetime
19 from signal import SIGTERM, SIGKILL
19 from signal import SIGTERM, SIGKILL
20 from pprint import pprint
20 from pprint import pprint
21
21
22 # System library imports.
22 # System library imports.
23 import zmq
23 import zmq
24 from zmq.eventloop import ioloop, zmqstream
24 from zmq.eventloop import ioloop, zmqstream
25
25
26 # Local imports.
26 # Local imports.
27 from IPython.core import ultratb
27 from IPython.utils.traitlets import HasTraits, Instance, List
28 from IPython.utils.traitlets import HasTraits, Instance, List
28 from IPython.zmq.completer import KernelCompleter
29 from IPython.zmq.completer import KernelCompleter
29 from IPython.zmq.log import logger # a Logger object
30 from IPython.zmq.log import logger # a Logger object
30
31
31 from streamsession import StreamSession, Message, extract_header, serialize_object,\
32 from streamsession import StreamSession, Message, extract_header, serialize_object,\
32 unpack_apply_message, ISO8601, wrap_exception
33 unpack_apply_message, ISO8601, wrap_exception
33 from dependency import UnmetDependency
34 from dependency import UnmetDependency
34 import heartmonitor
35 import heartmonitor
35 from client import Client
36 from client import Client
36
37
37 def printer(*args):
38 def printer(*args):
38 pprint(args)
39 pprint(args)
39
40
40 #-----------------------------------------------------------------------------
41 #-----------------------------------------------------------------------------
41 # Main kernel class
42 # Main kernel class
42 #-----------------------------------------------------------------------------
43 #-----------------------------------------------------------------------------
43
44
44 class Kernel(HasTraits):
45 class Kernel(HasTraits):
45
46
46 #---------------------------------------------------------------------------
47 #---------------------------------------------------------------------------
47 # Kernel interface
48 # Kernel interface
48 #---------------------------------------------------------------------------
49 #---------------------------------------------------------------------------
49
50
50 session = Instance(StreamSession)
51 session = Instance(StreamSession)
51 shell_streams = Instance(list)
52 shell_streams = Instance(list)
52 control_stream = Instance(zmqstream.ZMQStream)
53 control_stream = Instance(zmqstream.ZMQStream)
53 task_stream = Instance(zmqstream.ZMQStream)
54 task_stream = Instance(zmqstream.ZMQStream)
54 iopub_stream = Instance(zmqstream.ZMQStream)
55 iopub_stream = Instance(zmqstream.ZMQStream)
55 client = Instance(Client)
56 client = Instance(Client)
56 loop = Instance(ioloop.IOLoop)
57 loop = Instance(ioloop.IOLoop)
57
58
58 def __init__(self, **kwargs):
59 def __init__(self, **kwargs):
59 super(Kernel, self).__init__(**kwargs)
60 super(Kernel, self).__init__(**kwargs)
60 self.identity = self.shell_streams[0].getsockopt(zmq.IDENTITY)
61 self.identity = self.shell_streams[0].getsockopt(zmq.IDENTITY)
61 self.user_ns = {}
62 self.user_ns = {}
62 self.history = []
63 self.history = []
63 self.compiler = CommandCompiler()
64 self.compiler = CommandCompiler()
64 self.completer = KernelCompleter(self.user_ns)
65 self.completer = KernelCompleter(self.user_ns)
65 self.aborted = set()
66 self.aborted = set()
66
67
67 # Build dict of handlers for message types
68 # Build dict of handlers for message types
68 self.shell_handlers = {}
69 self.shell_handlers = {}
69 self.control_handlers = {}
70 self.control_handlers = {}
70 for msg_type in ['execute_request', 'complete_request', 'apply_request',
71 for msg_type in ['execute_request', 'complete_request', 'apply_request',
71 'clear_request']:
72 'clear_request']:
72 self.shell_handlers[msg_type] = getattr(self, msg_type)
73 self.shell_handlers[msg_type] = getattr(self, msg_type)
73
74
74 for msg_type in ['shutdown_request', 'abort_request']+self.shell_handlers.keys():
75 for msg_type in ['shutdown_request', 'abort_request']+self.shell_handlers.keys():
75 self.control_handlers[msg_type] = getattr(self, msg_type)
76 self.control_handlers[msg_type] = getattr(self, msg_type)
76
77
78
79 def _wrap_exception(self, method=None):
80 e_info = dict(engineid=self.identity, method=method)
81 content=wrap_exception(e_info)
82 return content
83
77 #-------------------- control handlers -----------------------------
84 #-------------------- control handlers -----------------------------
78 def abort_queues(self):
85 def abort_queues(self):
79 for stream in self.shell_streams:
86 for stream in self.shell_streams:
80 if stream:
87 if stream:
81 self.abort_queue(stream)
88 self.abort_queue(stream)
82
89
83 def abort_queue(self, stream):
90 def abort_queue(self, stream):
84 while True:
91 while True:
85 try:
92 try:
86 msg = self.session.recv(stream, zmq.NOBLOCK,content=True)
93 msg = self.session.recv(stream, zmq.NOBLOCK,content=True)
87 except zmq.ZMQError as e:
94 except zmq.ZMQError as e:
88 if e.errno == zmq.EAGAIN:
95 if e.errno == zmq.EAGAIN:
89 break
96 break
90 else:
97 else:
91 return
98 return
92 else:
99 else:
93 if msg is None:
100 if msg is None:
94 return
101 return
95 else:
102 else:
96 idents,msg = msg
103 idents,msg = msg
97
104
98 # assert self.reply_socketly_socket.rcvmore(), "Unexpected missing message part."
105 # assert self.reply_socketly_socket.rcvmore(), "Unexpected missing message part."
99 # msg = self.reply_socket.recv_json()
106 # msg = self.reply_socket.recv_json()
100 print ("Aborting:", file=sys.__stdout__)
107 print ("Aborting:", file=sys.__stdout__)
101 print (Message(msg), file=sys.__stdout__)
108 print (Message(msg), file=sys.__stdout__)
102 msg_type = msg['msg_type']
109 msg_type = msg['msg_type']
103 reply_type = msg_type.split('_')[0] + '_reply'
110 reply_type = msg_type.split('_')[0] + '_reply'
104 # reply_msg = self.session.msg(reply_type, {'status' : 'aborted'}, msg)
111 # reply_msg = self.session.msg(reply_type, {'status' : 'aborted'}, msg)
105 # self.reply_socket.send(ident,zmq.SNDMORE)
112 # self.reply_socket.send(ident,zmq.SNDMORE)
106 # self.reply_socket.send_json(reply_msg)
113 # self.reply_socket.send_json(reply_msg)
107 reply_msg = self.session.send(stream, reply_type,
114 reply_msg = self.session.send(stream, reply_type,
108 content={'status' : 'aborted'}, parent=msg, ident=idents)[0]
115 content={'status' : 'aborted'}, parent=msg, ident=idents)[0]
109 print(Message(reply_msg), file=sys.__stdout__)
116 print(Message(reply_msg), file=sys.__stdout__)
110 # We need to wait a bit for requests to come in. This can probably
117 # We need to wait a bit for requests to come in. This can probably
111 # be set shorter for true asynchronous clients.
118 # be set shorter for true asynchronous clients.
112 time.sleep(0.05)
119 time.sleep(0.05)
113
120
114 def abort_request(self, stream, ident, parent):
121 def abort_request(self, stream, ident, parent):
115 """abort a specifig msg by id"""
122 """abort a specifig msg by id"""
116 msg_ids = parent['content'].get('msg_ids', None)
123 msg_ids = parent['content'].get('msg_ids', None)
117 if isinstance(msg_ids, basestring):
124 if isinstance(msg_ids, basestring):
118 msg_ids = [msg_ids]
125 msg_ids = [msg_ids]
119 if not msg_ids:
126 if not msg_ids:
120 self.abort_queues()
127 self.abort_queues()
121 for mid in msg_ids:
128 for mid in msg_ids:
122 self.aborted.add(str(mid))
129 self.aborted.add(str(mid))
123
130
124 content = dict(status='ok')
131 content = dict(status='ok')
125 reply_msg = self.session.send(stream, 'abort_reply', content=content,
132 reply_msg = self.session.send(stream, 'abort_reply', content=content,
126 parent=parent, ident=ident)[0]
133 parent=parent, ident=ident)[0]
127 print(Message(reply_msg), file=sys.__stdout__)
134 print(Message(reply_msg), file=sys.__stdout__)
128
135
129 def shutdown_request(self, stream, ident, parent):
136 def shutdown_request(self, stream, ident, parent):
130 """kill ourself. This should really be handled in an external process"""
137 """kill ourself. This should really be handled in an external process"""
131 try:
138 try:
132 self.abort_queues()
139 self.abort_queues()
133 except:
140 except:
134 content = wrap_exception()
141 content = self._wrap_exception('shutdown')
135 else:
142 else:
136 content = dict(parent['content'])
143 content = dict(parent['content'])
137 content['status'] = 'ok'
144 content['status'] = 'ok'
138 msg = self.session.send(stream, 'shutdown_reply',
145 msg = self.session.send(stream, 'shutdown_reply',
139 content=content, parent=parent, ident=ident)
146 content=content, parent=parent, ident=ident)
140 # msg = self.session.send(self.pub_socket, 'shutdown_reply',
147 # msg = self.session.send(self.pub_socket, 'shutdown_reply',
141 # content, parent, ident)
148 # content, parent, ident)
142 # print >> sys.__stdout__, msg
149 # print >> sys.__stdout__, msg
143 # time.sleep(0.2)
150 # time.sleep(0.2)
144 dc = ioloop.DelayedCallback(lambda : sys.exit(0), 1000, self.loop)
151 dc = ioloop.DelayedCallback(lambda : sys.exit(0), 1000, self.loop)
145 dc.start()
152 dc.start()
146
153
147 def dispatch_control(self, msg):
154 def dispatch_control(self, msg):
148 idents,msg = self.session.feed_identities(msg, copy=False)
155 idents,msg = self.session.feed_identities(msg, copy=False)
149 try:
156 try:
150 msg = self.session.unpack_message(msg, content=True, copy=False)
157 msg = self.session.unpack_message(msg, content=True, copy=False)
151 except:
158 except:
152 logger.error("Invalid Message", exc_info=True)
159 logger.error("Invalid Message", exc_info=True)
153 return
160 return
154
161
155 header = msg['header']
162 header = msg['header']
156 msg_id = header['msg_id']
163 msg_id = header['msg_id']
157
164
158 handler = self.control_handlers.get(msg['msg_type'], None)
165 handler = self.control_handlers.get(msg['msg_type'], None)
159 if handler is None:
166 if handler is None:
160 print ("UNKNOWN CONTROL MESSAGE TYPE:", msg, file=sys.__stderr__)
167 print ("UNKNOWN CONTROL MESSAGE TYPE:", msg, file=sys.__stderr__)
161 else:
168 else:
162 handler(self.control_stream, idents, msg)
169 handler(self.control_stream, idents, msg)
163
170
164
171
165 #-------------------- queue helpers ------------------------------
172 #-------------------- queue helpers ------------------------------
166
173
167 def check_dependencies(self, dependencies):
174 def check_dependencies(self, dependencies):
168 if not dependencies:
175 if not dependencies:
169 return True
176 return True
170 if len(dependencies) == 2 and dependencies[0] in 'any all'.split():
177 if len(dependencies) == 2 and dependencies[0] in 'any all'.split():
171 anyorall = dependencies[0]
178 anyorall = dependencies[0]
172 dependencies = dependencies[1]
179 dependencies = dependencies[1]
173 else:
180 else:
174 anyorall = 'all'
181 anyorall = 'all'
175 results = self.client.get_results(dependencies,status_only=True)
182 results = self.client.get_results(dependencies,status_only=True)
176 if results['status'] != 'ok':
183 if results['status'] != 'ok':
177 return False
184 return False
178
185
179 if anyorall == 'any':
186 if anyorall == 'any':
180 if not results['completed']:
187 if not results['completed']:
181 return False
188 return False
182 else:
189 else:
183 if results['pending']:
190 if results['pending']:
184 return False
191 return False
185
192
186 return True
193 return True
187
194
188 def check_aborted(self, msg_id):
195 def check_aborted(self, msg_id):
189 return msg_id in self.aborted
196 return msg_id in self.aborted
190
197
191 #-------------------- queue handlers -----------------------------
198 #-------------------- queue handlers -----------------------------
192
199
193 def clear_request(self, stream, idents, parent):
200 def clear_request(self, stream, idents, parent):
194 """Clear our namespace."""
201 """Clear our namespace."""
195 self.user_ns = {}
202 self.user_ns = {}
196 msg = self.session.send(stream, 'clear_reply', ident=idents, parent=parent,
203 msg = self.session.send(stream, 'clear_reply', ident=idents, parent=parent,
197 content = dict(status='ok'))
204 content = dict(status='ok'))
198
205
199 def execute_request(self, stream, ident, parent):
206 def execute_request(self, stream, ident, parent):
200 try:
207 try:
201 code = parent[u'content'][u'code']
208 code = parent[u'content'][u'code']
202 except:
209 except:
203 print("Got bad msg: ", file=sys.__stderr__)
210 print("Got bad msg: ", file=sys.__stderr__)
204 print(Message(parent), file=sys.__stderr__)
211 print(Message(parent), file=sys.__stderr__)
205 return
212 return
206 # pyin_msg = self.session.msg(u'pyin',{u'code':code}, parent=parent)
213 # pyin_msg = self.session.msg(u'pyin',{u'code':code}, parent=parent)
207 # self.iopub_stream.send(pyin_msg)
214 # self.iopub_stream.send(pyin_msg)
208 self.session.send(self.iopub_stream, u'pyin', {u'code':code},parent=parent)
215 self.session.send(self.iopub_stream, u'pyin', {u'code':code},parent=parent)
209 started = datetime.now().strftime(ISO8601)
216 started = datetime.now().strftime(ISO8601)
210 try:
217 try:
211 comp_code = self.compiler(code, '<zmq-kernel>')
218 comp_code = self.compiler(code, '<zmq-kernel>')
212 # allow for not overriding displayhook
219 # allow for not overriding displayhook
213 if hasattr(sys.displayhook, 'set_parent'):
220 if hasattr(sys.displayhook, 'set_parent'):
214 sys.displayhook.set_parent(parent)
221 sys.displayhook.set_parent(parent)
215 exec comp_code in self.user_ns, self.user_ns
222 exec comp_code in self.user_ns, self.user_ns
216 except:
223 except:
217 exc_content = wrap_exception()
224 exc_content = self._wrap_exception('execute')
218 # exc_msg = self.session.msg(u'pyerr', exc_content, parent)
225 # exc_msg = self.session.msg(u'pyerr', exc_content, parent)
219 self.session.send(self.iopub_stream, u'pyerr', exc_content, parent=parent)
226 self.session.send(self.iopub_stream, u'pyerr', exc_content, parent=parent)
220 reply_content = exc_content
227 reply_content = exc_content
221 else:
228 else:
222 reply_content = {'status' : 'ok'}
229 reply_content = {'status' : 'ok'}
223 # reply_msg = self.session.msg(u'execute_reply', reply_content, parent)
230 # reply_msg = self.session.msg(u'execute_reply', reply_content, parent)
224 # self.reply_socket.send(ident, zmq.SNDMORE)
231 # self.reply_socket.send(ident, zmq.SNDMORE)
225 # self.reply_socket.send_json(reply_msg)
232 # self.reply_socket.send_json(reply_msg)
226 reply_msg = self.session.send(stream, u'execute_reply', reply_content, parent=parent,
233 reply_msg = self.session.send(stream, u'execute_reply', reply_content, parent=parent,
227 ident=ident, subheader = dict(started=started))
234 ident=ident, subheader = dict(started=started))
228 print(Message(reply_msg), file=sys.__stdout__)
235 print(Message(reply_msg), file=sys.__stdout__)
229 if reply_msg['content']['status'] == u'error':
236 if reply_msg['content']['status'] == u'error':
230 self.abort_queues()
237 self.abort_queues()
231
238
232 def complete_request(self, stream, ident, parent):
239 def complete_request(self, stream, ident, parent):
233 matches = {'matches' : self.complete(parent),
240 matches = {'matches' : self.complete(parent),
234 'status' : 'ok'}
241 'status' : 'ok'}
235 completion_msg = self.session.send(stream, 'complete_reply',
242 completion_msg = self.session.send(stream, 'complete_reply',
236 matches, parent, ident)
243 matches, parent, ident)
237 # print >> sys.__stdout__, completion_msg
244 # print >> sys.__stdout__, completion_msg
238
245
239 def complete(self, msg):
246 def complete(self, msg):
240 return self.completer.complete(msg.content.line, msg.content.text)
247 return self.completer.complete(msg.content.line, msg.content.text)
241
248
242 def apply_request(self, stream, ident, parent):
249 def apply_request(self, stream, ident, parent):
243 print (parent)
250 print (parent)
244 try:
251 try:
245 content = parent[u'content']
252 content = parent[u'content']
246 bufs = parent[u'buffers']
253 bufs = parent[u'buffers']
247 msg_id = parent['header']['msg_id']
254 msg_id = parent['header']['msg_id']
248 bound = content.get('bound', False)
255 bound = content.get('bound', False)
249 except:
256 except:
250 print("Got bad msg: ", file=sys.__stderr__)
257 print("Got bad msg: ", file=sys.__stderr__)
251 print(Message(parent), file=sys.__stderr__)
258 print(Message(parent), file=sys.__stderr__)
252 return
259 return
253 # pyin_msg = self.session.msg(u'pyin',{u'code':code}, parent=parent)
260 # pyin_msg = self.session.msg(u'pyin',{u'code':code}, parent=parent)
254 # self.iopub_stream.send(pyin_msg)
261 # self.iopub_stream.send(pyin_msg)
255 # self.session.send(self.iopub_stream, u'pyin', {u'code':code},parent=parent)
262 # self.session.send(self.iopub_stream, u'pyin', {u'code':code},parent=parent)
256 sub = {'dependencies_met' : True, 'engine' : self.identity,
263 sub = {'dependencies_met' : True, 'engine' : self.identity,
257 'started': datetime.now().strftime(ISO8601)}
264 'started': datetime.now().strftime(ISO8601)}
258 try:
265 try:
259 # allow for not overriding displayhook
266 # allow for not overriding displayhook
260 if hasattr(sys.displayhook, 'set_parent'):
267 if hasattr(sys.displayhook, 'set_parent'):
261 sys.displayhook.set_parent(parent)
268 sys.displayhook.set_parent(parent)
262 # exec "f(*args,**kwargs)" in self.user_ns, self.user_ns
269 # exec "f(*args,**kwargs)" in self.user_ns, self.user_ns
263 if bound:
270 if bound:
264 working = self.user_ns
271 working = self.user_ns
265 suffix = str(msg_id).replace("-","")
272 suffix = str(msg_id).replace("-","")
266 prefix = "_"
273 prefix = "_"
267
274
268 else:
275 else:
269 working = dict()
276 working = dict()
270 suffix = prefix = "_" # prevent keyword collisions with lambda
277 suffix = prefix = "_" # prevent keyword collisions with lambda
271 f,args,kwargs = unpack_apply_message(bufs, working, copy=False)
278 f,args,kwargs = unpack_apply_message(bufs, working, copy=False)
272 # if f.fun
279 # if f.fun
273 fname = prefix+f.func_name.strip('<>')+suffix
280 fname = prefix+f.func_name.strip('<>')+suffix
274 argname = prefix+"args"+suffix
281 argname = prefix+"args"+suffix
275 kwargname = prefix+"kwargs"+suffix
282 kwargname = prefix+"kwargs"+suffix
276 resultname = prefix+"result"+suffix
283 resultname = prefix+"result"+suffix
277
284
278 ns = { fname : f, argname : args, kwargname : kwargs }
285 ns = { fname : f, argname : args, kwargname : kwargs }
279 # print ns
286 # print ns
280 working.update(ns)
287 working.update(ns)
281 code = "%s=%s(*%s,**%s)"%(resultname, fname, argname, kwargname)
288 code = "%s=%s(*%s,**%s)"%(resultname, fname, argname, kwargname)
282 exec code in working, working
289 exec code in working, working
283 result = working.get(resultname)
290 result = working.get(resultname)
284 # clear the namespace
291 # clear the namespace
285 if bound:
292 if bound:
286 for key in ns.iterkeys():
293 for key in ns.iterkeys():
287 self.user_ns.pop(key)
294 self.user_ns.pop(key)
288 else:
295 else:
289 del working
296 del working
290
297
291 packed_result,buf = serialize_object(result)
298 packed_result,buf = serialize_object(result)
292 result_buf = [packed_result]+buf
299 result_buf = [packed_result]+buf
293 except:
300 except:
294 exc_content = wrap_exception()
301 exc_content = self._wrap_exception('apply')
295 # exc_msg = self.session.msg(u'pyerr', exc_content, parent)
302 # exc_msg = self.session.msg(u'pyerr', exc_content, parent)
296 self.session.send(self.iopub_stream, u'pyerr', exc_content, parent=parent)
303 self.session.send(self.iopub_stream, u'pyerr', exc_content, parent=parent)
297 reply_content = exc_content
304 reply_content = exc_content
298 result_buf = []
305 result_buf = []
299
306
300 if etype is UnmetDependency:
307 if exc_content['ename'] == UnmetDependency.__name__:
301 sub['dependencies_met'] = False
308 sub['dependencies_met'] = False
302 else:
309 else:
303 reply_content = {'status' : 'ok'}
310 reply_content = {'status' : 'ok'}
304 # reply_msg = self.session.msg(u'execute_reply', reply_content, parent)
311 # reply_msg = self.session.msg(u'execute_reply', reply_content, parent)
305 # self.reply_socket.send(ident, zmq.SNDMORE)
312 # self.reply_socket.send(ident, zmq.SNDMORE)
306 # self.reply_socket.send_json(reply_msg)
313 # self.reply_socket.send_json(reply_msg)
307 reply_msg = self.session.send(stream, u'apply_reply', reply_content,
314 reply_msg = self.session.send(stream, u'apply_reply', reply_content,
308 parent=parent, ident=ident,buffers=result_buf, subheader=sub)
315 parent=parent, ident=ident,buffers=result_buf, subheader=sub)
309 print(Message(reply_msg), file=sys.__stdout__)
316 print(Message(reply_msg), file=sys.__stdout__)
310 # if reply_msg['content']['status'] == u'error':
317 # if reply_msg['content']['status'] == u'error':
311 # self.abort_queues()
318 # self.abort_queues()
312
319
313 def dispatch_queue(self, stream, msg):
320 def dispatch_queue(self, stream, msg):
314 self.control_stream.flush()
321 self.control_stream.flush()
315 idents,msg = self.session.feed_identities(msg, copy=False)
322 idents,msg = self.session.feed_identities(msg, copy=False)
316 try:
323 try:
317 msg = self.session.unpack_message(msg, content=True, copy=False)
324 msg = self.session.unpack_message(msg, content=True, copy=False)
318 except:
325 except:
319 logger.error("Invalid Message", exc_info=True)
326 logger.error("Invalid Message", exc_info=True)
320 return
327 return
321
328
322
329
323 header = msg['header']
330 header = msg['header']
324 msg_id = header['msg_id']
331 msg_id = header['msg_id']
325 if self.check_aborted(msg_id):
332 if self.check_aborted(msg_id):
326 self.aborted.remove(msg_id)
333 self.aborted.remove(msg_id)
327 # is it safe to assume a msg_id will not be resubmitted?
334 # is it safe to assume a msg_id will not be resubmitted?
328 reply_type = msg['msg_type'].split('_')[0] + '_reply'
335 reply_type = msg['msg_type'].split('_')[0] + '_reply'
329 reply_msg = self.session.send(stream, reply_type,
336 reply_msg = self.session.send(stream, reply_type,
330 content={'status' : 'aborted'}, parent=msg, ident=idents)
337 content={'status' : 'aborted'}, parent=msg, ident=idents)
331 return
338 return
332 handler = self.shell_handlers.get(msg['msg_type'], None)
339 handler = self.shell_handlers.get(msg['msg_type'], None)
333 if handler is None:
340 if handler is None:
334 print ("UNKNOWN MESSAGE TYPE:", msg, file=sys.__stderr__)
341 print ("UNKNOWN MESSAGE TYPE:", msg, file=sys.__stderr__)
335 else:
342 else:
336 handler(stream, idents, msg)
343 handler(stream, idents, msg)
337
344
338 def start(self):
345 def start(self):
339 #### stream mode:
346 #### stream mode:
340 if self.control_stream:
347 if self.control_stream:
341 self.control_stream.on_recv(self.dispatch_control, copy=False)
348 self.control_stream.on_recv(self.dispatch_control, copy=False)
342 self.control_stream.on_err(printer)
349 self.control_stream.on_err(printer)
343
350
344 def make_dispatcher(stream):
351 def make_dispatcher(stream):
345 def dispatcher(msg):
352 def dispatcher(msg):
346 return self.dispatch_queue(stream, msg)
353 return self.dispatch_queue(stream, msg)
347 return dispatcher
354 return dispatcher
348
355
349 for s in self.shell_streams:
356 for s in self.shell_streams:
350 s.on_recv(make_dispatcher(s), copy=False)
357 s.on_recv(make_dispatcher(s), copy=False)
351 s.on_err(printer)
358 s.on_err(printer)
352
359
353 if self.iopub_stream:
360 if self.iopub_stream:
354 self.iopub_stream.on_err(printer)
361 self.iopub_stream.on_err(printer)
355 self.iopub_stream.on_send(printer)
362 self.iopub_stream.on_send(printer)
356
363
357 #### while True mode:
364 #### while True mode:
358 # while True:
365 # while True:
359 # idle = True
366 # idle = True
360 # try:
367 # try:
361 # msg = self.shell_stream.socket.recv_multipart(
368 # msg = self.shell_stream.socket.recv_multipart(
362 # zmq.NOBLOCK, copy=False)
369 # zmq.NOBLOCK, copy=False)
363 # except zmq.ZMQError, e:
370 # except zmq.ZMQError, e:
364 # if e.errno != zmq.EAGAIN:
371 # if e.errno != zmq.EAGAIN:
365 # raise e
372 # raise e
366 # else:
373 # else:
367 # idle=False
374 # idle=False
368 # self.dispatch_queue(self.shell_stream, msg)
375 # self.dispatch_queue(self.shell_stream, msg)
369 #
376 #
370 # if not self.task_stream.empty():
377 # if not self.task_stream.empty():
371 # idle=False
378 # idle=False
372 # msg = self.task_stream.recv_multipart()
379 # msg = self.task_stream.recv_multipart()
373 # self.dispatch_queue(self.task_stream, msg)
380 # self.dispatch_queue(self.task_stream, msg)
374 # if idle:
381 # if idle:
375 # # don't busywait
382 # # don't busywait
376 # time.sleep(1e-3)
383 # time.sleep(1e-3)
377
384
378 def make_kernel(identity, control_addr, shell_addrs, iopub_addr, hb_addrs,
385 def make_kernel(identity, control_addr, shell_addrs, iopub_addr, hb_addrs,
379 client_addr=None, loop=None, context=None, key=None):
386 client_addr=None, loop=None, context=None, key=None):
380 # create loop, context, and session:
387 # create loop, context, and session:
381 if loop is None:
388 if loop is None:
382 loop = ioloop.IOLoop.instance()
389 loop = ioloop.IOLoop.instance()
383 if context is None:
390 if context is None:
384 context = zmq.Context()
391 context = zmq.Context()
385 c = context
392 c = context
386 session = StreamSession(key=key)
393 session = StreamSession(key=key)
387 # print (session.key)
394 # print (session.key)
388 print (control_addr, shell_addrs, iopub_addr, hb_addrs)
395 print (control_addr, shell_addrs, iopub_addr, hb_addrs)
389
396
390 # create Control Stream
397 # create Control Stream
391 control_stream = zmqstream.ZMQStream(c.socket(zmq.PAIR), loop)
398 control_stream = zmqstream.ZMQStream(c.socket(zmq.PAIR), loop)
392 control_stream.setsockopt(zmq.IDENTITY, identity)
399 control_stream.setsockopt(zmq.IDENTITY, identity)
393 control_stream.connect(control_addr)
400 control_stream.connect(control_addr)
394
401
395 # create Shell Streams (MUX, Task, etc.):
402 # create Shell Streams (MUX, Task, etc.):
396 shell_streams = []
403 shell_streams = []
397 for addr in shell_addrs:
404 for addr in shell_addrs:
398 stream = zmqstream.ZMQStream(c.socket(zmq.PAIR), loop)
405 stream = zmqstream.ZMQStream(c.socket(zmq.PAIR), loop)
399 stream.setsockopt(zmq.IDENTITY, identity)
406 stream.setsockopt(zmq.IDENTITY, identity)
400 stream.connect(addr)
407 stream.connect(addr)
401 shell_streams.append(stream)
408 shell_streams.append(stream)
402
409
403 # create iopub stream:
410 # create iopub stream:
404 iopub_stream = zmqstream.ZMQStream(c.socket(zmq.PUB), loop)
411 iopub_stream = zmqstream.ZMQStream(c.socket(zmq.PUB), loop)
405 iopub_stream.setsockopt(zmq.IDENTITY, identity)
412 iopub_stream.setsockopt(zmq.IDENTITY, identity)
406 iopub_stream.connect(iopub_addr)
413 iopub_stream.connect(iopub_addr)
407
414
408 # launch heartbeat
415 # launch heartbeat
409 heart = heartmonitor.Heart(*map(str, hb_addrs), heart_id=identity)
416 heart = heartmonitor.Heart(*map(str, hb_addrs), heart_id=identity)
410 heart.start()
417 heart.start()
411
418
412 # create (optional) Client
419 # create (optional) Client
413 if client_addr:
420 if client_addr:
414 client = Client(client_addr, username=identity)
421 client = Client(client_addr, username=identity)
415 else:
422 else:
416 client = None
423 client = None
417
424
418 kernel = Kernel(session=session, control_stream=control_stream,
425 kernel = Kernel(session=session, control_stream=control_stream,
419 shell_streams=shell_streams, iopub_stream=iopub_stream,
426 shell_streams=shell_streams, iopub_stream=iopub_stream,
420 client=client, loop=loop)
427 client=client, loop=loop)
421 kernel.start()
428 kernel.start()
422 return loop, c, kernel
429 return loop, c, kernel
423
430
@@ -1,531 +1,530
1 #!/usr/bin/env python
1 #!/usr/bin/env python
2 """edited session.py to work with streams, and move msg_type to the header
2 """edited session.py to work with streams, and move msg_type to the header
3 """
3 """
4
4
5
5
6 import os
6 import os
7 import sys
7 import sys
8 import traceback
8 import traceback
9 import pprint
9 import pprint
10 import uuid
10 import uuid
11 from datetime import datetime
11 from datetime import datetime
12
12
13 import zmq
13 import zmq
14 from zmq.utils import jsonapi
14 from zmq.utils import jsonapi
15 from zmq.eventloop.zmqstream import ZMQStream
15 from zmq.eventloop.zmqstream import ZMQStream
16
16
17 from IPython.utils.pickleutil import can, uncan, canSequence, uncanSequence
17 from IPython.utils.pickleutil import can, uncan, canSequence, uncanSequence
18 from IPython.utils.newserialized import serialize, unserialize
18 from IPython.utils.newserialized import serialize, unserialize
19
19
20 from IPython.zmq.parallel.error import RemoteError
21
20 try:
22 try:
21 import cPickle
23 import cPickle
22 pickle = cPickle
24 pickle = cPickle
23 except:
25 except:
24 cPickle = None
26 cPickle = None
25 import pickle
27 import pickle
26
28
27 # packer priority: jsonlib[2], cPickle, simplejson/json, pickle
29 # packer priority: jsonlib[2], cPickle, simplejson/json, pickle
28 json_name = '' if not jsonapi.jsonmod else jsonapi.jsonmod.__name__
30 json_name = '' if not jsonapi.jsonmod else jsonapi.jsonmod.__name__
29 if json_name in ('jsonlib', 'jsonlib2'):
31 if json_name in ('jsonlib', 'jsonlib2'):
30 use_json = True
32 use_json = True
31 elif json_name:
33 elif json_name:
32 if cPickle is None:
34 if cPickle is None:
33 use_json = True
35 use_json = True
34 else:
36 else:
35 use_json = False
37 use_json = False
36 else:
38 else:
37 use_json = False
39 use_json = False
38
40
39 def squash_unicode(obj):
41 def squash_unicode(obj):
40 if isinstance(obj,dict):
42 if isinstance(obj,dict):
41 for key in obj.keys():
43 for key in obj.keys():
42 obj[key] = squash_unicode(obj[key])
44 obj[key] = squash_unicode(obj[key])
43 if isinstance(key, unicode):
45 if isinstance(key, unicode):
44 obj[squash_unicode(key)] = obj.pop(key)
46 obj[squash_unicode(key)] = obj.pop(key)
45 elif isinstance(obj, list):
47 elif isinstance(obj, list):
46 for i,v in enumerate(obj):
48 for i,v in enumerate(obj):
47 obj[i] = squash_unicode(v)
49 obj[i] = squash_unicode(v)
48 elif isinstance(obj, unicode):
50 elif isinstance(obj, unicode):
49 obj = obj.encode('utf8')
51 obj = obj.encode('utf8')
50 return obj
52 return obj
51
53
52 if use_json:
54 if use_json:
53 default_packer = jsonapi.dumps
55 default_packer = jsonapi.dumps
54 default_unpacker = lambda s: squash_unicode(jsonapi.loads(s))
56 default_unpacker = lambda s: squash_unicode(jsonapi.loads(s))
55 else:
57 else:
56 default_packer = lambda o: pickle.dumps(o,-1)
58 default_packer = lambda o: pickle.dumps(o,-1)
57 default_unpacker = pickle.loads
59 default_unpacker = pickle.loads
58
60
59
61
60 DELIM="<IDS|MSG>"
62 DELIM="<IDS|MSG>"
61 ISO8601="%Y-%m-%dT%H:%M:%S.%f"
63 ISO8601="%Y-%m-%dT%H:%M:%S.%f"
62
64
63 def wrap_exception():
65 def wrap_exception(engine_info={}):
64 etype, evalue, tb = sys.exc_info()
66 etype, evalue, tb = sys.exc_info()
65 tb = traceback.format_exception(etype, evalue, tb)
67 stb = traceback.format_exception(etype, evalue, tb)
66 exc_content = {
68 exc_content = {
67 'status' : 'error',
69 'status' : 'error',
68 'traceback' : [ line.encode('utf8') for line in tb ],
70 'traceback' : stb,
69 'etype' : str(etype).encode('utf8'),
71 'ename' : unicode(etype.__name__),
70 'evalue' : evalue.encode('utf8')
72 'evalue' : unicode(evalue),
73 'engine_info' : engine_info
71 }
74 }
72 return exc_content
75 return exc_content
73
76
74 class KernelError(Exception):
75 pass
76
77 def unwrap_exception(content):
77 def unwrap_exception(content):
78 err = KernelError(content['etype'], content['evalue'])
78 err = RemoteError(content['ename'], content['evalue'],
79 err.evalue = content['evalue']
79 ''.join(content['traceback']),
80 err.etype = content['etype']
80 content.get('engine_info', {}))
81 err.traceback = ''.join(content['traceback'])
82 return err
81 return err
83
82
84
83
85 class Message(object):
84 class Message(object):
86 """A simple message object that maps dict keys to attributes.
85 """A simple message object that maps dict keys to attributes.
87
86
88 A Message can be created from a dict and a dict from a Message instance
87 A Message can be created from a dict and a dict from a Message instance
89 simply by calling dict(msg_obj)."""
88 simply by calling dict(msg_obj)."""
90
89
91 def __init__(self, msg_dict):
90 def __init__(self, msg_dict):
92 dct = self.__dict__
91 dct = self.__dict__
93 for k, v in dict(msg_dict).iteritems():
92 for k, v in dict(msg_dict).iteritems():
94 if isinstance(v, dict):
93 if isinstance(v, dict):
95 v = Message(v)
94 v = Message(v)
96 dct[k] = v
95 dct[k] = v
97
96
98 # Having this iterator lets dict(msg_obj) work out of the box.
97 # Having this iterator lets dict(msg_obj) work out of the box.
99 def __iter__(self):
98 def __iter__(self):
100 return iter(self.__dict__.iteritems())
99 return iter(self.__dict__.iteritems())
101
100
102 def __repr__(self):
101 def __repr__(self):
103 return repr(self.__dict__)
102 return repr(self.__dict__)
104
103
105 def __str__(self):
104 def __str__(self):
106 return pprint.pformat(self.__dict__)
105 return pprint.pformat(self.__dict__)
107
106
108 def __contains__(self, k):
107 def __contains__(self, k):
109 return k in self.__dict__
108 return k in self.__dict__
110
109
111 def __getitem__(self, k):
110 def __getitem__(self, k):
112 return self.__dict__[k]
111 return self.__dict__[k]
113
112
114
113
115 def msg_header(msg_id, msg_type, username, session):
114 def msg_header(msg_id, msg_type, username, session):
116 date=datetime.now().strftime(ISO8601)
115 date=datetime.now().strftime(ISO8601)
117 return locals()
116 return locals()
118
117
119 def extract_header(msg_or_header):
118 def extract_header(msg_or_header):
120 """Given a message or header, return the header."""
119 """Given a message or header, return the header."""
121 if not msg_or_header:
120 if not msg_or_header:
122 return {}
121 return {}
123 try:
122 try:
124 # See if msg_or_header is the entire message.
123 # See if msg_or_header is the entire message.
125 h = msg_or_header['header']
124 h = msg_or_header['header']
126 except KeyError:
125 except KeyError:
127 try:
126 try:
128 # See if msg_or_header is just the header
127 # See if msg_or_header is just the header
129 h = msg_or_header['msg_id']
128 h = msg_or_header['msg_id']
130 except KeyError:
129 except KeyError:
131 raise
130 raise
132 else:
131 else:
133 h = msg_or_header
132 h = msg_or_header
134 if not isinstance(h, dict):
133 if not isinstance(h, dict):
135 h = dict(h)
134 h = dict(h)
136 return h
135 return h
137
136
138 def rekey(dikt):
137 def rekey(dikt):
139 """Rekey a dict that has been forced to use str keys where there should be
138 """Rekey a dict that has been forced to use str keys where there should be
140 ints by json. This belongs in the jsonutil added by fperez."""
139 ints by json. This belongs in the jsonutil added by fperez."""
141 for k in dikt.iterkeys():
140 for k in dikt.iterkeys():
142 if isinstance(k, str):
141 if isinstance(k, str):
143 ik=fk=None
142 ik=fk=None
144 try:
143 try:
145 ik = int(k)
144 ik = int(k)
146 except ValueError:
145 except ValueError:
147 try:
146 try:
148 fk = float(k)
147 fk = float(k)
149 except ValueError:
148 except ValueError:
150 continue
149 continue
151 if ik is not None:
150 if ik is not None:
152 nk = ik
151 nk = ik
153 else:
152 else:
154 nk = fk
153 nk = fk
155 if nk in dikt:
154 if nk in dikt:
156 raise KeyError("already have key %r"%nk)
155 raise KeyError("already have key %r"%nk)
157 dikt[nk] = dikt.pop(k)
156 dikt[nk] = dikt.pop(k)
158 return dikt
157 return dikt
159
158
160 def serialize_object(obj, threshold=64e-6):
159 def serialize_object(obj, threshold=64e-6):
161 """Serialize an object into a list of sendable buffers.
160 """Serialize an object into a list of sendable buffers.
162
161
163 Parameters
162 Parameters
164 ----------
163 ----------
165
164
166 obj : object
165 obj : object
167 The object to be serialized
166 The object to be serialized
168 threshold : float
167 threshold : float
169 The threshold for not double-pickling the content.
168 The threshold for not double-pickling the content.
170
169
171
170
172 Returns
171 Returns
173 -------
172 -------
174 ('pmd', [bufs]) :
173 ('pmd', [bufs]) :
175 where pmd is the pickled metadata wrapper,
174 where pmd is the pickled metadata wrapper,
176 bufs is a list of data buffers
175 bufs is a list of data buffers
177 """
176 """
178 databuffers = []
177 databuffers = []
179 if isinstance(obj, (list, tuple)):
178 if isinstance(obj, (list, tuple)):
180 clist = canSequence(obj)
179 clist = canSequence(obj)
181 slist = map(serialize, clist)
180 slist = map(serialize, clist)
182 for s in slist:
181 for s in slist:
183 if s.typeDescriptor in ('buffer', 'ndarray') or s.getDataSize() > threshold:
182 if s.typeDescriptor in ('buffer', 'ndarray') or s.getDataSize() > threshold:
184 databuffers.append(s.getData())
183 databuffers.append(s.getData())
185 s.data = None
184 s.data = None
186 return pickle.dumps(slist,-1), databuffers
185 return pickle.dumps(slist,-1), databuffers
187 elif isinstance(obj, dict):
186 elif isinstance(obj, dict):
188 sobj = {}
187 sobj = {}
189 for k in sorted(obj.iterkeys()):
188 for k in sorted(obj.iterkeys()):
190 s = serialize(can(obj[k]))
189 s = serialize(can(obj[k]))
191 if s.typeDescriptor in ('buffer', 'ndarray') or s.getDataSize() > threshold:
190 if s.typeDescriptor in ('buffer', 'ndarray') or s.getDataSize() > threshold:
192 databuffers.append(s.getData())
191 databuffers.append(s.getData())
193 s.data = None
192 s.data = None
194 sobj[k] = s
193 sobj[k] = s
195 return pickle.dumps(sobj,-1),databuffers
194 return pickle.dumps(sobj,-1),databuffers
196 else:
195 else:
197 s = serialize(can(obj))
196 s = serialize(can(obj))
198 if s.typeDescriptor in ('buffer', 'ndarray') or s.getDataSize() > threshold:
197 if s.typeDescriptor in ('buffer', 'ndarray') or s.getDataSize() > threshold:
199 databuffers.append(s.getData())
198 databuffers.append(s.getData())
200 s.data = None
199 s.data = None
201 return pickle.dumps(s,-1),databuffers
200 return pickle.dumps(s,-1),databuffers
202
201
203
202
204 def unserialize_object(bufs):
203 def unserialize_object(bufs):
205 """reconstruct an object serialized by serialize_object from data buffers"""
204 """reconstruct an object serialized by serialize_object from data buffers"""
206 bufs = list(bufs)
205 bufs = list(bufs)
207 sobj = pickle.loads(bufs.pop(0))
206 sobj = pickle.loads(bufs.pop(0))
208 if isinstance(sobj, (list, tuple)):
207 if isinstance(sobj, (list, tuple)):
209 for s in sobj:
208 for s in sobj:
210 if s.data is None:
209 if s.data is None:
211 s.data = bufs.pop(0)
210 s.data = bufs.pop(0)
212 return uncanSequence(map(unserialize, sobj))
211 return uncanSequence(map(unserialize, sobj))
213 elif isinstance(sobj, dict):
212 elif isinstance(sobj, dict):
214 newobj = {}
213 newobj = {}
215 for k in sorted(sobj.iterkeys()):
214 for k in sorted(sobj.iterkeys()):
216 s = sobj[k]
215 s = sobj[k]
217 if s.data is None:
216 if s.data is None:
218 s.data = bufs.pop(0)
217 s.data = bufs.pop(0)
219 newobj[k] = uncan(unserialize(s))
218 newobj[k] = uncan(unserialize(s))
220 return newobj
219 return newobj
221 else:
220 else:
222 if sobj.data is None:
221 if sobj.data is None:
223 sobj.data = bufs.pop(0)
222 sobj.data = bufs.pop(0)
224 return uncan(unserialize(sobj))
223 return uncan(unserialize(sobj))
225
224
226 def pack_apply_message(f, args, kwargs, threshold=64e-6):
225 def pack_apply_message(f, args, kwargs, threshold=64e-6):
227 """pack up a function, args, and kwargs to be sent over the wire
226 """pack up a function, args, and kwargs to be sent over the wire
228 as a series of buffers. Any object whose data is larger than `threshold`
227 as a series of buffers. Any object whose data is larger than `threshold`
229 will not have their data copied (currently only numpy arrays support zero-copy)"""
228 will not have their data copied (currently only numpy arrays support zero-copy)"""
230 msg = [pickle.dumps(can(f),-1)]
229 msg = [pickle.dumps(can(f),-1)]
231 databuffers = [] # for large objects
230 databuffers = [] # for large objects
232 sargs, bufs = serialize_object(args,threshold)
231 sargs, bufs = serialize_object(args,threshold)
233 msg.append(sargs)
232 msg.append(sargs)
234 databuffers.extend(bufs)
233 databuffers.extend(bufs)
235 skwargs, bufs = serialize_object(kwargs,threshold)
234 skwargs, bufs = serialize_object(kwargs,threshold)
236 msg.append(skwargs)
235 msg.append(skwargs)
237 databuffers.extend(bufs)
236 databuffers.extend(bufs)
238 msg.extend(databuffers)
237 msg.extend(databuffers)
239 return msg
238 return msg
240
239
241 def unpack_apply_message(bufs, g=None, copy=True):
240 def unpack_apply_message(bufs, g=None, copy=True):
242 """unpack f,args,kwargs from buffers packed by pack_apply_message()
241 """unpack f,args,kwargs from buffers packed by pack_apply_message()
243 Returns: original f,args,kwargs"""
242 Returns: original f,args,kwargs"""
244 bufs = list(bufs) # allow us to pop
243 bufs = list(bufs) # allow us to pop
245 assert len(bufs) >= 3, "not enough buffers!"
244 assert len(bufs) >= 3, "not enough buffers!"
246 if not copy:
245 if not copy:
247 for i in range(3):
246 for i in range(3):
248 bufs[i] = bufs[i].bytes
247 bufs[i] = bufs[i].bytes
249 cf = pickle.loads(bufs.pop(0))
248 cf = pickle.loads(bufs.pop(0))
250 sargs = list(pickle.loads(bufs.pop(0)))
249 sargs = list(pickle.loads(bufs.pop(0)))
251 skwargs = dict(pickle.loads(bufs.pop(0)))
250 skwargs = dict(pickle.loads(bufs.pop(0)))
252 # print sargs, skwargs
251 # print sargs, skwargs
253 f = uncan(cf, g)
252 f = uncan(cf, g)
254 for sa in sargs:
253 for sa in sargs:
255 if sa.data is None:
254 if sa.data is None:
256 m = bufs.pop(0)
255 m = bufs.pop(0)
257 if sa.getTypeDescriptor() in ('buffer', 'ndarray'):
256 if sa.getTypeDescriptor() in ('buffer', 'ndarray'):
258 if copy:
257 if copy:
259 sa.data = buffer(m)
258 sa.data = buffer(m)
260 else:
259 else:
261 sa.data = m.buffer
260 sa.data = m.buffer
262 else:
261 else:
263 if copy:
262 if copy:
264 sa.data = m
263 sa.data = m
265 else:
264 else:
266 sa.data = m.bytes
265 sa.data = m.bytes
267
266
268 args = uncanSequence(map(unserialize, sargs), g)
267 args = uncanSequence(map(unserialize, sargs), g)
269 kwargs = {}
268 kwargs = {}
270 for k in sorted(skwargs.iterkeys()):
269 for k in sorted(skwargs.iterkeys()):
271 sa = skwargs[k]
270 sa = skwargs[k]
272 if sa.data is None:
271 if sa.data is None:
273 sa.data = bufs.pop(0)
272 sa.data = bufs.pop(0)
274 kwargs[k] = uncan(unserialize(sa), g)
273 kwargs[k] = uncan(unserialize(sa), g)
275
274
276 return f,args,kwargs
275 return f,args,kwargs
277
276
278 class StreamSession(object):
277 class StreamSession(object):
279 """tweaked version of IPython.zmq.session.Session, for development in Parallel"""
278 """tweaked version of IPython.zmq.session.Session, for development in Parallel"""
280 debug=False
279 debug=False
281 key=None
280 key=None
282
281
283 def __init__(self, username=None, session=None, packer=None, unpacker=None, key=None, keyfile=None):
282 def __init__(self, username=None, session=None, packer=None, unpacker=None, key=None, keyfile=None):
284 if username is None:
283 if username is None:
285 username = os.environ.get('USER','username')
284 username = os.environ.get('USER','username')
286 self.username = username
285 self.username = username
287 if session is None:
286 if session is None:
288 self.session = str(uuid.uuid4())
287 self.session = str(uuid.uuid4())
289 else:
288 else:
290 self.session = session
289 self.session = session
291 self.msg_id = str(uuid.uuid4())
290 self.msg_id = str(uuid.uuid4())
292 if packer is None:
291 if packer is None:
293 self.pack = default_packer
292 self.pack = default_packer
294 else:
293 else:
295 if not callable(packer):
294 if not callable(packer):
296 raise TypeError("packer must be callable, not %s"%type(packer))
295 raise TypeError("packer must be callable, not %s"%type(packer))
297 self.pack = packer
296 self.pack = packer
298
297
299 if unpacker is None:
298 if unpacker is None:
300 self.unpack = default_unpacker
299 self.unpack = default_unpacker
301 else:
300 else:
302 if not callable(unpacker):
301 if not callable(unpacker):
303 raise TypeError("unpacker must be callable, not %s"%type(unpacker))
302 raise TypeError("unpacker must be callable, not %s"%type(unpacker))
304 self.unpack = unpacker
303 self.unpack = unpacker
305
304
306 if key is not None and keyfile is not None:
305 if key is not None and keyfile is not None:
307 raise TypeError("Must specify key OR keyfile, not both")
306 raise TypeError("Must specify key OR keyfile, not both")
308 if keyfile is not None:
307 if keyfile is not None:
309 with open(keyfile) as f:
308 with open(keyfile) as f:
310 self.key = f.read().strip()
309 self.key = f.read().strip()
311 else:
310 else:
312 self.key = key
311 self.key = key
313 # print key, keyfile, self.key
312 # print key, keyfile, self.key
314 self.none = self.pack({})
313 self.none = self.pack({})
315
314
316 def msg_header(self, msg_type):
315 def msg_header(self, msg_type):
317 h = msg_header(self.msg_id, msg_type, self.username, self.session)
316 h = msg_header(self.msg_id, msg_type, self.username, self.session)
318 self.msg_id = str(uuid.uuid4())
317 self.msg_id = str(uuid.uuid4())
319 return h
318 return h
320
319
321 def msg(self, msg_type, content=None, parent=None, subheader=None):
320 def msg(self, msg_type, content=None, parent=None, subheader=None):
322 msg = {}
321 msg = {}
323 msg['header'] = self.msg_header(msg_type)
322 msg['header'] = self.msg_header(msg_type)
324 msg['msg_id'] = msg['header']['msg_id']
323 msg['msg_id'] = msg['header']['msg_id']
325 msg['parent_header'] = {} if parent is None else extract_header(parent)
324 msg['parent_header'] = {} if parent is None else extract_header(parent)
326 msg['msg_type'] = msg_type
325 msg['msg_type'] = msg_type
327 msg['content'] = {} if content is None else content
326 msg['content'] = {} if content is None else content
328 sub = {} if subheader is None else subheader
327 sub = {} if subheader is None else subheader
329 msg['header'].update(sub)
328 msg['header'].update(sub)
330 return msg
329 return msg
331
330
332 def check_key(self, msg_or_header):
331 def check_key(self, msg_or_header):
333 """Check that a message's header has the right key"""
332 """Check that a message's header has the right key"""
334 if self.key is None:
333 if self.key is None:
335 return True
334 return True
336 header = extract_header(msg_or_header)
335 header = extract_header(msg_or_header)
337 return header.get('key', None) == self.key
336 return header.get('key', None) == self.key
338
337
339
338
340 def send(self, stream, msg_type, content=None, buffers=None, parent=None, subheader=None, ident=None):
339 def send(self, stream, msg_type, content=None, buffers=None, parent=None, subheader=None, ident=None):
341 """Build and send a message via stream or socket.
340 """Build and send a message via stream or socket.
342
341
343 Parameters
342 Parameters
344 ----------
343 ----------
345
344
346 stream : zmq.Socket or ZMQStream
345 stream : zmq.Socket or ZMQStream
347 the socket-like object used to send the data
346 the socket-like object used to send the data
348 msg_type : str or Message/dict
347 msg_type : str or Message/dict
349 Normally, msg_type will be
348 Normally, msg_type will be
350
349
351
350
352
351
353 Returns
352 Returns
354 -------
353 -------
355 (msg,sent) : tuple
354 (msg,sent) : tuple
356 msg : Message
355 msg : Message
357 the nice wrapped dict-like object containing the headers
356 the nice wrapped dict-like object containing the headers
358
357
359 """
358 """
360 if isinstance(msg_type, (Message, dict)):
359 if isinstance(msg_type, (Message, dict)):
361 # we got a Message, not a msg_type
360 # we got a Message, not a msg_type
362 # don't build a new Message
361 # don't build a new Message
363 msg = msg_type
362 msg = msg_type
364 content = msg['content']
363 content = msg['content']
365 else:
364 else:
366 msg = self.msg(msg_type, content, parent, subheader)
365 msg = self.msg(msg_type, content, parent, subheader)
367 buffers = [] if buffers is None else buffers
366 buffers = [] if buffers is None else buffers
368 to_send = []
367 to_send = []
369 if isinstance(ident, list):
368 if isinstance(ident, list):
370 # accept list of idents
369 # accept list of idents
371 to_send.extend(ident)
370 to_send.extend(ident)
372 elif ident is not None:
371 elif ident is not None:
373 to_send.append(ident)
372 to_send.append(ident)
374 to_send.append(DELIM)
373 to_send.append(DELIM)
375 if self.key is not None:
374 if self.key is not None:
376 to_send.append(self.key)
375 to_send.append(self.key)
377 to_send.append(self.pack(msg['header']))
376 to_send.append(self.pack(msg['header']))
378 to_send.append(self.pack(msg['parent_header']))
377 to_send.append(self.pack(msg['parent_header']))
379
378
380 if content is None:
379 if content is None:
381 content = self.none
380 content = self.none
382 elif isinstance(content, dict):
381 elif isinstance(content, dict):
383 content = self.pack(content)
382 content = self.pack(content)
384 elif isinstance(content, str):
383 elif isinstance(content, str):
385 # content is already packed, as in a relayed message
384 # content is already packed, as in a relayed message
386 pass
385 pass
387 else:
386 else:
388 raise TypeError("Content incorrect type: %s"%type(content))
387 raise TypeError("Content incorrect type: %s"%type(content))
389 to_send.append(content)
388 to_send.append(content)
390 flag = 0
389 flag = 0
391 if buffers:
390 if buffers:
392 flag = zmq.SNDMORE
391 flag = zmq.SNDMORE
393 stream.send_multipart(to_send, flag, copy=False)
392 stream.send_multipart(to_send, flag, copy=False)
394 for b in buffers[:-1]:
393 for b in buffers[:-1]:
395 stream.send(b, flag, copy=False)
394 stream.send(b, flag, copy=False)
396 if buffers:
395 if buffers:
397 stream.send(buffers[-1], copy=False)
396 stream.send(buffers[-1], copy=False)
398 omsg = Message(msg)
397 omsg = Message(msg)
399 if self.debug:
398 if self.debug:
400 pprint.pprint(omsg)
399 pprint.pprint(omsg)
401 pprint.pprint(to_send)
400 pprint.pprint(to_send)
402 pprint.pprint(buffers)
401 pprint.pprint(buffers)
403 return omsg
402 return omsg
404
403
405 def send_raw(self, stream, msg, flags=0, copy=True, idents=None):
404 def send_raw(self, stream, msg, flags=0, copy=True, ident=None):
406 """Send a raw message via idents.
405 """Send a raw message via idents.
407
406
408 Parameters
407 Parameters
409 ----------
408 ----------
410 msg : list of sendable buffers"""
409 msg : list of sendable buffers"""
411 to_send = []
410 to_send = []
412 if isinstance(ident, str):
411 if isinstance(ident, str):
413 ident = [ident]
412 ident = [ident]
414 if ident is not None:
413 if ident is not None:
415 to_send.extend(ident)
414 to_send.extend(ident)
416 to_send.append(DELIM)
415 to_send.append(DELIM)
417 if self.key is not None:
416 if self.key is not None:
418 to_send.append(self.key)
417 to_send.append(self.key)
419 to_send.extend(msg)
418 to_send.extend(msg)
420 stream.send_multipart(msg, flags, copy=copy)
419 stream.send_multipart(msg, flags, copy=copy)
421
420
422 def recv(self, socket, mode=zmq.NOBLOCK, content=True, copy=True):
421 def recv(self, socket, mode=zmq.NOBLOCK, content=True, copy=True):
423 """receives and unpacks a message
422 """receives and unpacks a message
424 returns [idents], msg"""
423 returns [idents], msg"""
425 if isinstance(socket, ZMQStream):
424 if isinstance(socket, ZMQStream):
426 socket = socket.socket
425 socket = socket.socket
427 try:
426 try:
428 msg = socket.recv_multipart(mode)
427 msg = socket.recv_multipart(mode)
429 except zmq.ZMQError as e:
428 except zmq.ZMQError as e:
430 if e.errno == zmq.EAGAIN:
429 if e.errno == zmq.EAGAIN:
431 # We can convert EAGAIN to None as we know in this case
430 # We can convert EAGAIN to None as we know in this case
432 # recv_json won't return None.
431 # recv_json won't return None.
433 return None
432 return None
434 else:
433 else:
435 raise
434 raise
436 # return an actual Message object
435 # return an actual Message object
437 # determine the number of idents by trying to unpack them.
436 # determine the number of idents by trying to unpack them.
438 # this is terrible:
437 # this is terrible:
439 idents, msg = self.feed_identities(msg, copy)
438 idents, msg = self.feed_identities(msg, copy)
440 try:
439 try:
441 return idents, self.unpack_message(msg, content=content, copy=copy)
440 return idents, self.unpack_message(msg, content=content, copy=copy)
442 except Exception as e:
441 except Exception as e:
443 print (idents, msg)
442 print (idents, msg)
444 # TODO: handle it
443 # TODO: handle it
445 raise e
444 raise e
446
445
447 def feed_identities(self, msg, copy=True):
446 def feed_identities(self, msg, copy=True):
448 """This is a completely horrible thing, but it strips the zmq
447 """This is a completely horrible thing, but it strips the zmq
449 ident prefixes off of a message. It will break if any identities
448 ident prefixes off of a message. It will break if any identities
450 are unpackable by self.unpack."""
449 are unpackable by self.unpack."""
451 msg = list(msg)
450 msg = list(msg)
452 idents = []
451 idents = []
453 while len(msg) > 3:
452 while len(msg) > 3:
454 if copy:
453 if copy:
455 s = msg[0]
454 s = msg[0]
456 else:
455 else:
457 s = msg[0].bytes
456 s = msg[0].bytes
458 if s == DELIM:
457 if s == DELIM:
459 msg.pop(0)
458 msg.pop(0)
460 break
459 break
461 else:
460 else:
462 idents.append(s)
461 idents.append(s)
463 msg.pop(0)
462 msg.pop(0)
464
463
465 return idents, msg
464 return idents, msg
466
465
467 def unpack_message(self, msg, content=True, copy=True):
466 def unpack_message(self, msg, content=True, copy=True):
468 """Return a message object from the format
467 """Return a message object from the format
469 sent by self.send.
468 sent by self.send.
470
469
471 Parameters:
470 Parameters:
472 -----------
471 -----------
473
472
474 content : bool (True)
473 content : bool (True)
475 whether to unpack the content dict (True),
474 whether to unpack the content dict (True),
476 or leave it serialized (False)
475 or leave it serialized (False)
477
476
478 copy : bool (True)
477 copy : bool (True)
479 whether to return the bytes (True),
478 whether to return the bytes (True),
480 or the non-copying Message object in each place (False)
479 or the non-copying Message object in each place (False)
481
480
482 """
481 """
483 ikey = int(self.key is not None)
482 ikey = int(self.key is not None)
484 minlen = 3 + ikey
483 minlen = 3 + ikey
485 if not len(msg) >= minlen:
484 if not len(msg) >= minlen:
486 raise TypeError("malformed message, must have at least %i elements"%minlen)
485 raise TypeError("malformed message, must have at least %i elements"%minlen)
487 message = {}
486 message = {}
488 if not copy:
487 if not copy:
489 for i in range(minlen):
488 for i in range(minlen):
490 msg[i] = msg[i].bytes
489 msg[i] = msg[i].bytes
491 if ikey:
490 if ikey:
492 if not self.key == msg[0]:
491 if not self.key == msg[0]:
493 raise KeyError("Invalid Session Key: %s"%msg[0])
492 raise KeyError("Invalid Session Key: %s"%msg[0])
494 message['header'] = self.unpack(msg[ikey+0])
493 message['header'] = self.unpack(msg[ikey+0])
495 message['msg_type'] = message['header']['msg_type']
494 message['msg_type'] = message['header']['msg_type']
496 message['parent_header'] = self.unpack(msg[ikey+1])
495 message['parent_header'] = self.unpack(msg[ikey+1])
497 if content:
496 if content:
498 message['content'] = self.unpack(msg[ikey+2])
497 message['content'] = self.unpack(msg[ikey+2])
499 else:
498 else:
500 message['content'] = msg[ikey+2]
499 message['content'] = msg[ikey+2]
501
500
502 # message['buffers'] = msg[3:]
501 # message['buffers'] = msg[3:]
503 # else:
502 # else:
504 # message['header'] = self.unpack(msg[0].bytes)
503 # message['header'] = self.unpack(msg[0].bytes)
505 # message['msg_type'] = message['header']['msg_type']
504 # message['msg_type'] = message['header']['msg_type']
506 # message['parent_header'] = self.unpack(msg[1].bytes)
505 # message['parent_header'] = self.unpack(msg[1].bytes)
507 # if content:
506 # if content:
508 # message['content'] = self.unpack(msg[2].bytes)
507 # message['content'] = self.unpack(msg[2].bytes)
509 # else:
508 # else:
510 # message['content'] = msg[2].bytes
509 # message['content'] = msg[2].bytes
511
510
512 message['buffers'] = msg[ikey+3:]# [ m.buffer for m in msg[3:] ]
511 message['buffers'] = msg[ikey+3:]# [ m.buffer for m in msg[3:] ]
513 return message
512 return message
514
513
515
514
516
515
517 def test_msg2obj():
516 def test_msg2obj():
518 am = dict(x=1)
517 am = dict(x=1)
519 ao = Message(am)
518 ao = Message(am)
520 assert ao.x == am['x']
519 assert ao.x == am['x']
521
520
522 am['y'] = dict(z=1)
521 am['y'] = dict(z=1)
523 ao = Message(am)
522 ao = Message(am)
524 assert ao.y.z == am['y']['z']
523 assert ao.y.z == am['y']['z']
525
524
526 k1, k2 = 'y', 'z'
525 k1, k2 = 'y', 'z'
527 assert ao[k1][k2] == am[k1][k2]
526 assert ao[k1][k2] == am[k1][k2]
528
527
529 am2 = dict(ao)
528 am2 = dict(ao)
530 assert am['x'] == am2['x']
529 assert am['x'] == am2['x']
531 assert am['y']['z'] == am2['y']['z']
530 assert am['y']['z'] == am2['y']['z']
General Comments 0
You need to be logged in to leave comments. Login now