##// END OF EJS Templates
Merge branch 'sessionwork'
Brian E. Granger -
r4517:8de5e115 merge
parent child Browse files
Show More
@@ -1,109 +1,109 b''
1 1 """ Defines a convenient mix-in class for implementing Qt frontends.
2 2 """
3 3
4 4 class BaseFrontendMixin(object):
5 5 """ A mix-in class for implementing Qt frontends.
6 6
7 7 To handle messages of a particular type, frontends need only define an
8 8 appropriate handler method. For example, to handle 'stream' messaged, define
9 9 a '_handle_stream(msg)' method.
10 10 """
11 11
12 12 #---------------------------------------------------------------------------
13 13 # 'BaseFrontendMixin' concrete interface
14 14 #---------------------------------------------------------------------------
15 15
16 16 def _get_kernel_manager(self):
17 17 """ Returns the current kernel manager.
18 18 """
19 19 return self._kernel_manager
20 20
21 21 def _set_kernel_manager(self, kernel_manager):
22 22 """ Disconnect from the current kernel manager (if any) and set a new
23 23 kernel manager.
24 24 """
25 25 # Disconnect the old kernel manager, if necessary.
26 26 old_manager = self._kernel_manager
27 27 if old_manager is not None:
28 28 old_manager.started_channels.disconnect(self._started_channels)
29 29 old_manager.stopped_channels.disconnect(self._stopped_channels)
30 30
31 31 # Disconnect the old kernel manager's channels.
32 32 old_manager.sub_channel.message_received.disconnect(self._dispatch)
33 33 old_manager.shell_channel.message_received.disconnect(self._dispatch)
34 34 old_manager.stdin_channel.message_received.disconnect(self._dispatch)
35 35 old_manager.hb_channel.kernel_died.disconnect(
36 36 self._handle_kernel_died)
37 37
38 38 # Handle the case where the old kernel manager is still listening.
39 39 if old_manager.channels_running:
40 40 self._stopped_channels()
41 41
42 42 # Set the new kernel manager.
43 43 self._kernel_manager = kernel_manager
44 44 if kernel_manager is None:
45 45 return
46 46
47 47 # Connect the new kernel manager.
48 48 kernel_manager.started_channels.connect(self._started_channels)
49 49 kernel_manager.stopped_channels.connect(self._stopped_channels)
50 50
51 51 # Connect the new kernel manager's channels.
52 52 kernel_manager.sub_channel.message_received.connect(self._dispatch)
53 53 kernel_manager.shell_channel.message_received.connect(self._dispatch)
54 54 kernel_manager.stdin_channel.message_received.connect(self._dispatch)
55 55 kernel_manager.hb_channel.kernel_died.connect(self._handle_kernel_died)
56 56
57 57 # Handle the case where the kernel manager started channels before
58 58 # we connected.
59 59 if kernel_manager.channels_running:
60 60 self._started_channels()
61 61
62 62 kernel_manager = property(_get_kernel_manager, _set_kernel_manager)
63 63
64 64 #---------------------------------------------------------------------------
65 65 # 'BaseFrontendMixin' abstract interface
66 66 #---------------------------------------------------------------------------
67 67
68 68 def _handle_kernel_died(self, since_last_heartbeat):
69 69 """ This is called when the ``kernel_died`` signal is emitted.
70 70
71 71 This method is called when the kernel heartbeat has not been
72 72 active for a certain amount of time. The typical action will be to
73 73 give the user the option of restarting the kernel.
74 74
75 75 Parameters
76 76 ----------
77 77 since_last_heartbeat : float
78 78 The time since the heartbeat was last received.
79 79 """
80 80
81 81 def _started_channels(self):
82 82 """ Called when the KernelManager channels have started listening or
83 83 when the frontend is assigned an already listening KernelManager.
84 84 """
85 85
86 86 def _stopped_channels(self):
87 87 """ Called when the KernelManager channels have stopped listening or
88 88 when a listening KernelManager is removed from the frontend.
89 89 """
90 90
91 91 #---------------------------------------------------------------------------
92 92 # 'BaseFrontendMixin' protected interface
93 93 #---------------------------------------------------------------------------
94 94
95 95 def _dispatch(self, msg):
96 96 """ Calls the frontend handler associated with the message type of the
97 97 given message.
98 98 """
99 msg_type = msg['msg_type']
99 msg_type = msg['header']['msg_type']
100 100 handler = getattr(self, '_handle_' + msg_type, None)
101 101 if handler:
102 102 handler(msg)
103 103
104 104 def _is_from_this_session(self, msg):
105 105 """ Returns whether a reply from the kernel originated from a request
106 106 from this frontend.
107 107 """
108 108 session = self._kernel_manager.session.session
109 109 return msg['parent_header']['session'] == session
@@ -1,243 +1,243 b''
1 1 """ Defines a KernelManager that provides signals and slots.
2 2 """
3 3
4 4 # System library imports.
5 5 from IPython.external.qt import QtCore
6 6
7 7 # IPython imports.
8 8 from IPython.utils.traitlets import Type
9 9 from IPython.zmq.kernelmanager import KernelManager, SubSocketChannel, \
10 10 ShellSocketChannel, StdInSocketChannel, HBSocketChannel
11 11 from util import MetaQObjectHasTraits, SuperQObject
12 12
13 13
14 14 class SocketChannelQObject(SuperQObject):
15 15
16 16 # Emitted when the channel is started.
17 17 started = QtCore.Signal()
18 18
19 19 # Emitted when the channel is stopped.
20 20 stopped = QtCore.Signal()
21 21
22 22 #---------------------------------------------------------------------------
23 23 # 'ZMQSocketChannel' interface
24 24 #---------------------------------------------------------------------------
25 25
26 26 def start(self):
27 27 """ Reimplemented to emit signal.
28 28 """
29 29 super(SocketChannelQObject, self).start()
30 30 self.started.emit()
31 31
32 32 def stop(self):
33 33 """ Reimplemented to emit signal.
34 34 """
35 35 super(SocketChannelQObject, self).stop()
36 36 self.stopped.emit()
37 37
38 38
39 39 class QtShellSocketChannel(SocketChannelQObject, ShellSocketChannel):
40 40
41 41 # Emitted when any message is received.
42 42 message_received = QtCore.Signal(object)
43 43
44 44 # Emitted when a reply has been received for the corresponding request
45 45 # type.
46 46 execute_reply = QtCore.Signal(object)
47 47 complete_reply = QtCore.Signal(object)
48 48 object_info_reply = QtCore.Signal(object)
49 49 history_reply = QtCore.Signal(object)
50 50
51 51 # Emitted when the first reply comes back.
52 52 first_reply = QtCore.Signal()
53 53
54 54 # Used by the first_reply signal logic to determine if a reply is the
55 55 # first.
56 56 _handlers_called = False
57 57
58 58 #---------------------------------------------------------------------------
59 59 # 'ShellSocketChannel' interface
60 60 #---------------------------------------------------------------------------
61 61
62 62 def call_handlers(self, msg):
63 63 """ Reimplemented to emit signals instead of making callbacks.
64 64 """
65 65 # Emit the generic signal.
66 66 self.message_received.emit(msg)
67 67
68 68 # Emit signals for specialized message types.
69 msg_type = msg['msg_type']
69 msg_type = msg['header']['msg_type']
70 70 signal = getattr(self, msg_type, None)
71 71 if signal:
72 72 signal.emit(msg)
73 73
74 74 if not self._handlers_called:
75 75 self.first_reply.emit()
76 76 self._handlers_called = True
77 77
78 78 #---------------------------------------------------------------------------
79 79 # 'QtShellSocketChannel' interface
80 80 #---------------------------------------------------------------------------
81 81
82 82 def reset_first_reply(self):
83 83 """ Reset the first_reply signal to fire again on the next reply.
84 84 """
85 85 self._handlers_called = False
86 86
87 87
88 88 class QtSubSocketChannel(SocketChannelQObject, SubSocketChannel):
89 89
90 90 # Emitted when any message is received.
91 91 message_received = QtCore.Signal(object)
92 92
93 93 # Emitted when a message of type 'stream' is received.
94 94 stream_received = QtCore.Signal(object)
95 95
96 96 # Emitted when a message of type 'pyin' is received.
97 97 pyin_received = QtCore.Signal(object)
98 98
99 99 # Emitted when a message of type 'pyout' is received.
100 100 pyout_received = QtCore.Signal(object)
101 101
102 102 # Emitted when a message of type 'pyerr' is received.
103 103 pyerr_received = QtCore.Signal(object)
104 104
105 105 # Emitted when a message of type 'display_data' is received
106 106 display_data_received = QtCore.Signal(object)
107 107
108 108 # Emitted when a crash report message is received from the kernel's
109 109 # last-resort sys.excepthook.
110 110 crash_received = QtCore.Signal(object)
111 111
112 112 # Emitted when a shutdown is noticed.
113 113 shutdown_reply_received = QtCore.Signal(object)
114 114
115 115 #---------------------------------------------------------------------------
116 116 # 'SubSocketChannel' interface
117 117 #---------------------------------------------------------------------------
118 118
119 119 def call_handlers(self, msg):
120 120 """ Reimplemented to emit signals instead of making callbacks.
121 121 """
122 122 # Emit the generic signal.
123 123 self.message_received.emit(msg)
124 124 # Emit signals for specialized message types.
125 msg_type = msg['msg_type']
125 msg_type = msg['header']['msg_type']
126 126 signal = getattr(self, msg_type + '_received', None)
127 127 if signal:
128 128 signal.emit(msg)
129 129 elif msg_type in ('stdout', 'stderr'):
130 130 self.stream_received.emit(msg)
131 131
132 132 def flush(self):
133 133 """ Reimplemented to ensure that signals are dispatched immediately.
134 134 """
135 135 super(QtSubSocketChannel, self).flush()
136 136 QtCore.QCoreApplication.instance().processEvents()
137 137
138 138
139 139 class QtStdInSocketChannel(SocketChannelQObject, StdInSocketChannel):
140 140
141 141 # Emitted when any message is received.
142 142 message_received = QtCore.Signal(object)
143 143
144 144 # Emitted when an input request is received.
145 145 input_requested = QtCore.Signal(object)
146 146
147 147 #---------------------------------------------------------------------------
148 148 # 'StdInSocketChannel' interface
149 149 #---------------------------------------------------------------------------
150 150
151 151 def call_handlers(self, msg):
152 152 """ Reimplemented to emit signals instead of making callbacks.
153 153 """
154 154 # Emit the generic signal.
155 155 self.message_received.emit(msg)
156 156
157 157 # Emit signals for specialized message types.
158 msg_type = msg['msg_type']
158 msg_type = msg['header']['msg_type']
159 159 if msg_type == 'input_request':
160 160 self.input_requested.emit(msg)
161 161
162 162
163 163 class QtHBSocketChannel(SocketChannelQObject, HBSocketChannel):
164 164
165 165 # Emitted when the kernel has died.
166 166 kernel_died = QtCore.Signal(object)
167 167
168 168 #---------------------------------------------------------------------------
169 169 # 'HBSocketChannel' interface
170 170 #---------------------------------------------------------------------------
171 171
172 172 def call_handlers(self, since_last_heartbeat):
173 173 """ Reimplemented to emit signals instead of making callbacks.
174 174 """
175 175 # Emit the generic signal.
176 176 self.kernel_died.emit(since_last_heartbeat)
177 177
178 178
179 179 class QtKernelManager(KernelManager, SuperQObject):
180 180 """ A KernelManager that provides signals and slots.
181 181 """
182 182
183 183 __metaclass__ = MetaQObjectHasTraits
184 184
185 185 # Emitted when the kernel manager has started listening.
186 186 started_channels = QtCore.Signal()
187 187
188 188 # Emitted when the kernel manager has stopped listening.
189 189 stopped_channels = QtCore.Signal()
190 190
191 191 # Use Qt-specific channel classes that emit signals.
192 192 sub_channel_class = Type(QtSubSocketChannel)
193 193 shell_channel_class = Type(QtShellSocketChannel)
194 194 stdin_channel_class = Type(QtStdInSocketChannel)
195 195 hb_channel_class = Type(QtHBSocketChannel)
196 196
197 197 #---------------------------------------------------------------------------
198 198 # 'KernelManager' interface
199 199 #---------------------------------------------------------------------------
200 200
201 201 #------ Kernel process management ------------------------------------------
202 202
203 203 def start_kernel(self, *args, **kw):
204 204 """ Reimplemented for proper heartbeat management.
205 205 """
206 206 if self._shell_channel is not None:
207 207 self._shell_channel.reset_first_reply()
208 208 super(QtKernelManager, self).start_kernel(*args, **kw)
209 209
210 210 #------ Channel management -------------------------------------------------
211 211
212 212 def start_channels(self, *args, **kw):
213 213 """ Reimplemented to emit signal.
214 214 """
215 215 super(QtKernelManager, self).start_channels(*args, **kw)
216 216 self.started_channels.emit()
217 217
218 218 def stop_channels(self):
219 219 """ Reimplemented to emit signal.
220 220 """
221 221 super(QtKernelManager, self).stop_channels()
222 222 self.stopped_channels.emit()
223 223
224 224 @property
225 225 def shell_channel(self):
226 226 """ Reimplemented for proper heartbeat management.
227 227 """
228 228 if self._shell_channel is None:
229 229 self._shell_channel = super(QtKernelManager, self).shell_channel
230 230 self._shell_channel.first_reply.connect(self._first_reply)
231 231 return self._shell_channel
232 232
233 233 #---------------------------------------------------------------------------
234 234 # Protected interface
235 235 #---------------------------------------------------------------------------
236 236
237 237 def _first_reply(self):
238 238 """ Unpauses the heartbeat channel when the first reply is received on
239 239 the execute channel. Note that this will *not* start the heartbeat
240 240 channel if it is not already running!
241 241 """
242 242 if self._hb_channel is not None:
243 243 self._hb_channel.unpause()
@@ -1,1435 +1,1435 b''
1 1 """A semi-synchronous Client for the ZMQ cluster
2 2
3 3 Authors:
4 4
5 5 * MinRK
6 6 """
7 7 #-----------------------------------------------------------------------------
8 8 # Copyright (C) 2010-2011 The IPython Development Team
9 9 #
10 10 # Distributed under the terms of the BSD License. The full license is in
11 11 # the file COPYING, distributed as part of this software.
12 12 #-----------------------------------------------------------------------------
13 13
14 14 #-----------------------------------------------------------------------------
15 15 # Imports
16 16 #-----------------------------------------------------------------------------
17 17
18 18 import os
19 19 import json
20 20 import sys
21 21 import time
22 22 import warnings
23 23 from datetime import datetime
24 24 from getpass import getpass
25 25 from pprint import pprint
26 26
27 27 pjoin = os.path.join
28 28
29 29 import zmq
30 30 # from zmq.eventloop import ioloop, zmqstream
31 31
32 32 from IPython.config.configurable import MultipleInstanceError
33 33 from IPython.core.application import BaseIPythonApplication
34 34
35 35 from IPython.utils.jsonutil import rekey
36 36 from IPython.utils.localinterfaces import LOCAL_IPS
37 37 from IPython.utils.path import get_ipython_dir
38 38 from IPython.utils.traitlets import (HasTraits, Int, Instance, Unicode,
39 39 Dict, List, Bool, Set)
40 40 from IPython.external.decorator import decorator
41 41 from IPython.external.ssh import tunnel
42 42
43 43 from IPython.parallel import error
44 44 from IPython.parallel import util
45 45
46 46 from IPython.zmq.session import Session, Message
47 47
48 48 from .asyncresult import AsyncResult, AsyncHubResult
49 49 from IPython.core.profiledir import ProfileDir, ProfileDirError
50 50 from .view import DirectView, LoadBalancedView
51 51
52 52 if sys.version_info[0] >= 3:
53 53 # xrange is used in a couple 'isinstance' tests in py2
54 54 # should be just 'range' in 3k
55 55 xrange = range
56 56
57 57 #--------------------------------------------------------------------------
58 58 # Decorators for Client methods
59 59 #--------------------------------------------------------------------------
60 60
61 61 @decorator
62 62 def spin_first(f, self, *args, **kwargs):
63 63 """Call spin() to sync state prior to calling the method."""
64 64 self.spin()
65 65 return f(self, *args, **kwargs)
66 66
67 67
68 68 #--------------------------------------------------------------------------
69 69 # Classes
70 70 #--------------------------------------------------------------------------
71 71
72 72 class Metadata(dict):
73 73 """Subclass of dict for initializing metadata values.
74 74
75 75 Attribute access works on keys.
76 76
77 77 These objects have a strict set of keys - errors will raise if you try
78 78 to add new keys.
79 79 """
80 80 def __init__(self, *args, **kwargs):
81 81 dict.__init__(self)
82 82 md = {'msg_id' : None,
83 83 'submitted' : None,
84 84 'started' : None,
85 85 'completed' : None,
86 86 'received' : None,
87 87 'engine_uuid' : None,
88 88 'engine_id' : None,
89 89 'follow' : None,
90 90 'after' : None,
91 91 'status' : None,
92 92
93 93 'pyin' : None,
94 94 'pyout' : None,
95 95 'pyerr' : None,
96 96 'stdout' : '',
97 97 'stderr' : '',
98 98 }
99 99 self.update(md)
100 100 self.update(dict(*args, **kwargs))
101 101
102 102 def __getattr__(self, key):
103 103 """getattr aliased to getitem"""
104 104 if key in self.iterkeys():
105 105 return self[key]
106 106 else:
107 107 raise AttributeError(key)
108 108
109 109 def __setattr__(self, key, value):
110 110 """setattr aliased to setitem, with strict"""
111 111 if key in self.iterkeys():
112 112 self[key] = value
113 113 else:
114 114 raise AttributeError(key)
115 115
116 116 def __setitem__(self, key, value):
117 117 """strict static key enforcement"""
118 118 if key in self.iterkeys():
119 119 dict.__setitem__(self, key, value)
120 120 else:
121 121 raise KeyError(key)
122 122
123 123
124 124 class Client(HasTraits):
125 125 """A semi-synchronous client to the IPython ZMQ cluster
126 126
127 127 Parameters
128 128 ----------
129 129
130 130 url_or_file : bytes or unicode; zmq url or path to ipcontroller-client.json
131 131 Connection information for the Hub's registration. If a json connector
132 132 file is given, then likely no further configuration is necessary.
133 133 [Default: use profile]
134 134 profile : bytes
135 135 The name of the Cluster profile to be used to find connector information.
136 136 If run from an IPython application, the default profile will be the same
137 137 as the running application, otherwise it will be 'default'.
138 138 context : zmq.Context
139 139 Pass an existing zmq.Context instance, otherwise the client will create its own.
140 140 debug : bool
141 141 flag for lots of message printing for debug purposes
142 142 timeout : int/float
143 143 time (in seconds) to wait for connection replies from the Hub
144 144 [Default: 10]
145 145
146 146 #-------------- session related args ----------------
147 147
148 148 config : Config object
149 149 If specified, this will be relayed to the Session for configuration
150 150 username : str
151 151 set username for the session object
152 152 packer : str (import_string) or callable
153 153 Can be either the simple keyword 'json' or 'pickle', or an import_string to a
154 154 function to serialize messages. Must support same input as
155 155 JSON, and output must be bytes.
156 156 You can pass a callable directly as `pack`
157 157 unpacker : str (import_string) or callable
158 158 The inverse of packer. Only necessary if packer is specified as *not* one
159 159 of 'json' or 'pickle'.
160 160
161 161 #-------------- ssh related args ----------------
162 162 # These are args for configuring the ssh tunnel to be used
163 163 # credentials are used to forward connections over ssh to the Controller
164 164 # Note that the ip given in `addr` needs to be relative to sshserver
165 165 # The most basic case is to leave addr as pointing to localhost (127.0.0.1),
166 166 # and set sshserver as the same machine the Controller is on. However,
167 167 # the only requirement is that sshserver is able to see the Controller
168 168 # (i.e. is within the same trusted network).
169 169
170 170 sshserver : str
171 171 A string of the form passed to ssh, i.e. 'server.tld' or 'user@server.tld:port'
172 172 If keyfile or password is specified, and this is not, it will default to
173 173 the ip given in addr.
174 174 sshkey : str; path to public ssh key file
175 175 This specifies a key to be used in ssh login, default None.
176 176 Regular default ssh keys will be used without specifying this argument.
177 177 password : str
178 178 Your ssh password to sshserver. Note that if this is left None,
179 179 you will be prompted for it if passwordless key based login is unavailable.
180 180 paramiko : bool
181 181 flag for whether to use paramiko instead of shell ssh for tunneling.
182 182 [default: True on win32, False else]
183 183
184 184 ------- exec authentication args -------
185 185 If even localhost is untrusted, you can have some protection against
186 186 unauthorized execution by signing messages with HMAC digests.
187 187 Messages are still sent as cleartext, so if someone can snoop your
188 188 loopback traffic this will not protect your privacy, but will prevent
189 189 unauthorized execution.
190 190
191 191 exec_key : str
192 192 an authentication key or file containing a key
193 193 default: None
194 194
195 195
196 196 Attributes
197 197 ----------
198 198
199 199 ids : list of int engine IDs
200 200 requesting the ids attribute always synchronizes
201 201 the registration state. To request ids without synchronization,
202 202 use semi-private _ids attributes.
203 203
204 204 history : list of msg_ids
205 205 a list of msg_ids, keeping track of all the execution
206 206 messages you have submitted in order.
207 207
208 208 outstanding : set of msg_ids
209 209 a set of msg_ids that have been submitted, but whose
210 210 results have not yet been received.
211 211
212 212 results : dict
213 213 a dict of all our results, keyed by msg_id
214 214
215 215 block : bool
216 216 determines default behavior when block not specified
217 217 in execution methods
218 218
219 219 Methods
220 220 -------
221 221
222 222 spin
223 223 flushes incoming results and registration state changes
224 224 control methods spin, and requesting `ids` also ensures up to date
225 225
226 226 wait
227 227 wait on one or more msg_ids
228 228
229 229 execution methods
230 230 apply
231 231 legacy: execute, run
232 232
233 233 data movement
234 234 push, pull, scatter, gather
235 235
236 236 query methods
237 237 queue_status, get_result, purge, result_status
238 238
239 239 control methods
240 240 abort, shutdown
241 241
242 242 """
243 243
244 244
245 245 block = Bool(False)
246 246 outstanding = Set()
247 247 results = Instance('collections.defaultdict', (dict,))
248 248 metadata = Instance('collections.defaultdict', (Metadata,))
249 249 history = List()
250 250 debug = Bool(False)
251 251
252 252 profile=Unicode()
253 253 def _profile_default(self):
254 254 if BaseIPythonApplication.initialized():
255 255 # an IPython app *might* be running, try to get its profile
256 256 try:
257 257 return BaseIPythonApplication.instance().profile
258 258 except (AttributeError, MultipleInstanceError):
259 259 # could be a *different* subclass of config.Application,
260 260 # which would raise one of these two errors.
261 261 return u'default'
262 262 else:
263 263 return u'default'
264 264
265 265
266 266 _outstanding_dict = Instance('collections.defaultdict', (set,))
267 267 _ids = List()
268 268 _connected=Bool(False)
269 269 _ssh=Bool(False)
270 270 _context = Instance('zmq.Context')
271 271 _config = Dict()
272 272 _engines=Instance(util.ReverseDict, (), {})
273 273 # _hub_socket=Instance('zmq.Socket')
274 274 _query_socket=Instance('zmq.Socket')
275 275 _control_socket=Instance('zmq.Socket')
276 276 _iopub_socket=Instance('zmq.Socket')
277 277 _notification_socket=Instance('zmq.Socket')
278 278 _mux_socket=Instance('zmq.Socket')
279 279 _task_socket=Instance('zmq.Socket')
280 280 _task_scheme=Unicode()
281 281 _closed = False
282 282 _ignored_control_replies=Int(0)
283 283 _ignored_hub_replies=Int(0)
284 284
285 285 def __new__(self, *args, **kw):
286 286 # don't raise on positional args
287 287 return HasTraits.__new__(self, **kw)
288 288
289 289 def __init__(self, url_or_file=None, profile=None, profile_dir=None, ipython_dir=None,
290 290 context=None, debug=False, exec_key=None,
291 291 sshserver=None, sshkey=None, password=None, paramiko=None,
292 292 timeout=10, **extra_args
293 293 ):
294 294 if profile:
295 295 super(Client, self).__init__(debug=debug, profile=profile)
296 296 else:
297 297 super(Client, self).__init__(debug=debug)
298 298 if context is None:
299 299 context = zmq.Context.instance()
300 300 self._context = context
301 301
302 302 self._setup_profile_dir(self.profile, profile_dir, ipython_dir)
303 303 if self._cd is not None:
304 304 if url_or_file is None:
305 305 url_or_file = pjoin(self._cd.security_dir, 'ipcontroller-client.json')
306 306 assert url_or_file is not None, "I can't find enough information to connect to a hub!"\
307 307 " Please specify at least one of url_or_file or profile."
308 308
309 309 try:
310 310 util.validate_url(url_or_file)
311 311 except AssertionError:
312 312 if not os.path.exists(url_or_file):
313 313 if self._cd:
314 314 url_or_file = os.path.join(self._cd.security_dir, url_or_file)
315 315 assert os.path.exists(url_or_file), "Not a valid connection file or url: %r"%url_or_file
316 316 with open(url_or_file) as f:
317 317 cfg = json.loads(f.read())
318 318 else:
319 319 cfg = {'url':url_or_file}
320 320
321 321 # sync defaults from args, json:
322 322 if sshserver:
323 323 cfg['ssh'] = sshserver
324 324 if exec_key:
325 325 cfg['exec_key'] = exec_key
326 326 exec_key = cfg['exec_key']
327 327 location = cfg.setdefault('location', None)
328 328 cfg['url'] = util.disambiguate_url(cfg['url'], location)
329 329 url = cfg['url']
330 330 proto,addr,port = util.split_url(url)
331 331 if location is not None and addr == '127.0.0.1':
332 332 # location specified, and connection is expected to be local
333 333 if location not in LOCAL_IPS and not sshserver:
334 334 # load ssh from JSON *only* if the controller is not on
335 335 # this machine
336 336 sshserver=cfg['ssh']
337 337 if location not in LOCAL_IPS and not sshserver:
338 338 # warn if no ssh specified, but SSH is probably needed
339 339 # This is only a warning, because the most likely cause
340 340 # is a local Controller on a laptop whose IP is dynamic
341 341 warnings.warn("""
342 342 Controller appears to be listening on localhost, but not on this machine.
343 343 If this is true, you should specify Client(...,sshserver='you@%s')
344 344 or instruct your controller to listen on an external IP."""%location,
345 345 RuntimeWarning)
346 346 elif not sshserver:
347 347 # otherwise sync with cfg
348 348 sshserver = cfg['ssh']
349 349
350 350 self._config = cfg
351 351
352 352 self._ssh = bool(sshserver or sshkey or password)
353 353 if self._ssh and sshserver is None:
354 354 # default to ssh via localhost
355 355 sshserver = url.split('://')[1].split(':')[0]
356 356 if self._ssh and password is None:
357 357 if tunnel.try_passwordless_ssh(sshserver, sshkey, paramiko):
358 358 password=False
359 359 else:
360 360 password = getpass("SSH Password for %s: "%sshserver)
361 361 ssh_kwargs = dict(keyfile=sshkey, password=password, paramiko=paramiko)
362 362
363 363 # configure and construct the session
364 364 if exec_key is not None:
365 365 if os.path.isfile(exec_key):
366 366 extra_args['keyfile'] = exec_key
367 367 else:
368 368 exec_key = util.asbytes(exec_key)
369 369 extra_args['key'] = exec_key
370 370 self.session = Session(**extra_args)
371 371
372 372 self._query_socket = self._context.socket(zmq.XREQ)
373 373 self._query_socket.setsockopt(zmq.IDENTITY, util.asbytes(self.session.session))
374 374 if self._ssh:
375 375 tunnel.tunnel_connection(self._query_socket, url, sshserver, **ssh_kwargs)
376 376 else:
377 377 self._query_socket.connect(url)
378 378
379 379 self.session.debug = self.debug
380 380
381 381 self._notification_handlers = {'registration_notification' : self._register_engine,
382 382 'unregistration_notification' : self._unregister_engine,
383 383 'shutdown_notification' : lambda msg: self.close(),
384 384 }
385 385 self._queue_handlers = {'execute_reply' : self._handle_execute_reply,
386 386 'apply_reply' : self._handle_apply_reply}
387 387 self._connect(sshserver, ssh_kwargs, timeout)
388 388
389 389 def __del__(self):
390 390 """cleanup sockets, but _not_ context."""
391 391 self.close()
392 392
393 393 def _setup_profile_dir(self, profile, profile_dir, ipython_dir):
394 394 if ipython_dir is None:
395 395 ipython_dir = get_ipython_dir()
396 396 if profile_dir is not None:
397 397 try:
398 398 self._cd = ProfileDir.find_profile_dir(profile_dir)
399 399 return
400 400 except ProfileDirError:
401 401 pass
402 402 elif profile is not None:
403 403 try:
404 404 self._cd = ProfileDir.find_profile_dir_by_name(
405 405 ipython_dir, profile)
406 406 return
407 407 except ProfileDirError:
408 408 pass
409 409 self._cd = None
410 410
411 411 def _update_engines(self, engines):
412 412 """Update our engines dict and _ids from a dict of the form: {id:uuid}."""
413 413 for k,v in engines.iteritems():
414 414 eid = int(k)
415 415 self._engines[eid] = v
416 416 self._ids.append(eid)
417 417 self._ids = sorted(self._ids)
418 418 if sorted(self._engines.keys()) != range(len(self._engines)) and \
419 419 self._task_scheme == 'pure' and self._task_socket:
420 420 self._stop_scheduling_tasks()
421 421
422 422 def _stop_scheduling_tasks(self):
423 423 """Stop scheduling tasks because an engine has been unregistered
424 424 from a pure ZMQ scheduler.
425 425 """
426 426 self._task_socket.close()
427 427 self._task_socket = None
428 428 msg = "An engine has been unregistered, and we are using pure " +\
429 429 "ZMQ task scheduling. Task farming will be disabled."
430 430 if self.outstanding:
431 431 msg += " If you were running tasks when this happened, " +\
432 432 "some `outstanding` msg_ids may never resolve."
433 433 warnings.warn(msg, RuntimeWarning)
434 434
435 435 def _build_targets(self, targets):
436 436 """Turn valid target IDs or 'all' into two lists:
437 437 (int_ids, uuids).
438 438 """
439 439 if not self._ids:
440 440 # flush notification socket if no engines yet, just in case
441 441 if not self.ids:
442 442 raise error.NoEnginesRegistered("Can't build targets without any engines")
443 443
444 444 if targets is None:
445 445 targets = self._ids
446 446 elif isinstance(targets, basestring):
447 447 if targets.lower() == 'all':
448 448 targets = self._ids
449 449 else:
450 450 raise TypeError("%r not valid str target, must be 'all'"%(targets))
451 451 elif isinstance(targets, int):
452 452 if targets < 0:
453 453 targets = self.ids[targets]
454 454 if targets not in self._ids:
455 455 raise IndexError("No such engine: %i"%targets)
456 456 targets = [targets]
457 457
458 458 if isinstance(targets, slice):
459 459 indices = range(len(self._ids))[targets]
460 460 ids = self.ids
461 461 targets = [ ids[i] for i in indices ]
462 462
463 463 if not isinstance(targets, (tuple, list, xrange)):
464 464 raise TypeError("targets by int/slice/collection of ints only, not %s"%(type(targets)))
465 465
466 466 return [util.asbytes(self._engines[t]) for t in targets], list(targets)
467 467
468 468 def _connect(self, sshserver, ssh_kwargs, timeout):
469 469 """setup all our socket connections to the cluster. This is called from
470 470 __init__."""
471 471
472 472 # Maybe allow reconnecting?
473 473 if self._connected:
474 474 return
475 475 self._connected=True
476 476
477 477 def connect_socket(s, url):
478 478 url = util.disambiguate_url(url, self._config['location'])
479 479 if self._ssh:
480 480 return tunnel.tunnel_connection(s, url, sshserver, **ssh_kwargs)
481 481 else:
482 482 return s.connect(url)
483 483
484 484 self.session.send(self._query_socket, 'connection_request')
485 485 # use Poller because zmq.select has wrong units in pyzmq 2.1.7
486 486 poller = zmq.Poller()
487 487 poller.register(self._query_socket, zmq.POLLIN)
488 488 # poll expects milliseconds, timeout is seconds
489 489 evts = poller.poll(timeout*1000)
490 490 if not evts:
491 491 raise error.TimeoutError("Hub connection request timed out")
492 492 idents,msg = self.session.recv(self._query_socket,mode=0)
493 493 if self.debug:
494 494 pprint(msg)
495 495 msg = Message(msg)
496 496 content = msg.content
497 497 self._config['registration'] = dict(content)
498 498 if content.status == 'ok':
499 499 ident = util.asbytes(self.session.session)
500 500 if content.mux:
501 501 self._mux_socket = self._context.socket(zmq.XREQ)
502 502 self._mux_socket.setsockopt(zmq.IDENTITY, ident)
503 503 connect_socket(self._mux_socket, content.mux)
504 504 if content.task:
505 505 self._task_scheme, task_addr = content.task
506 506 self._task_socket = self._context.socket(zmq.XREQ)
507 507 self._task_socket.setsockopt(zmq.IDENTITY, ident)
508 508 connect_socket(self._task_socket, task_addr)
509 509 if content.notification:
510 510 self._notification_socket = self._context.socket(zmq.SUB)
511 511 connect_socket(self._notification_socket, content.notification)
512 512 self._notification_socket.setsockopt(zmq.SUBSCRIBE, b'')
513 513 # if content.query:
514 514 # self._query_socket = self._context.socket(zmq.XREQ)
515 515 # self._query_socket.setsockopt(zmq.IDENTITY, self.session.session)
516 516 # connect_socket(self._query_socket, content.query)
517 517 if content.control:
518 518 self._control_socket = self._context.socket(zmq.XREQ)
519 519 self._control_socket.setsockopt(zmq.IDENTITY, ident)
520 520 connect_socket(self._control_socket, content.control)
521 521 if content.iopub:
522 522 self._iopub_socket = self._context.socket(zmq.SUB)
523 523 self._iopub_socket.setsockopt(zmq.SUBSCRIBE, b'')
524 524 self._iopub_socket.setsockopt(zmq.IDENTITY, ident)
525 525 connect_socket(self._iopub_socket, content.iopub)
526 526 self._update_engines(dict(content.engines))
527 527 else:
528 528 self._connected = False
529 529 raise Exception("Failed to connect!")
530 530
531 531 #--------------------------------------------------------------------------
532 532 # handlers and callbacks for incoming messages
533 533 #--------------------------------------------------------------------------
534 534
535 535 def _unwrap_exception(self, content):
536 536 """unwrap exception, and remap engine_id to int."""
537 537 e = error.unwrap_exception(content)
538 538 # print e.traceback
539 539 if e.engine_info:
540 540 e_uuid = e.engine_info['engine_uuid']
541 541 eid = self._engines[e_uuid]
542 542 e.engine_info['engine_id'] = eid
543 543 return e
544 544
545 545 def _extract_metadata(self, header, parent, content):
546 546 md = {'msg_id' : parent['msg_id'],
547 547 'received' : datetime.now(),
548 548 'engine_uuid' : header.get('engine', None),
549 549 'follow' : parent.get('follow', []),
550 550 'after' : parent.get('after', []),
551 551 'status' : content['status'],
552 552 }
553 553
554 554 if md['engine_uuid'] is not None:
555 555 md['engine_id'] = self._engines.get(md['engine_uuid'], None)
556 556
557 557 if 'date' in parent:
558 558 md['submitted'] = parent['date']
559 559 if 'started' in header:
560 560 md['started'] = header['started']
561 561 if 'date' in header:
562 562 md['completed'] = header['date']
563 563 return md
564 564
565 565 def _register_engine(self, msg):
566 566 """Register a new engine, and update our connection info."""
567 567 content = msg['content']
568 568 eid = content['id']
569 569 d = {eid : content['queue']}
570 570 self._update_engines(d)
571 571
572 572 def _unregister_engine(self, msg):
573 573 """Unregister an engine that has died."""
574 574 content = msg['content']
575 575 eid = int(content['id'])
576 576 if eid in self._ids:
577 577 self._ids.remove(eid)
578 578 uuid = self._engines.pop(eid)
579 579
580 580 self._handle_stranded_msgs(eid, uuid)
581 581
582 582 if self._task_socket and self._task_scheme == 'pure':
583 583 self._stop_scheduling_tasks()
584 584
585 585 def _handle_stranded_msgs(self, eid, uuid):
586 586 """Handle messages known to be on an engine when the engine unregisters.
587 587
588 588 It is possible that this will fire prematurely - that is, an engine will
589 589 go down after completing a result, and the client will be notified
590 590 of the unregistration and later receive the successful result.
591 591 """
592 592
593 593 outstanding = self._outstanding_dict[uuid]
594 594
595 595 for msg_id in list(outstanding):
596 596 if msg_id in self.results:
597 597 # we already
598 598 continue
599 599 try:
600 600 raise error.EngineError("Engine %r died while running task %r"%(eid, msg_id))
601 601 except:
602 602 content = error.wrap_exception()
603 603 # build a fake message:
604 604 parent = {}
605 605 header = {}
606 606 parent['msg_id'] = msg_id
607 607 header['engine'] = uuid
608 608 header['date'] = datetime.now()
609 609 msg = dict(parent_header=parent, header=header, content=content)
610 610 self._handle_apply_reply(msg)
611 611
612 612 def _handle_execute_reply(self, msg):
613 613 """Save the reply to an execute_request into our results.
614 614
615 615 execute messages are never actually used. apply is used instead.
616 616 """
617 617
618 618 parent = msg['parent_header']
619 619 msg_id = parent['msg_id']
620 620 if msg_id not in self.outstanding:
621 621 if msg_id in self.history:
622 622 print ("got stale result: %s"%msg_id)
623 623 else:
624 624 print ("got unknown result: %s"%msg_id)
625 625 else:
626 626 self.outstanding.remove(msg_id)
627 627 self.results[msg_id] = self._unwrap_exception(msg['content'])
628 628
629 629 def _handle_apply_reply(self, msg):
630 630 """Save the reply to an apply_request into our results."""
631 631 parent = msg['parent_header']
632 632 msg_id = parent['msg_id']
633 633 if msg_id not in self.outstanding:
634 634 if msg_id in self.history:
635 635 print ("got stale result: %s"%msg_id)
636 636 print self.results[msg_id]
637 637 print msg
638 638 else:
639 639 print ("got unknown result: %s"%msg_id)
640 640 else:
641 641 self.outstanding.remove(msg_id)
642 642 content = msg['content']
643 643 header = msg['header']
644 644
645 645 # construct metadata:
646 646 md = self.metadata[msg_id]
647 647 md.update(self._extract_metadata(header, parent, content))
648 648 # is this redundant?
649 649 self.metadata[msg_id] = md
650 650
651 651 e_outstanding = self._outstanding_dict[md['engine_uuid']]
652 652 if msg_id in e_outstanding:
653 653 e_outstanding.remove(msg_id)
654 654
655 655 # construct result:
656 656 if content['status'] == 'ok':
657 657 self.results[msg_id] = util.unserialize_object(msg['buffers'])[0]
658 658 elif content['status'] == 'aborted':
659 659 self.results[msg_id] = error.TaskAborted(msg_id)
660 660 elif content['status'] == 'resubmitted':
661 661 # TODO: handle resubmission
662 662 pass
663 663 else:
664 664 self.results[msg_id] = self._unwrap_exception(content)
665 665
666 666 def _flush_notifications(self):
667 667 """Flush notifications of engine registrations waiting
668 668 in ZMQ queue."""
669 669 idents,msg = self.session.recv(self._notification_socket, mode=zmq.NOBLOCK)
670 670 while msg is not None:
671 671 if self.debug:
672 672 pprint(msg)
673 msg_type = msg['msg_type']
673 msg_type = msg['header']['msg_type']
674 674 handler = self._notification_handlers.get(msg_type, None)
675 675 if handler is None:
676 676 raise Exception("Unhandled message type: %s"%msg.msg_type)
677 677 else:
678 678 handler(msg)
679 679 idents,msg = self.session.recv(self._notification_socket, mode=zmq.NOBLOCK)
680 680
681 681 def _flush_results(self, sock):
682 682 """Flush task or queue results waiting in ZMQ queue."""
683 683 idents,msg = self.session.recv(sock, mode=zmq.NOBLOCK)
684 684 while msg is not None:
685 685 if self.debug:
686 686 pprint(msg)
687 msg_type = msg['msg_type']
687 msg_type = msg['header']['msg_type']
688 688 handler = self._queue_handlers.get(msg_type, None)
689 689 if handler is None:
690 690 raise Exception("Unhandled message type: %s"%msg.msg_type)
691 691 else:
692 692 handler(msg)
693 693 idents,msg = self.session.recv(sock, mode=zmq.NOBLOCK)
694 694
695 695 def _flush_control(self, sock):
696 696 """Flush replies from the control channel waiting
697 697 in the ZMQ queue.
698 698
699 699 Currently: ignore them."""
700 700 if self._ignored_control_replies <= 0:
701 701 return
702 702 idents,msg = self.session.recv(sock, mode=zmq.NOBLOCK)
703 703 while msg is not None:
704 704 self._ignored_control_replies -= 1
705 705 if self.debug:
706 706 pprint(msg)
707 707 idents,msg = self.session.recv(sock, mode=zmq.NOBLOCK)
708 708
709 709 def _flush_ignored_control(self):
710 710 """flush ignored control replies"""
711 711 while self._ignored_control_replies > 0:
712 712 self.session.recv(self._control_socket)
713 713 self._ignored_control_replies -= 1
714 714
715 715 def _flush_ignored_hub_replies(self):
716 716 ident,msg = self.session.recv(self._query_socket, mode=zmq.NOBLOCK)
717 717 while msg is not None:
718 718 ident,msg = self.session.recv(self._query_socket, mode=zmq.NOBLOCK)
719 719
720 720 def _flush_iopub(self, sock):
721 721 """Flush replies from the iopub channel waiting
722 722 in the ZMQ queue.
723 723 """
724 724 idents,msg = self.session.recv(sock, mode=zmq.NOBLOCK)
725 725 while msg is not None:
726 726 if self.debug:
727 727 pprint(msg)
728 728 parent = msg['parent_header']
729 729 msg_id = parent['msg_id']
730 730 content = msg['content']
731 731 header = msg['header']
732 msg_type = msg['msg_type']
732 msg_type = msg['header']['msg_type']
733 733
734 734 # init metadata:
735 735 md = self.metadata[msg_id]
736 736
737 737 if msg_type == 'stream':
738 738 name = content['name']
739 739 s = md[name] or ''
740 740 md[name] = s + content['data']
741 741 elif msg_type == 'pyerr':
742 742 md.update({'pyerr' : self._unwrap_exception(content)})
743 743 elif msg_type == 'pyin':
744 744 md.update({'pyin' : content['code']})
745 745 else:
746 746 md.update({msg_type : content.get('data', '')})
747 747
748 748 # reduntant?
749 749 self.metadata[msg_id] = md
750 750
751 751 idents,msg = self.session.recv(sock, mode=zmq.NOBLOCK)
752 752
753 753 #--------------------------------------------------------------------------
754 754 # len, getitem
755 755 #--------------------------------------------------------------------------
756 756
757 757 def __len__(self):
758 758 """len(client) returns # of engines."""
759 759 return len(self.ids)
760 760
761 761 def __getitem__(self, key):
762 762 """index access returns DirectView multiplexer objects
763 763
764 764 Must be int, slice, or list/tuple/xrange of ints"""
765 765 if not isinstance(key, (int, slice, tuple, list, xrange)):
766 766 raise TypeError("key by int/slice/iterable of ints only, not %s"%(type(key)))
767 767 else:
768 768 return self.direct_view(key)
769 769
770 770 #--------------------------------------------------------------------------
771 771 # Begin public methods
772 772 #--------------------------------------------------------------------------
773 773
774 774 @property
775 775 def ids(self):
776 776 """Always up-to-date ids property."""
777 777 self._flush_notifications()
778 778 # always copy:
779 779 return list(self._ids)
780 780
781 781 def close(self):
782 782 if self._closed:
783 783 return
784 784 snames = filter(lambda n: n.endswith('socket'), dir(self))
785 785 for socket in map(lambda name: getattr(self, name), snames):
786 786 if isinstance(socket, zmq.Socket) and not socket.closed:
787 787 socket.close()
788 788 self._closed = True
789 789
790 790 def spin(self):
791 791 """Flush any registration notifications and execution results
792 792 waiting in the ZMQ queue.
793 793 """
794 794 if self._notification_socket:
795 795 self._flush_notifications()
796 796 if self._mux_socket:
797 797 self._flush_results(self._mux_socket)
798 798 if self._task_socket:
799 799 self._flush_results(self._task_socket)
800 800 if self._control_socket:
801 801 self._flush_control(self._control_socket)
802 802 if self._iopub_socket:
803 803 self._flush_iopub(self._iopub_socket)
804 804 if self._query_socket:
805 805 self._flush_ignored_hub_replies()
806 806
807 807 def wait(self, jobs=None, timeout=-1):
808 808 """waits on one or more `jobs`, for up to `timeout` seconds.
809 809
810 810 Parameters
811 811 ----------
812 812
813 813 jobs : int, str, or list of ints and/or strs, or one or more AsyncResult objects
814 814 ints are indices to self.history
815 815 strs are msg_ids
816 816 default: wait on all outstanding messages
817 817 timeout : float
818 818 a time in seconds, after which to give up.
819 819 default is -1, which means no timeout
820 820
821 821 Returns
822 822 -------
823 823
824 824 True : when all msg_ids are done
825 825 False : timeout reached, some msg_ids still outstanding
826 826 """
827 827 tic = time.time()
828 828 if jobs is None:
829 829 theids = self.outstanding
830 830 else:
831 831 if isinstance(jobs, (int, basestring, AsyncResult)):
832 832 jobs = [jobs]
833 833 theids = set()
834 834 for job in jobs:
835 835 if isinstance(job, int):
836 836 # index access
837 837 job = self.history[job]
838 838 elif isinstance(job, AsyncResult):
839 839 map(theids.add, job.msg_ids)
840 840 continue
841 841 theids.add(job)
842 842 if not theids.intersection(self.outstanding):
843 843 return True
844 844 self.spin()
845 845 while theids.intersection(self.outstanding):
846 846 if timeout >= 0 and ( time.time()-tic ) > timeout:
847 847 break
848 848 time.sleep(1e-3)
849 849 self.spin()
850 850 return len(theids.intersection(self.outstanding)) == 0
851 851
852 852 #--------------------------------------------------------------------------
853 853 # Control methods
854 854 #--------------------------------------------------------------------------
855 855
856 856 @spin_first
857 857 def clear(self, targets=None, block=None):
858 858 """Clear the namespace in target(s)."""
859 859 block = self.block if block is None else block
860 860 targets = self._build_targets(targets)[0]
861 861 for t in targets:
862 862 self.session.send(self._control_socket, 'clear_request', content={}, ident=t)
863 863 error = False
864 864 if block:
865 865 self._flush_ignored_control()
866 866 for i in range(len(targets)):
867 867 idents,msg = self.session.recv(self._control_socket,0)
868 868 if self.debug:
869 869 pprint(msg)
870 870 if msg['content']['status'] != 'ok':
871 871 error = self._unwrap_exception(msg['content'])
872 872 else:
873 873 self._ignored_control_replies += len(targets)
874 874 if error:
875 875 raise error
876 876
877 877
878 878 @spin_first
879 879 def abort(self, jobs=None, targets=None, block=None):
880 880 """Abort specific jobs from the execution queues of target(s).
881 881
882 882 This is a mechanism to prevent jobs that have already been submitted
883 883 from executing.
884 884
885 885 Parameters
886 886 ----------
887 887
888 888 jobs : msg_id, list of msg_ids, or AsyncResult
889 889 The jobs to be aborted
890 890
891 891
892 892 """
893 893 block = self.block if block is None else block
894 894 targets = self._build_targets(targets)[0]
895 895 msg_ids = []
896 896 if isinstance(jobs, (basestring,AsyncResult)):
897 897 jobs = [jobs]
898 898 bad_ids = filter(lambda obj: not isinstance(obj, (basestring, AsyncResult)), jobs)
899 899 if bad_ids:
900 900 raise TypeError("Invalid msg_id type %r, expected str or AsyncResult"%bad_ids[0])
901 901 for j in jobs:
902 902 if isinstance(j, AsyncResult):
903 903 msg_ids.extend(j.msg_ids)
904 904 else:
905 905 msg_ids.append(j)
906 906 content = dict(msg_ids=msg_ids)
907 907 for t in targets:
908 908 self.session.send(self._control_socket, 'abort_request',
909 909 content=content, ident=t)
910 910 error = False
911 911 if block:
912 912 self._flush_ignored_control()
913 913 for i in range(len(targets)):
914 914 idents,msg = self.session.recv(self._control_socket,0)
915 915 if self.debug:
916 916 pprint(msg)
917 917 if msg['content']['status'] != 'ok':
918 918 error = self._unwrap_exception(msg['content'])
919 919 else:
920 920 self._ignored_control_replies += len(targets)
921 921 if error:
922 922 raise error
923 923
924 924 @spin_first
925 925 def shutdown(self, targets=None, restart=False, hub=False, block=None):
926 926 """Terminates one or more engine processes, optionally including the hub."""
927 927 block = self.block if block is None else block
928 928 if hub:
929 929 targets = 'all'
930 930 targets = self._build_targets(targets)[0]
931 931 for t in targets:
932 932 self.session.send(self._control_socket, 'shutdown_request',
933 933 content={'restart':restart},ident=t)
934 934 error = False
935 935 if block or hub:
936 936 self._flush_ignored_control()
937 937 for i in range(len(targets)):
938 938 idents,msg = self.session.recv(self._control_socket, 0)
939 939 if self.debug:
940 940 pprint(msg)
941 941 if msg['content']['status'] != 'ok':
942 942 error = self._unwrap_exception(msg['content'])
943 943 else:
944 944 self._ignored_control_replies += len(targets)
945 945
946 946 if hub:
947 947 time.sleep(0.25)
948 948 self.session.send(self._query_socket, 'shutdown_request')
949 949 idents,msg = self.session.recv(self._query_socket, 0)
950 950 if self.debug:
951 951 pprint(msg)
952 952 if msg['content']['status'] != 'ok':
953 953 error = self._unwrap_exception(msg['content'])
954 954
955 955 if error:
956 956 raise error
957 957
958 958 #--------------------------------------------------------------------------
959 959 # Execution related methods
960 960 #--------------------------------------------------------------------------
961 961
962 962 def _maybe_raise(self, result):
963 963 """wrapper for maybe raising an exception if apply failed."""
964 964 if isinstance(result, error.RemoteError):
965 965 raise result
966 966
967 967 return result
968 968
969 969 def send_apply_message(self, socket, f, args=None, kwargs=None, subheader=None, track=False,
970 970 ident=None):
971 971 """construct and send an apply message via a socket.
972 972
973 973 This is the principal method with which all engine execution is performed by views.
974 974 """
975 975
976 976 assert not self._closed, "cannot use me anymore, I'm closed!"
977 977 # defaults:
978 978 args = args if args is not None else []
979 979 kwargs = kwargs if kwargs is not None else {}
980 980 subheader = subheader if subheader is not None else {}
981 981
982 982 # validate arguments
983 983 if not callable(f):
984 984 raise TypeError("f must be callable, not %s"%type(f))
985 985 if not isinstance(args, (tuple, list)):
986 986 raise TypeError("args must be tuple or list, not %s"%type(args))
987 987 if not isinstance(kwargs, dict):
988 988 raise TypeError("kwargs must be dict, not %s"%type(kwargs))
989 989 if not isinstance(subheader, dict):
990 990 raise TypeError("subheader must be dict, not %s"%type(subheader))
991 991
992 992 bufs = util.pack_apply_message(f,args,kwargs)
993 993
994 994 msg = self.session.send(socket, "apply_request", buffers=bufs, ident=ident,
995 995 subheader=subheader, track=track)
996 996
997 msg_id = msg['msg_id']
997 msg_id = msg['header']['msg_id']
998 998 self.outstanding.add(msg_id)
999 999 if ident:
1000 1000 # possibly routed to a specific engine
1001 1001 if isinstance(ident, list):
1002 1002 ident = ident[-1]
1003 1003 if ident in self._engines.values():
1004 1004 # save for later, in case of engine death
1005 1005 self._outstanding_dict[ident].add(msg_id)
1006 1006 self.history.append(msg_id)
1007 1007 self.metadata[msg_id]['submitted'] = datetime.now()
1008 1008
1009 1009 return msg
1010 1010
1011 1011 #--------------------------------------------------------------------------
1012 1012 # construct a View object
1013 1013 #--------------------------------------------------------------------------
1014 1014
1015 1015 def load_balanced_view(self, targets=None):
1016 1016 """construct a DirectView object.
1017 1017
1018 1018 If no arguments are specified, create a LoadBalancedView
1019 1019 using all engines.
1020 1020
1021 1021 Parameters
1022 1022 ----------
1023 1023
1024 1024 targets: list,slice,int,etc. [default: use all engines]
1025 1025 The subset of engines across which to load-balance
1026 1026 """
1027 1027 if targets == 'all':
1028 1028 targets = None
1029 1029 if targets is not None:
1030 1030 targets = self._build_targets(targets)[1]
1031 1031 return LoadBalancedView(client=self, socket=self._task_socket, targets=targets)
1032 1032
1033 1033 def direct_view(self, targets='all'):
1034 1034 """construct a DirectView object.
1035 1035
1036 1036 If no targets are specified, create a DirectView
1037 1037 using all engines.
1038 1038
1039 1039 Parameters
1040 1040 ----------
1041 1041
1042 1042 targets: list,slice,int,etc. [default: use all engines]
1043 1043 The engines to use for the View
1044 1044 """
1045 1045 single = isinstance(targets, int)
1046 1046 # allow 'all' to be lazily evaluated at each execution
1047 1047 if targets != 'all':
1048 1048 targets = self._build_targets(targets)[1]
1049 1049 if single:
1050 1050 targets = targets[0]
1051 1051 return DirectView(client=self, socket=self._mux_socket, targets=targets)
1052 1052
1053 1053 #--------------------------------------------------------------------------
1054 1054 # Query methods
1055 1055 #--------------------------------------------------------------------------
1056 1056
1057 1057 @spin_first
1058 1058 def get_result(self, indices_or_msg_ids=None, block=None):
1059 1059 """Retrieve a result by msg_id or history index, wrapped in an AsyncResult object.
1060 1060
1061 1061 If the client already has the results, no request to the Hub will be made.
1062 1062
1063 1063 This is a convenient way to construct AsyncResult objects, which are wrappers
1064 1064 that include metadata about execution, and allow for awaiting results that
1065 1065 were not submitted by this Client.
1066 1066
1067 1067 It can also be a convenient way to retrieve the metadata associated with
1068 1068 blocking execution, since it always retrieves
1069 1069
1070 1070 Examples
1071 1071 --------
1072 1072 ::
1073 1073
1074 1074 In [10]: r = client.apply()
1075 1075
1076 1076 Parameters
1077 1077 ----------
1078 1078
1079 1079 indices_or_msg_ids : integer history index, str msg_id, or list of either
1080 1080 The indices or msg_ids of indices to be retrieved
1081 1081
1082 1082 block : bool
1083 1083 Whether to wait for the result to be done
1084 1084
1085 1085 Returns
1086 1086 -------
1087 1087
1088 1088 AsyncResult
1089 1089 A single AsyncResult object will always be returned.
1090 1090
1091 1091 AsyncHubResult
1092 1092 A subclass of AsyncResult that retrieves results from the Hub
1093 1093
1094 1094 """
1095 1095 block = self.block if block is None else block
1096 1096 if indices_or_msg_ids is None:
1097 1097 indices_or_msg_ids = -1
1098 1098
1099 1099 if not isinstance(indices_or_msg_ids, (list,tuple)):
1100 1100 indices_or_msg_ids = [indices_or_msg_ids]
1101 1101
1102 1102 theids = []
1103 1103 for id in indices_or_msg_ids:
1104 1104 if isinstance(id, int):
1105 1105 id = self.history[id]
1106 1106 if not isinstance(id, basestring):
1107 1107 raise TypeError("indices must be str or int, not %r"%id)
1108 1108 theids.append(id)
1109 1109
1110 1110 local_ids = filter(lambda msg_id: msg_id in self.history or msg_id in self.results, theids)
1111 1111 remote_ids = filter(lambda msg_id: msg_id not in local_ids, theids)
1112 1112
1113 1113 if remote_ids:
1114 1114 ar = AsyncHubResult(self, msg_ids=theids)
1115 1115 else:
1116 1116 ar = AsyncResult(self, msg_ids=theids)
1117 1117
1118 1118 if block:
1119 1119 ar.wait()
1120 1120
1121 1121 return ar
1122 1122
1123 1123 @spin_first
1124 1124 def resubmit(self, indices_or_msg_ids=None, subheader=None, block=None):
1125 1125 """Resubmit one or more tasks.
1126 1126
1127 1127 in-flight tasks may not be resubmitted.
1128 1128
1129 1129 Parameters
1130 1130 ----------
1131 1131
1132 1132 indices_or_msg_ids : integer history index, str msg_id, or list of either
1133 1133 The indices or msg_ids of indices to be retrieved
1134 1134
1135 1135 block : bool
1136 1136 Whether to wait for the result to be done
1137 1137
1138 1138 Returns
1139 1139 -------
1140 1140
1141 1141 AsyncHubResult
1142 1142 A subclass of AsyncResult that retrieves results from the Hub
1143 1143
1144 1144 """
1145 1145 block = self.block if block is None else block
1146 1146 if indices_or_msg_ids is None:
1147 1147 indices_or_msg_ids = -1
1148 1148
1149 1149 if not isinstance(indices_or_msg_ids, (list,tuple)):
1150 1150 indices_or_msg_ids = [indices_or_msg_ids]
1151 1151
1152 1152 theids = []
1153 1153 for id in indices_or_msg_ids:
1154 1154 if isinstance(id, int):
1155 1155 id = self.history[id]
1156 1156 if not isinstance(id, basestring):
1157 1157 raise TypeError("indices must be str or int, not %r"%id)
1158 1158 theids.append(id)
1159 1159
1160 1160 for msg_id in theids:
1161 1161 self.outstanding.discard(msg_id)
1162 1162 if msg_id in self.history:
1163 1163 self.history.remove(msg_id)
1164 1164 self.results.pop(msg_id, None)
1165 1165 self.metadata.pop(msg_id, None)
1166 1166 content = dict(msg_ids = theids)
1167 1167
1168 1168 self.session.send(self._query_socket, 'resubmit_request', content)
1169 1169
1170 1170 zmq.select([self._query_socket], [], [])
1171 1171 idents,msg = self.session.recv(self._query_socket, zmq.NOBLOCK)
1172 1172 if self.debug:
1173 1173 pprint(msg)
1174 1174 content = msg['content']
1175 1175 if content['status'] != 'ok':
1176 1176 raise self._unwrap_exception(content)
1177 1177
1178 1178 ar = AsyncHubResult(self, msg_ids=theids)
1179 1179
1180 1180 if block:
1181 1181 ar.wait()
1182 1182
1183 1183 return ar
1184 1184
1185 1185 @spin_first
1186 1186 def result_status(self, msg_ids, status_only=True):
1187 1187 """Check on the status of the result(s) of the apply request with `msg_ids`.
1188 1188
1189 1189 If status_only is False, then the actual results will be retrieved, else
1190 1190 only the status of the results will be checked.
1191 1191
1192 1192 Parameters
1193 1193 ----------
1194 1194
1195 1195 msg_ids : list of msg_ids
1196 1196 if int:
1197 1197 Passed as index to self.history for convenience.
1198 1198 status_only : bool (default: True)
1199 1199 if False:
1200 1200 Retrieve the actual results of completed tasks.
1201 1201
1202 1202 Returns
1203 1203 -------
1204 1204
1205 1205 results : dict
1206 1206 There will always be the keys 'pending' and 'completed', which will
1207 1207 be lists of msg_ids that are incomplete or complete. If `status_only`
1208 1208 is False, then completed results will be keyed by their `msg_id`.
1209 1209 """
1210 1210 if not isinstance(msg_ids, (list,tuple)):
1211 1211 msg_ids = [msg_ids]
1212 1212
1213 1213 theids = []
1214 1214 for msg_id in msg_ids:
1215 1215 if isinstance(msg_id, int):
1216 1216 msg_id = self.history[msg_id]
1217 1217 if not isinstance(msg_id, basestring):
1218 1218 raise TypeError("msg_ids must be str, not %r"%msg_id)
1219 1219 theids.append(msg_id)
1220 1220
1221 1221 completed = []
1222 1222 local_results = {}
1223 1223
1224 1224 # comment this block out to temporarily disable local shortcut:
1225 1225 for msg_id in theids:
1226 1226 if msg_id in self.results:
1227 1227 completed.append(msg_id)
1228 1228 local_results[msg_id] = self.results[msg_id]
1229 1229 theids.remove(msg_id)
1230 1230
1231 1231 if theids: # some not locally cached
1232 1232 content = dict(msg_ids=theids, status_only=status_only)
1233 1233 msg = self.session.send(self._query_socket, "result_request", content=content)
1234 1234 zmq.select([self._query_socket], [], [])
1235 1235 idents,msg = self.session.recv(self._query_socket, zmq.NOBLOCK)
1236 1236 if self.debug:
1237 1237 pprint(msg)
1238 1238 content = msg['content']
1239 1239 if content['status'] != 'ok':
1240 1240 raise self._unwrap_exception(content)
1241 1241 buffers = msg['buffers']
1242 1242 else:
1243 1243 content = dict(completed=[],pending=[])
1244 1244
1245 1245 content['completed'].extend(completed)
1246 1246
1247 1247 if status_only:
1248 1248 return content
1249 1249
1250 1250 failures = []
1251 1251 # load cached results into result:
1252 1252 content.update(local_results)
1253 1253
1254 1254 # update cache with results:
1255 1255 for msg_id in sorted(theids):
1256 1256 if msg_id in content['completed']:
1257 1257 rec = content[msg_id]
1258 1258 parent = rec['header']
1259 1259 header = rec['result_header']
1260 1260 rcontent = rec['result_content']
1261 1261 iodict = rec['io']
1262 1262 if isinstance(rcontent, str):
1263 1263 rcontent = self.session.unpack(rcontent)
1264 1264
1265 1265 md = self.metadata[msg_id]
1266 1266 md.update(self._extract_metadata(header, parent, rcontent))
1267 1267 md.update(iodict)
1268 1268
1269 1269 if rcontent['status'] == 'ok':
1270 1270 res,buffers = util.unserialize_object(buffers)
1271 1271 else:
1272 1272 print rcontent
1273 1273 res = self._unwrap_exception(rcontent)
1274 1274 failures.append(res)
1275 1275
1276 1276 self.results[msg_id] = res
1277 1277 content[msg_id] = res
1278 1278
1279 1279 if len(theids) == 1 and failures:
1280 1280 raise failures[0]
1281 1281
1282 1282 error.collect_exceptions(failures, "result_status")
1283 1283 return content
1284 1284
1285 1285 @spin_first
1286 1286 def queue_status(self, targets='all', verbose=False):
1287 1287 """Fetch the status of engine queues.
1288 1288
1289 1289 Parameters
1290 1290 ----------
1291 1291
1292 1292 targets : int/str/list of ints/strs
1293 1293 the engines whose states are to be queried.
1294 1294 default : all
1295 1295 verbose : bool
1296 1296 Whether to return lengths only, or lists of ids for each element
1297 1297 """
1298 1298 engine_ids = self._build_targets(targets)[1]
1299 1299 content = dict(targets=engine_ids, verbose=verbose)
1300 1300 self.session.send(self._query_socket, "queue_request", content=content)
1301 1301 idents,msg = self.session.recv(self._query_socket, 0)
1302 1302 if self.debug:
1303 1303 pprint(msg)
1304 1304 content = msg['content']
1305 1305 status = content.pop('status')
1306 1306 if status != 'ok':
1307 1307 raise self._unwrap_exception(content)
1308 1308 content = rekey(content)
1309 1309 if isinstance(targets, int):
1310 1310 return content[targets]
1311 1311 else:
1312 1312 return content
1313 1313
1314 1314 @spin_first
1315 1315 def purge_results(self, jobs=[], targets=[]):
1316 1316 """Tell the Hub to forget results.
1317 1317
1318 1318 Individual results can be purged by msg_id, or the entire
1319 1319 history of specific targets can be purged.
1320 1320
1321 1321 Use `purge_results('all')` to scrub everything from the Hub's db.
1322 1322
1323 1323 Parameters
1324 1324 ----------
1325 1325
1326 1326 jobs : str or list of str or AsyncResult objects
1327 1327 the msg_ids whose results should be forgotten.
1328 1328 targets : int/str/list of ints/strs
1329 1329 The targets, by int_id, whose entire history is to be purged.
1330 1330
1331 1331 default : None
1332 1332 """
1333 1333 if not targets and not jobs:
1334 1334 raise ValueError("Must specify at least one of `targets` and `jobs`")
1335 1335 if targets:
1336 1336 targets = self._build_targets(targets)[1]
1337 1337
1338 1338 # construct msg_ids from jobs
1339 1339 if jobs == 'all':
1340 1340 msg_ids = jobs
1341 1341 else:
1342 1342 msg_ids = []
1343 1343 if isinstance(jobs, (basestring,AsyncResult)):
1344 1344 jobs = [jobs]
1345 1345 bad_ids = filter(lambda obj: not isinstance(obj, (basestring, AsyncResult)), jobs)
1346 1346 if bad_ids:
1347 1347 raise TypeError("Invalid msg_id type %r, expected str or AsyncResult"%bad_ids[0])
1348 1348 for j in jobs:
1349 1349 if isinstance(j, AsyncResult):
1350 1350 msg_ids.extend(j.msg_ids)
1351 1351 else:
1352 1352 msg_ids.append(j)
1353 1353
1354 1354 content = dict(engine_ids=targets, msg_ids=msg_ids)
1355 1355 self.session.send(self._query_socket, "purge_request", content=content)
1356 1356 idents, msg = self.session.recv(self._query_socket, 0)
1357 1357 if self.debug:
1358 1358 pprint(msg)
1359 1359 content = msg['content']
1360 1360 if content['status'] != 'ok':
1361 1361 raise self._unwrap_exception(content)
1362 1362
1363 1363 @spin_first
1364 1364 def hub_history(self):
1365 1365 """Get the Hub's history
1366 1366
1367 1367 Just like the Client, the Hub has a history, which is a list of msg_ids.
1368 1368 This will contain the history of all clients, and, depending on configuration,
1369 1369 may contain history across multiple cluster sessions.
1370 1370
1371 1371 Any msg_id returned here is a valid argument to `get_result`.
1372 1372
1373 1373 Returns
1374 1374 -------
1375 1375
1376 1376 msg_ids : list of strs
1377 1377 list of all msg_ids, ordered by task submission time.
1378 1378 """
1379 1379
1380 1380 self.session.send(self._query_socket, "history_request", content={})
1381 1381 idents, msg = self.session.recv(self._query_socket, 0)
1382 1382
1383 1383 if self.debug:
1384 1384 pprint(msg)
1385 1385 content = msg['content']
1386 1386 if content['status'] != 'ok':
1387 1387 raise self._unwrap_exception(content)
1388 1388 else:
1389 1389 return content['history']
1390 1390
1391 1391 @spin_first
1392 1392 def db_query(self, query, keys=None):
1393 1393 """Query the Hub's TaskRecord database
1394 1394
1395 1395 This will return a list of task record dicts that match `query`
1396 1396
1397 1397 Parameters
1398 1398 ----------
1399 1399
1400 1400 query : mongodb query dict
1401 1401 The search dict. See mongodb query docs for details.
1402 1402 keys : list of strs [optional]
1403 1403 The subset of keys to be returned. The default is to fetch everything but buffers.
1404 1404 'msg_id' will *always* be included.
1405 1405 """
1406 1406 if isinstance(keys, basestring):
1407 1407 keys = [keys]
1408 1408 content = dict(query=query, keys=keys)
1409 1409 self.session.send(self._query_socket, "db_request", content=content)
1410 1410 idents, msg = self.session.recv(self._query_socket, 0)
1411 1411 if self.debug:
1412 1412 pprint(msg)
1413 1413 content = msg['content']
1414 1414 if content['status'] != 'ok':
1415 1415 raise self._unwrap_exception(content)
1416 1416
1417 1417 records = content['records']
1418 1418
1419 1419 buffer_lens = content['buffer_lens']
1420 1420 result_buffer_lens = content['result_buffer_lens']
1421 1421 buffers = msg['buffers']
1422 1422 has_bufs = buffer_lens is not None
1423 1423 has_rbufs = result_buffer_lens is not None
1424 1424 for i,rec in enumerate(records):
1425 1425 # relink buffers
1426 1426 if has_bufs:
1427 1427 blen = buffer_lens[i]
1428 1428 rec['buffers'], buffers = buffers[:blen],buffers[blen:]
1429 1429 if has_rbufs:
1430 1430 blen = result_buffer_lens[i]
1431 1431 rec['result_buffers'], buffers = buffers[:blen],buffers[blen:]
1432 1432
1433 1433 return records
1434 1434
1435 1435 __all__ = [ 'Client' ]
@@ -1,1048 +1,1048 b''
1 1 """Views of remote engines.
2 2
3 3 Authors:
4 4
5 5 * Min RK
6 6 """
7 7 #-----------------------------------------------------------------------------
8 8 # Copyright (C) 2010-2011 The IPython Development Team
9 9 #
10 10 # Distributed under the terms of the BSD License. The full license is in
11 11 # the file COPYING, distributed as part of this software.
12 12 #-----------------------------------------------------------------------------
13 13
14 14 #-----------------------------------------------------------------------------
15 15 # Imports
16 16 #-----------------------------------------------------------------------------
17 17
18 18 import imp
19 19 import sys
20 20 import warnings
21 21 from contextlib import contextmanager
22 22 from types import ModuleType
23 23
24 24 import zmq
25 25
26 26 from IPython.testing.skipdoctest import skip_doctest
27 27 from IPython.utils.traitlets import HasTraits, Any, Bool, List, Dict, Set, Int, Instance, CFloat, CInt
28 28 from IPython.external.decorator import decorator
29 29
30 30 from IPython.parallel import util
31 31 from IPython.parallel.controller.dependency import Dependency, dependent
32 32
33 33 from . import map as Map
34 34 from .asyncresult import AsyncResult, AsyncMapResult
35 35 from .remotefunction import ParallelFunction, parallel, remote
36 36
37 37 #-----------------------------------------------------------------------------
38 38 # Decorators
39 39 #-----------------------------------------------------------------------------
40 40
41 41 @decorator
42 42 def save_ids(f, self, *args, **kwargs):
43 43 """Keep our history and outstanding attributes up to date after a method call."""
44 44 n_previous = len(self.client.history)
45 45 try:
46 46 ret = f(self, *args, **kwargs)
47 47 finally:
48 48 nmsgs = len(self.client.history) - n_previous
49 49 msg_ids = self.client.history[-nmsgs:]
50 50 self.history.extend(msg_ids)
51 51 map(self.outstanding.add, msg_ids)
52 52 return ret
53 53
54 54 @decorator
55 55 def sync_results(f, self, *args, **kwargs):
56 56 """sync relevant results from self.client to our results attribute."""
57 57 ret = f(self, *args, **kwargs)
58 58 delta = self.outstanding.difference(self.client.outstanding)
59 59 completed = self.outstanding.intersection(delta)
60 60 self.outstanding = self.outstanding.difference(completed)
61 61 for msg_id in completed:
62 62 self.results[msg_id] = self.client.results[msg_id]
63 63 return ret
64 64
65 65 @decorator
66 66 def spin_after(f, self, *args, **kwargs):
67 67 """call spin after the method."""
68 68 ret = f(self, *args, **kwargs)
69 69 self.spin()
70 70 return ret
71 71
72 72 #-----------------------------------------------------------------------------
73 73 # Classes
74 74 #-----------------------------------------------------------------------------
75 75
76 76 @skip_doctest
77 77 class View(HasTraits):
78 78 """Base View class for more convenint apply(f,*args,**kwargs) syntax via attributes.
79 79
80 80 Don't use this class, use subclasses.
81 81
82 82 Methods
83 83 -------
84 84
85 85 spin
86 86 flushes incoming results and registration state changes
87 87 control methods spin, and requesting `ids` also ensures up to date
88 88
89 89 wait
90 90 wait on one or more msg_ids
91 91
92 92 execution methods
93 93 apply
94 94 legacy: execute, run
95 95
96 96 data movement
97 97 push, pull, scatter, gather
98 98
99 99 query methods
100 100 get_result, queue_status, purge_results, result_status
101 101
102 102 control methods
103 103 abort, shutdown
104 104
105 105 """
106 106 # flags
107 107 block=Bool(False)
108 108 track=Bool(True)
109 109 targets = Any()
110 110
111 111 history=List()
112 112 outstanding = Set()
113 113 results = Dict()
114 114 client = Instance('IPython.parallel.Client')
115 115
116 116 _socket = Instance('zmq.Socket')
117 117 _flag_names = List(['targets', 'block', 'track'])
118 118 _targets = Any()
119 119 _idents = Any()
120 120
121 121 def __init__(self, client=None, socket=None, **flags):
122 122 super(View, self).__init__(client=client, _socket=socket)
123 123 self.block = client.block
124 124
125 125 self.set_flags(**flags)
126 126
127 127 assert not self.__class__ is View, "Don't use base View objects, use subclasses"
128 128
129 129
130 130 def __repr__(self):
131 131 strtargets = str(self.targets)
132 132 if len(strtargets) > 16:
133 133 strtargets = strtargets[:12]+'...]'
134 134 return "<%s %s>"%(self.__class__.__name__, strtargets)
135 135
136 136 def set_flags(self, **kwargs):
137 137 """set my attribute flags by keyword.
138 138
139 139 Views determine behavior with a few attributes (`block`, `track`, etc.).
140 140 These attributes can be set all at once by name with this method.
141 141
142 142 Parameters
143 143 ----------
144 144
145 145 block : bool
146 146 whether to wait for results
147 147 track : bool
148 148 whether to create a MessageTracker to allow the user to
149 149 safely edit after arrays and buffers during non-copying
150 150 sends.
151 151 """
152 152 for name, value in kwargs.iteritems():
153 153 if name not in self._flag_names:
154 154 raise KeyError("Invalid name: %r"%name)
155 155 else:
156 156 setattr(self, name, value)
157 157
158 158 @contextmanager
159 159 def temp_flags(self, **kwargs):
160 160 """temporarily set flags, for use in `with` statements.
161 161
162 162 See set_flags for permanent setting of flags
163 163
164 164 Examples
165 165 --------
166 166
167 167 >>> view.track=False
168 168 ...
169 169 >>> with view.temp_flags(track=True):
170 170 ... ar = view.apply(dostuff, my_big_array)
171 171 ... ar.tracker.wait() # wait for send to finish
172 172 >>> view.track
173 173 False
174 174
175 175 """
176 176 # preflight: save flags, and set temporaries
177 177 saved_flags = {}
178 178 for f in self._flag_names:
179 179 saved_flags[f] = getattr(self, f)
180 180 self.set_flags(**kwargs)
181 181 # yield to the with-statement block
182 182 try:
183 183 yield
184 184 finally:
185 185 # postflight: restore saved flags
186 186 self.set_flags(**saved_flags)
187 187
188 188
189 189 #----------------------------------------------------------------
190 190 # apply
191 191 #----------------------------------------------------------------
192 192
193 193 @sync_results
194 194 @save_ids
195 195 def _really_apply(self, f, args, kwargs, block=None, **options):
196 196 """wrapper for client.send_apply_message"""
197 197 raise NotImplementedError("Implement in subclasses")
198 198
199 199 def apply(self, f, *args, **kwargs):
200 200 """calls f(*args, **kwargs) on remote engines, returning the result.
201 201
202 202 This method sets all apply flags via this View's attributes.
203 203
204 204 if self.block is False:
205 205 returns AsyncResult
206 206 else:
207 207 returns actual result of f(*args, **kwargs)
208 208 """
209 209 return self._really_apply(f, args, kwargs)
210 210
211 211 def apply_async(self, f, *args, **kwargs):
212 212 """calls f(*args, **kwargs) on remote engines in a nonblocking manner.
213 213
214 214 returns AsyncResult
215 215 """
216 216 return self._really_apply(f, args, kwargs, block=False)
217 217
218 218 @spin_after
219 219 def apply_sync(self, f, *args, **kwargs):
220 220 """calls f(*args, **kwargs) on remote engines in a blocking manner,
221 221 returning the result.
222 222
223 223 returns: actual result of f(*args, **kwargs)
224 224 """
225 225 return self._really_apply(f, args, kwargs, block=True)
226 226
227 227 #----------------------------------------------------------------
228 228 # wrappers for client and control methods
229 229 #----------------------------------------------------------------
230 230 @sync_results
231 231 def spin(self):
232 232 """spin the client, and sync"""
233 233 self.client.spin()
234 234
235 235 @sync_results
236 236 def wait(self, jobs=None, timeout=-1):
237 237 """waits on one or more `jobs`, for up to `timeout` seconds.
238 238
239 239 Parameters
240 240 ----------
241 241
242 242 jobs : int, str, or list of ints and/or strs, or one or more AsyncResult objects
243 243 ints are indices to self.history
244 244 strs are msg_ids
245 245 default: wait on all outstanding messages
246 246 timeout : float
247 247 a time in seconds, after which to give up.
248 248 default is -1, which means no timeout
249 249
250 250 Returns
251 251 -------
252 252
253 253 True : when all msg_ids are done
254 254 False : timeout reached, some msg_ids still outstanding
255 255 """
256 256 if jobs is None:
257 257 jobs = self.history
258 258 return self.client.wait(jobs, timeout)
259 259
260 260 def abort(self, jobs=None, targets=None, block=None):
261 261 """Abort jobs on my engines.
262 262
263 263 Parameters
264 264 ----------
265 265
266 266 jobs : None, str, list of strs, optional
267 267 if None: abort all jobs.
268 268 else: abort specific msg_id(s).
269 269 """
270 270 block = block if block is not None else self.block
271 271 targets = targets if targets is not None else self.targets
272 272 return self.client.abort(jobs=jobs, targets=targets, block=block)
273 273
274 274 def queue_status(self, targets=None, verbose=False):
275 275 """Fetch the Queue status of my engines"""
276 276 targets = targets if targets is not None else self.targets
277 277 return self.client.queue_status(targets=targets, verbose=verbose)
278 278
279 279 def purge_results(self, jobs=[], targets=[]):
280 280 """Instruct the controller to forget specific results."""
281 281 if targets is None or targets == 'all':
282 282 targets = self.targets
283 283 return self.client.purge_results(jobs=jobs, targets=targets)
284 284
285 285 def shutdown(self, targets=None, restart=False, hub=False, block=None):
286 286 """Terminates one or more engine processes, optionally including the hub.
287 287 """
288 288 block = self.block if block is None else block
289 289 if targets is None or targets == 'all':
290 290 targets = self.targets
291 291 return self.client.shutdown(targets=targets, restart=restart, hub=hub, block=block)
292 292
293 293 @spin_after
294 294 def get_result(self, indices_or_msg_ids=None):
295 295 """return one or more results, specified by history index or msg_id.
296 296
297 297 See client.get_result for details.
298 298
299 299 """
300 300
301 301 if indices_or_msg_ids is None:
302 302 indices_or_msg_ids = -1
303 303 if isinstance(indices_or_msg_ids, int):
304 304 indices_or_msg_ids = self.history[indices_or_msg_ids]
305 305 elif isinstance(indices_or_msg_ids, (list,tuple,set)):
306 306 indices_or_msg_ids = list(indices_or_msg_ids)
307 307 for i,index in enumerate(indices_or_msg_ids):
308 308 if isinstance(index, int):
309 309 indices_or_msg_ids[i] = self.history[index]
310 310 return self.client.get_result(indices_or_msg_ids)
311 311
312 312 #-------------------------------------------------------------------
313 313 # Map
314 314 #-------------------------------------------------------------------
315 315
316 316 def map(self, f, *sequences, **kwargs):
317 317 """override in subclasses"""
318 318 raise NotImplementedError
319 319
320 320 def map_async(self, f, *sequences, **kwargs):
321 321 """Parallel version of builtin `map`, using this view's engines.
322 322
323 323 This is equivalent to map(...block=False)
324 324
325 325 See `self.map` for details.
326 326 """
327 327 if 'block' in kwargs:
328 328 raise TypeError("map_async doesn't take a `block` keyword argument.")
329 329 kwargs['block'] = False
330 330 return self.map(f,*sequences,**kwargs)
331 331
332 332 def map_sync(self, f, *sequences, **kwargs):
333 333 """Parallel version of builtin `map`, using this view's engines.
334 334
335 335 This is equivalent to map(...block=True)
336 336
337 337 See `self.map` for details.
338 338 """
339 339 if 'block' in kwargs:
340 340 raise TypeError("map_sync doesn't take a `block` keyword argument.")
341 341 kwargs['block'] = True
342 342 return self.map(f,*sequences,**kwargs)
343 343
344 344 def imap(self, f, *sequences, **kwargs):
345 345 """Parallel version of `itertools.imap`.
346 346
347 347 See `self.map` for details.
348 348
349 349 """
350 350
351 351 return iter(self.map_async(f,*sequences, **kwargs))
352 352
353 353 #-------------------------------------------------------------------
354 354 # Decorators
355 355 #-------------------------------------------------------------------
356 356
357 357 def remote(self, block=True, **flags):
358 358 """Decorator for making a RemoteFunction"""
359 359 block = self.block if block is None else block
360 360 return remote(self, block=block, **flags)
361 361
362 362 def parallel(self, dist='b', block=None, **flags):
363 363 """Decorator for making a ParallelFunction"""
364 364 block = self.block if block is None else block
365 365 return parallel(self, dist=dist, block=block, **flags)
366 366
367 367 @skip_doctest
368 368 class DirectView(View):
369 369 """Direct Multiplexer View of one or more engines.
370 370
371 371 These are created via indexed access to a client:
372 372
373 373 >>> dv_1 = client[1]
374 374 >>> dv_all = client[:]
375 375 >>> dv_even = client[::2]
376 376 >>> dv_some = client[1:3]
377 377
378 378 This object provides dictionary access to engine namespaces:
379 379
380 380 # push a=5:
381 381 >>> dv['a'] = 5
382 382 # pull 'foo':
383 383 >>> db['foo']
384 384
385 385 """
386 386
387 387 def __init__(self, client=None, socket=None, targets=None):
388 388 super(DirectView, self).__init__(client=client, socket=socket, targets=targets)
389 389
390 390 @property
391 391 def importer(self):
392 392 """sync_imports(local=True) as a property.
393 393
394 394 See sync_imports for details.
395 395
396 396 """
397 397 return self.sync_imports(True)
398 398
399 399 @contextmanager
400 400 def sync_imports(self, local=True):
401 401 """Context Manager for performing simultaneous local and remote imports.
402 402
403 403 'import x as y' will *not* work. The 'as y' part will simply be ignored.
404 404
405 405 >>> with view.sync_imports():
406 406 ... from numpy import recarray
407 407 importing recarray from numpy on engine(s)
408 408
409 409 """
410 410 import __builtin__
411 411 local_import = __builtin__.__import__
412 412 modules = set()
413 413 results = []
414 414 @util.interactive
415 415 def remote_import(name, fromlist, level):
416 416 """the function to be passed to apply, that actually performs the import
417 417 on the engine, and loads up the user namespace.
418 418 """
419 419 import sys
420 420 user_ns = globals()
421 421 mod = __import__(name, fromlist=fromlist, level=level)
422 422 if fromlist:
423 423 for key in fromlist:
424 424 user_ns[key] = getattr(mod, key)
425 425 else:
426 426 user_ns[name] = sys.modules[name]
427 427
428 428 def view_import(name, globals={}, locals={}, fromlist=[], level=-1):
429 429 """the drop-in replacement for __import__, that optionally imports
430 430 locally as well.
431 431 """
432 432 # don't override nested imports
433 433 save_import = __builtin__.__import__
434 434 __builtin__.__import__ = local_import
435 435
436 436 if imp.lock_held():
437 437 # this is a side-effect import, don't do it remotely, or even
438 438 # ignore the local effects
439 439 return local_import(name, globals, locals, fromlist, level)
440 440
441 441 imp.acquire_lock()
442 442 if local:
443 443 mod = local_import(name, globals, locals, fromlist, level)
444 444 else:
445 445 raise NotImplementedError("remote-only imports not yet implemented")
446 446 imp.release_lock()
447 447
448 448 key = name+':'+','.join(fromlist or [])
449 449 if level == -1 and key not in modules:
450 450 modules.add(key)
451 451 if fromlist:
452 452 print "importing %s from %s on engine(s)"%(','.join(fromlist), name)
453 453 else:
454 454 print "importing %s on engine(s)"%name
455 455 results.append(self.apply_async(remote_import, name, fromlist, level))
456 456 # restore override
457 457 __builtin__.__import__ = save_import
458 458
459 459 return mod
460 460
461 461 # override __import__
462 462 __builtin__.__import__ = view_import
463 463 try:
464 464 # enter the block
465 465 yield
466 466 except ImportError:
467 467 if not local:
468 468 # ignore import errors if not doing local imports
469 469 pass
470 470 finally:
471 471 # always restore __import__
472 472 __builtin__.__import__ = local_import
473 473
474 474 for r in results:
475 475 # raise possible remote ImportErrors here
476 476 r.get()
477 477
478 478
479 479 @sync_results
480 480 @save_ids
481 481 def _really_apply(self, f, args=None, kwargs=None, targets=None, block=None, track=None):
482 482 """calls f(*args, **kwargs) on remote engines, returning the result.
483 483
484 484 This method sets all of `apply`'s flags via this View's attributes.
485 485
486 486 Parameters
487 487 ----------
488 488
489 489 f : callable
490 490
491 491 args : list [default: empty]
492 492
493 493 kwargs : dict [default: empty]
494 494
495 495 targets : target list [default: self.targets]
496 496 where to run
497 497 block : bool [default: self.block]
498 498 whether to block
499 499 track : bool [default: self.track]
500 500 whether to ask zmq to track the message, for safe non-copying sends
501 501
502 502 Returns
503 503 -------
504 504
505 505 if self.block is False:
506 506 returns AsyncResult
507 507 else:
508 508 returns actual result of f(*args, **kwargs) on the engine(s)
509 509 This will be a list of self.targets is also a list (even length 1), or
510 510 the single result if self.targets is an integer engine id
511 511 """
512 512 args = [] if args is None else args
513 513 kwargs = {} if kwargs is None else kwargs
514 514 block = self.block if block is None else block
515 515 track = self.track if track is None else track
516 516 targets = self.targets if targets is None else targets
517 517
518 518 _idents = self.client._build_targets(targets)[0]
519 519 msg_ids = []
520 520 trackers = []
521 521 for ident in _idents:
522 522 msg = self.client.send_apply_message(self._socket, f, args, kwargs, track=track,
523 523 ident=ident)
524 524 if track:
525 525 trackers.append(msg['tracker'])
526 msg_ids.append(msg['msg_id'])
526 msg_ids.append(msg['header']['msg_id'])
527 527 tracker = None if track is False else zmq.MessageTracker(*trackers)
528 528 ar = AsyncResult(self.client, msg_ids, fname=f.__name__, targets=targets, tracker=tracker)
529 529 if block:
530 530 try:
531 531 return ar.get()
532 532 except KeyboardInterrupt:
533 533 pass
534 534 return ar
535 535
536 536 @spin_after
537 537 def map(self, f, *sequences, **kwargs):
538 538 """view.map(f, *sequences, block=self.block) => list|AsyncMapResult
539 539
540 540 Parallel version of builtin `map`, using this View's `targets`.
541 541
542 542 There will be one task per target, so work will be chunked
543 543 if the sequences are longer than `targets`.
544 544
545 545 Results can be iterated as they are ready, but will become available in chunks.
546 546
547 547 Parameters
548 548 ----------
549 549
550 550 f : callable
551 551 function to be mapped
552 552 *sequences: one or more sequences of matching length
553 553 the sequences to be distributed and passed to `f`
554 554 block : bool
555 555 whether to wait for the result or not [default self.block]
556 556
557 557 Returns
558 558 -------
559 559
560 560 if block=False:
561 561 AsyncMapResult
562 562 An object like AsyncResult, but which reassembles the sequence of results
563 563 into a single list. AsyncMapResults can be iterated through before all
564 564 results are complete.
565 565 else:
566 566 list
567 567 the result of map(f,*sequences)
568 568 """
569 569
570 570 block = kwargs.pop('block', self.block)
571 571 for k in kwargs.keys():
572 572 if k not in ['block', 'track']:
573 573 raise TypeError("invalid keyword arg, %r"%k)
574 574
575 575 assert len(sequences) > 0, "must have some sequences to map onto!"
576 576 pf = ParallelFunction(self, f, block=block, **kwargs)
577 577 return pf.map(*sequences)
578 578
579 579 def execute(self, code, targets=None, block=None):
580 580 """Executes `code` on `targets` in blocking or nonblocking manner.
581 581
582 582 ``execute`` is always `bound` (affects engine namespace)
583 583
584 584 Parameters
585 585 ----------
586 586
587 587 code : str
588 588 the code string to be executed
589 589 block : bool
590 590 whether or not to wait until done to return
591 591 default: self.block
592 592 """
593 593 return self._really_apply(util._execute, args=(code,), block=block, targets=targets)
594 594
595 595 def run(self, filename, targets=None, block=None):
596 596 """Execute contents of `filename` on my engine(s).
597 597
598 598 This simply reads the contents of the file and calls `execute`.
599 599
600 600 Parameters
601 601 ----------
602 602
603 603 filename : str
604 604 The path to the file
605 605 targets : int/str/list of ints/strs
606 606 the engines on which to execute
607 607 default : all
608 608 block : bool
609 609 whether or not to wait until done
610 610 default: self.block
611 611
612 612 """
613 613 with open(filename, 'r') as f:
614 614 # add newline in case of trailing indented whitespace
615 615 # which will cause SyntaxError
616 616 code = f.read()+'\n'
617 617 return self.execute(code, block=block, targets=targets)
618 618
619 619 def update(self, ns):
620 620 """update remote namespace with dict `ns`
621 621
622 622 See `push` for details.
623 623 """
624 624 return self.push(ns, block=self.block, track=self.track)
625 625
626 626 def push(self, ns, targets=None, block=None, track=None):
627 627 """update remote namespace with dict `ns`
628 628
629 629 Parameters
630 630 ----------
631 631
632 632 ns : dict
633 633 dict of keys with which to update engine namespace(s)
634 634 block : bool [default : self.block]
635 635 whether to wait to be notified of engine receipt
636 636
637 637 """
638 638
639 639 block = block if block is not None else self.block
640 640 track = track if track is not None else self.track
641 641 targets = targets if targets is not None else self.targets
642 642 # applier = self.apply_sync if block else self.apply_async
643 643 if not isinstance(ns, dict):
644 644 raise TypeError("Must be a dict, not %s"%type(ns))
645 645 return self._really_apply(util._push, (ns,), block=block, track=track, targets=targets)
646 646
647 647 def get(self, key_s):
648 648 """get object(s) by `key_s` from remote namespace
649 649
650 650 see `pull` for details.
651 651 """
652 652 # block = block if block is not None else self.block
653 653 return self.pull(key_s, block=True)
654 654
655 655 def pull(self, names, targets=None, block=None):
656 656 """get object(s) by `name` from remote namespace
657 657
658 658 will return one object if it is a key.
659 659 can also take a list of keys, in which case it will return a list of objects.
660 660 """
661 661 block = block if block is not None else self.block
662 662 targets = targets if targets is not None else self.targets
663 663 applier = self.apply_sync if block else self.apply_async
664 664 if isinstance(names, basestring):
665 665 pass
666 666 elif isinstance(names, (list,tuple,set)):
667 667 for key in names:
668 668 if not isinstance(key, basestring):
669 669 raise TypeError("keys must be str, not type %r"%type(key))
670 670 else:
671 671 raise TypeError("names must be strs, not %r"%names)
672 672 return self._really_apply(util._pull, (names,), block=block, targets=targets)
673 673
674 674 def scatter(self, key, seq, dist='b', flatten=False, targets=None, block=None, track=None):
675 675 """
676 676 Partition a Python sequence and send the partitions to a set of engines.
677 677 """
678 678 block = block if block is not None else self.block
679 679 track = track if track is not None else self.track
680 680 targets = targets if targets is not None else self.targets
681 681
682 682 mapObject = Map.dists[dist]()
683 683 nparts = len(targets)
684 684 msg_ids = []
685 685 trackers = []
686 686 for index, engineid in enumerate(targets):
687 687 partition = mapObject.getPartition(seq, index, nparts)
688 688 if flatten and len(partition) == 1:
689 689 ns = {key: partition[0]}
690 690 else:
691 691 ns = {key: partition}
692 692 r = self.push(ns, block=False, track=track, targets=engineid)
693 693 msg_ids.extend(r.msg_ids)
694 694 if track:
695 695 trackers.append(r._tracker)
696 696
697 697 if track:
698 698 tracker = zmq.MessageTracker(*trackers)
699 699 else:
700 700 tracker = None
701 701
702 702 r = AsyncResult(self.client, msg_ids, fname='scatter', targets=targets, tracker=tracker)
703 703 if block:
704 704 r.wait()
705 705 else:
706 706 return r
707 707
708 708 @sync_results
709 709 @save_ids
710 710 def gather(self, key, dist='b', targets=None, block=None):
711 711 """
712 712 Gather a partitioned sequence on a set of engines as a single local seq.
713 713 """
714 714 block = block if block is not None else self.block
715 715 targets = targets if targets is not None else self.targets
716 716 mapObject = Map.dists[dist]()
717 717 msg_ids = []
718 718
719 719 for index, engineid in enumerate(targets):
720 720 msg_ids.extend(self.pull(key, block=False, targets=engineid).msg_ids)
721 721
722 722 r = AsyncMapResult(self.client, msg_ids, mapObject, fname='gather')
723 723
724 724 if block:
725 725 try:
726 726 return r.get()
727 727 except KeyboardInterrupt:
728 728 pass
729 729 return r
730 730
731 731 def __getitem__(self, key):
732 732 return self.get(key)
733 733
734 734 def __setitem__(self,key, value):
735 735 self.update({key:value})
736 736
737 737 def clear(self, targets=None, block=False):
738 738 """Clear the remote namespaces on my engines."""
739 739 block = block if block is not None else self.block
740 740 targets = targets if targets is not None else self.targets
741 741 return self.client.clear(targets=targets, block=block)
742 742
743 743 def kill(self, targets=None, block=True):
744 744 """Kill my engines."""
745 745 block = block if block is not None else self.block
746 746 targets = targets if targets is not None else self.targets
747 747 return self.client.kill(targets=targets, block=block)
748 748
749 749 #----------------------------------------
750 750 # activate for %px,%autopx magics
751 751 #----------------------------------------
752 752 def activate(self):
753 753 """Make this `View` active for parallel magic commands.
754 754
755 755 IPython has a magic command syntax to work with `MultiEngineClient` objects.
756 756 In a given IPython session there is a single active one. While
757 757 there can be many `Views` created and used by the user,
758 758 there is only one active one. The active `View` is used whenever
759 759 the magic commands %px and %autopx are used.
760 760
761 761 The activate() method is called on a given `View` to make it
762 762 active. Once this has been done, the magic commands can be used.
763 763 """
764 764
765 765 try:
766 766 # This is injected into __builtins__.
767 767 ip = get_ipython()
768 768 except NameError:
769 769 print "The IPython parallel magics (%result, %px, %autopx) only work within IPython."
770 770 else:
771 771 pmagic = ip.plugin_manager.get_plugin('parallelmagic')
772 772 if pmagic is None:
773 773 ip.magic_load_ext('parallelmagic')
774 774 pmagic = ip.plugin_manager.get_plugin('parallelmagic')
775 775
776 776 pmagic.active_view = self
777 777
778 778
779 779 @skip_doctest
780 780 class LoadBalancedView(View):
781 781 """An load-balancing View that only executes via the Task scheduler.
782 782
783 783 Load-balanced views can be created with the client's `view` method:
784 784
785 785 >>> v = client.load_balanced_view()
786 786
787 787 or targets can be specified, to restrict the potential destinations:
788 788
789 789 >>> v = client.client.load_balanced_view(([1,3])
790 790
791 791 which would restrict loadbalancing to between engines 1 and 3.
792 792
793 793 """
794 794
795 795 follow=Any()
796 796 after=Any()
797 797 timeout=CFloat()
798 798 retries = CInt(0)
799 799
800 800 _task_scheme = Any()
801 801 _flag_names = List(['targets', 'block', 'track', 'follow', 'after', 'timeout', 'retries'])
802 802
803 803 def __init__(self, client=None, socket=None, **flags):
804 804 super(LoadBalancedView, self).__init__(client=client, socket=socket, **flags)
805 805 self._task_scheme=client._task_scheme
806 806
807 807 def _validate_dependency(self, dep):
808 808 """validate a dependency.
809 809
810 810 For use in `set_flags`.
811 811 """
812 812 if dep is None or isinstance(dep, (basestring, AsyncResult, Dependency)):
813 813 return True
814 814 elif isinstance(dep, (list,set, tuple)):
815 815 for d in dep:
816 816 if not isinstance(d, (basestring, AsyncResult)):
817 817 return False
818 818 elif isinstance(dep, dict):
819 819 if set(dep.keys()) != set(Dependency().as_dict().keys()):
820 820 return False
821 821 if not isinstance(dep['msg_ids'], list):
822 822 return False
823 823 for d in dep['msg_ids']:
824 824 if not isinstance(d, basestring):
825 825 return False
826 826 else:
827 827 return False
828 828
829 829 return True
830 830
831 831 def _render_dependency(self, dep):
832 832 """helper for building jsonable dependencies from various input forms."""
833 833 if isinstance(dep, Dependency):
834 834 return dep.as_dict()
835 835 elif isinstance(dep, AsyncResult):
836 836 return dep.msg_ids
837 837 elif dep is None:
838 838 return []
839 839 else:
840 840 # pass to Dependency constructor
841 841 return list(Dependency(dep))
842 842
843 843 def set_flags(self, **kwargs):
844 844 """set my attribute flags by keyword.
845 845
846 846 A View is a wrapper for the Client's apply method, but with attributes
847 847 that specify keyword arguments, those attributes can be set by keyword
848 848 argument with this method.
849 849
850 850 Parameters
851 851 ----------
852 852
853 853 block : bool
854 854 whether to wait for results
855 855 track : bool
856 856 whether to create a MessageTracker to allow the user to
857 857 safely edit after arrays and buffers during non-copying
858 858 sends.
859 859
860 860 after : Dependency or collection of msg_ids
861 861 Only for load-balanced execution (targets=None)
862 862 Specify a list of msg_ids as a time-based dependency.
863 863 This job will only be run *after* the dependencies
864 864 have been met.
865 865
866 866 follow : Dependency or collection of msg_ids
867 867 Only for load-balanced execution (targets=None)
868 868 Specify a list of msg_ids as a location-based dependency.
869 869 This job will only be run on an engine where this dependency
870 870 is met.
871 871
872 872 timeout : float/int or None
873 873 Only for load-balanced execution (targets=None)
874 874 Specify an amount of time (in seconds) for the scheduler to
875 875 wait for dependencies to be met before failing with a
876 876 DependencyTimeout.
877 877
878 878 retries : int
879 879 Number of times a task will be retried on failure.
880 880 """
881 881
882 882 super(LoadBalancedView, self).set_flags(**kwargs)
883 883 for name in ('follow', 'after'):
884 884 if name in kwargs:
885 885 value = kwargs[name]
886 886 if self._validate_dependency(value):
887 887 setattr(self, name, value)
888 888 else:
889 889 raise ValueError("Invalid dependency: %r"%value)
890 890 if 'timeout' in kwargs:
891 891 t = kwargs['timeout']
892 892 if not isinstance(t, (int, long, float, type(None))):
893 893 raise TypeError("Invalid type for timeout: %r"%type(t))
894 894 if t is not None:
895 895 if t < 0:
896 896 raise ValueError("Invalid timeout: %s"%t)
897 897 self.timeout = t
898 898
899 899 @sync_results
900 900 @save_ids
901 901 def _really_apply(self, f, args=None, kwargs=None, block=None, track=None,
902 902 after=None, follow=None, timeout=None,
903 903 targets=None, retries=None):
904 904 """calls f(*args, **kwargs) on a remote engine, returning the result.
905 905
906 906 This method temporarily sets all of `apply`'s flags for a single call.
907 907
908 908 Parameters
909 909 ----------
910 910
911 911 f : callable
912 912
913 913 args : list [default: empty]
914 914
915 915 kwargs : dict [default: empty]
916 916
917 917 block : bool [default: self.block]
918 918 whether to block
919 919 track : bool [default: self.track]
920 920 whether to ask zmq to track the message, for safe non-copying sends
921 921
922 922 !!!!!! TODO: THE REST HERE !!!!
923 923
924 924 Returns
925 925 -------
926 926
927 927 if self.block is False:
928 928 returns AsyncResult
929 929 else:
930 930 returns actual result of f(*args, **kwargs) on the engine(s)
931 931 This will be a list of self.targets is also a list (even length 1), or
932 932 the single result if self.targets is an integer engine id
933 933 """
934 934
935 935 # validate whether we can run
936 936 if self._socket.closed:
937 937 msg = "Task farming is disabled"
938 938 if self._task_scheme == 'pure':
939 939 msg += " because the pure ZMQ scheduler cannot handle"
940 940 msg += " disappearing engines."
941 941 raise RuntimeError(msg)
942 942
943 943 if self._task_scheme == 'pure':
944 944 # pure zmq scheme doesn't support extra features
945 945 msg = "Pure ZMQ scheduler doesn't support the following flags:"
946 946 "follow, after, retries, targets, timeout"
947 947 if (follow or after or retries or targets or timeout):
948 948 # hard fail on Scheduler flags
949 949 raise RuntimeError(msg)
950 950 if isinstance(f, dependent):
951 951 # soft warn on functional dependencies
952 952 warnings.warn(msg, RuntimeWarning)
953 953
954 954 # build args
955 955 args = [] if args is None else args
956 956 kwargs = {} if kwargs is None else kwargs
957 957 block = self.block if block is None else block
958 958 track = self.track if track is None else track
959 959 after = self.after if after is None else after
960 960 retries = self.retries if retries is None else retries
961 961 follow = self.follow if follow is None else follow
962 962 timeout = self.timeout if timeout is None else timeout
963 963 targets = self.targets if targets is None else targets
964 964
965 965 if not isinstance(retries, int):
966 966 raise TypeError('retries must be int, not %r'%type(retries))
967 967
968 968 if targets is None:
969 969 idents = []
970 970 else:
971 971 idents = self.client._build_targets(targets)[0]
972 972 # ensure *not* bytes
973 973 idents = [ ident.decode() for ident in idents ]
974 974
975 975 after = self._render_dependency(after)
976 976 follow = self._render_dependency(follow)
977 977 subheader = dict(after=after, follow=follow, timeout=timeout, targets=idents, retries=retries)
978 978
979 979 msg = self.client.send_apply_message(self._socket, f, args, kwargs, track=track,
980 980 subheader=subheader)
981 981 tracker = None if track is False else msg['tracker']
982 982
983 ar = AsyncResult(self.client, msg['msg_id'], fname=f.__name__, targets=None, tracker=tracker)
983 ar = AsyncResult(self.client, msg['header']['msg_id'], fname=f.__name__, targets=None, tracker=tracker)
984 984
985 985 if block:
986 986 try:
987 987 return ar.get()
988 988 except KeyboardInterrupt:
989 989 pass
990 990 return ar
991 991
992 992 @spin_after
993 993 @save_ids
994 994 def map(self, f, *sequences, **kwargs):
995 995 """view.map(f, *sequences, block=self.block, chunksize=1) => list|AsyncMapResult
996 996
997 997 Parallel version of builtin `map`, load-balanced by this View.
998 998
999 999 `block`, and `chunksize` can be specified by keyword only.
1000 1000
1001 1001 Each `chunksize` elements will be a separate task, and will be
1002 1002 load-balanced. This lets individual elements be available for iteration
1003 1003 as soon as they arrive.
1004 1004
1005 1005 Parameters
1006 1006 ----------
1007 1007
1008 1008 f : callable
1009 1009 function to be mapped
1010 1010 *sequences: one or more sequences of matching length
1011 1011 the sequences to be distributed and passed to `f`
1012 1012 block : bool
1013 1013 whether to wait for the result or not [default self.block]
1014 1014 track : bool
1015 1015 whether to create a MessageTracker to allow the user to
1016 1016 safely edit after arrays and buffers during non-copying
1017 1017 sends.
1018 1018 chunksize : int
1019 1019 how many elements should be in each task [default 1]
1020 1020
1021 1021 Returns
1022 1022 -------
1023 1023
1024 1024 if block=False:
1025 1025 AsyncMapResult
1026 1026 An object like AsyncResult, but which reassembles the sequence of results
1027 1027 into a single list. AsyncMapResults can be iterated through before all
1028 1028 results are complete.
1029 1029 else:
1030 1030 the result of map(f,*sequences)
1031 1031
1032 1032 """
1033 1033
1034 1034 # default
1035 1035 block = kwargs.get('block', self.block)
1036 1036 chunksize = kwargs.get('chunksize', 1)
1037 1037
1038 1038 keyset = set(kwargs.keys())
1039 1039 extra_keys = keyset.difference_update(set(['block', 'chunksize']))
1040 1040 if extra_keys:
1041 1041 raise TypeError("Invalid kwargs: %s"%list(extra_keys))
1042 1042
1043 1043 assert len(sequences) > 0, "must have some sequences to map onto!"
1044 1044
1045 1045 pf = ParallelFunction(self, f, block=block, chunksize=chunksize)
1046 1046 return pf.map(*sequences)
1047 1047
1048 1048 __all__ = ['LoadBalancedView', 'DirectView']
@@ -1,1291 +1,1291 b''
1 1 #!/usr/bin/env python
2 2 """The IPython Controller Hub with 0MQ
3 3 This is the master object that handles connections from engines and clients,
4 4 and monitors traffic through the various queues.
5 5
6 6 Authors:
7 7
8 8 * Min RK
9 9 """
10 10 #-----------------------------------------------------------------------------
11 11 # Copyright (C) 2010 The IPython Development Team
12 12 #
13 13 # Distributed under the terms of the BSD License. The full license is in
14 14 # the file COPYING, distributed as part of this software.
15 15 #-----------------------------------------------------------------------------
16 16
17 17 #-----------------------------------------------------------------------------
18 18 # Imports
19 19 #-----------------------------------------------------------------------------
20 20 from __future__ import print_function
21 21
22 22 import sys
23 23 import time
24 24 from datetime import datetime
25 25
26 26 import zmq
27 27 from zmq.eventloop import ioloop
28 28 from zmq.eventloop.zmqstream import ZMQStream
29 29
30 30 # internal:
31 31 from IPython.utils.importstring import import_item
32 32 from IPython.utils.traitlets import (
33 33 HasTraits, Instance, Int, Unicode, Dict, Set, Tuple, CBytes, DottedObjectName
34 34 )
35 35
36 36 from IPython.parallel import error, util
37 37 from IPython.parallel.factory import RegistrationFactory
38 38
39 39 from IPython.zmq.session import SessionFactory
40 40
41 41 from .heartmonitor import HeartMonitor
42 42
43 43 #-----------------------------------------------------------------------------
44 44 # Code
45 45 #-----------------------------------------------------------------------------
46 46
47 47 def _passer(*args, **kwargs):
48 48 return
49 49
50 50 def _printer(*args, **kwargs):
51 51 print (args)
52 52 print (kwargs)
53 53
54 54 def empty_record():
55 55 """Return an empty dict with all record keys."""
56 56 return {
57 57 'msg_id' : None,
58 58 'header' : None,
59 59 'content': None,
60 60 'buffers': None,
61 61 'submitted': None,
62 62 'client_uuid' : None,
63 63 'engine_uuid' : None,
64 64 'started': None,
65 65 'completed': None,
66 66 'resubmitted': None,
67 67 'result_header' : None,
68 68 'result_content' : None,
69 69 'result_buffers' : None,
70 70 'queue' : None,
71 71 'pyin' : None,
72 72 'pyout': None,
73 73 'pyerr': None,
74 74 'stdout': '',
75 75 'stderr': '',
76 76 }
77 77
78 78 def init_record(msg):
79 79 """Initialize a TaskRecord based on a request."""
80 80 header = msg['header']
81 81 return {
82 82 'msg_id' : header['msg_id'],
83 83 'header' : header,
84 84 'content': msg['content'],
85 85 'buffers': msg['buffers'],
86 86 'submitted': header['date'],
87 87 'client_uuid' : None,
88 88 'engine_uuid' : None,
89 89 'started': None,
90 90 'completed': None,
91 91 'resubmitted': None,
92 92 'result_header' : None,
93 93 'result_content' : None,
94 94 'result_buffers' : None,
95 95 'queue' : None,
96 96 'pyin' : None,
97 97 'pyout': None,
98 98 'pyerr': None,
99 99 'stdout': '',
100 100 'stderr': '',
101 101 }
102 102
103 103
104 104 class EngineConnector(HasTraits):
105 105 """A simple object for accessing the various zmq connections of an object.
106 106 Attributes are:
107 107 id (int): engine ID
108 108 uuid (str): uuid (unused?)
109 109 queue (str): identity of queue's XREQ socket
110 110 registration (str): identity of registration XREQ socket
111 111 heartbeat (str): identity of heartbeat XREQ socket
112 112 """
113 113 id=Int(0)
114 114 queue=CBytes()
115 115 control=CBytes()
116 116 registration=CBytes()
117 117 heartbeat=CBytes()
118 118 pending=Set()
119 119
120 120 class HubFactory(RegistrationFactory):
121 121 """The Configurable for setting up a Hub."""
122 122
123 123 # port-pairs for monitoredqueues:
124 124 hb = Tuple(Int,Int,config=True,
125 125 help="""XREQ/SUB Port pair for Engine heartbeats""")
126 126 def _hb_default(self):
127 127 return tuple(util.select_random_ports(2))
128 128
129 129 mux = Tuple(Int,Int,config=True,
130 130 help="""Engine/Client Port pair for MUX queue""")
131 131
132 132 def _mux_default(self):
133 133 return tuple(util.select_random_ports(2))
134 134
135 135 task = Tuple(Int,Int,config=True,
136 136 help="""Engine/Client Port pair for Task queue""")
137 137 def _task_default(self):
138 138 return tuple(util.select_random_ports(2))
139 139
140 140 control = Tuple(Int,Int,config=True,
141 141 help="""Engine/Client Port pair for Control queue""")
142 142
143 143 def _control_default(self):
144 144 return tuple(util.select_random_ports(2))
145 145
146 146 iopub = Tuple(Int,Int,config=True,
147 147 help="""Engine/Client Port pair for IOPub relay""")
148 148
149 149 def _iopub_default(self):
150 150 return tuple(util.select_random_ports(2))
151 151
152 152 # single ports:
153 153 mon_port = Int(config=True,
154 154 help="""Monitor (SUB) port for queue traffic""")
155 155
156 156 def _mon_port_default(self):
157 157 return util.select_random_ports(1)[0]
158 158
159 159 notifier_port = Int(config=True,
160 160 help="""PUB port for sending engine status notifications""")
161 161
162 162 def _notifier_port_default(self):
163 163 return util.select_random_ports(1)[0]
164 164
165 165 engine_ip = Unicode('127.0.0.1', config=True,
166 166 help="IP on which to listen for engine connections. [default: loopback]")
167 167 engine_transport = Unicode('tcp', config=True,
168 168 help="0MQ transport for engine connections. [default: tcp]")
169 169
170 170 client_ip = Unicode('127.0.0.1', config=True,
171 171 help="IP on which to listen for client connections. [default: loopback]")
172 172 client_transport = Unicode('tcp', config=True,
173 173 help="0MQ transport for client connections. [default : tcp]")
174 174
175 175 monitor_ip = Unicode('127.0.0.1', config=True,
176 176 help="IP on which to listen for monitor messages. [default: loopback]")
177 177 monitor_transport = Unicode('tcp', config=True,
178 178 help="0MQ transport for monitor messages. [default : tcp]")
179 179
180 180 monitor_url = Unicode('')
181 181
182 182 db_class = DottedObjectName('IPython.parallel.controller.dictdb.DictDB',
183 183 config=True, help="""The class to use for the DB backend""")
184 184
185 185 # not configurable
186 186 db = Instance('IPython.parallel.controller.dictdb.BaseDB')
187 187 heartmonitor = Instance('IPython.parallel.controller.heartmonitor.HeartMonitor')
188 188
189 189 def _ip_changed(self, name, old, new):
190 190 self.engine_ip = new
191 191 self.client_ip = new
192 192 self.monitor_ip = new
193 193 self._update_monitor_url()
194 194
195 195 def _update_monitor_url(self):
196 196 self.monitor_url = "%s://%s:%i"%(self.monitor_transport, self.monitor_ip, self.mon_port)
197 197
198 198 def _transport_changed(self, name, old, new):
199 199 self.engine_transport = new
200 200 self.client_transport = new
201 201 self.monitor_transport = new
202 202 self._update_monitor_url()
203 203
204 204 def __init__(self, **kwargs):
205 205 super(HubFactory, self).__init__(**kwargs)
206 206 self._update_monitor_url()
207 207
208 208
209 209 def construct(self):
210 210 self.init_hub()
211 211
212 212 def start(self):
213 213 self.heartmonitor.start()
214 214 self.log.info("Heartmonitor started")
215 215
216 216 def init_hub(self):
217 217 """construct"""
218 218 client_iface = "%s://%s:"%(self.client_transport, self.client_ip) + "%i"
219 219 engine_iface = "%s://%s:"%(self.engine_transport, self.engine_ip) + "%i"
220 220
221 221 ctx = self.context
222 222 loop = self.loop
223 223
224 224 # Registrar socket
225 225 q = ZMQStream(ctx.socket(zmq.XREP), loop)
226 226 q.bind(client_iface % self.regport)
227 227 self.log.info("Hub listening on %s for registration."%(client_iface%self.regport))
228 228 if self.client_ip != self.engine_ip:
229 229 q.bind(engine_iface % self.regport)
230 230 self.log.info("Hub listening on %s for registration."%(engine_iface%self.regport))
231 231
232 232 ### Engine connections ###
233 233
234 234 # heartbeat
235 235 hpub = ctx.socket(zmq.PUB)
236 236 hpub.bind(engine_iface % self.hb[0])
237 237 hrep = ctx.socket(zmq.XREP)
238 238 hrep.bind(engine_iface % self.hb[1])
239 239 self.heartmonitor = HeartMonitor(loop=loop, config=self.config, log=self.log,
240 240 pingstream=ZMQStream(hpub,loop),
241 241 pongstream=ZMQStream(hrep,loop)
242 242 )
243 243
244 244 ### Client connections ###
245 245 # Notifier socket
246 246 n = ZMQStream(ctx.socket(zmq.PUB), loop)
247 247 n.bind(client_iface%self.notifier_port)
248 248
249 249 ### build and launch the queues ###
250 250
251 251 # monitor socket
252 252 sub = ctx.socket(zmq.SUB)
253 253 sub.setsockopt(zmq.SUBSCRIBE, b"")
254 254 sub.bind(self.monitor_url)
255 255 sub.bind('inproc://monitor')
256 256 sub = ZMQStream(sub, loop)
257 257
258 258 # connect the db
259 259 self.log.info('Hub using DB backend: %r'%(self.db_class.split()[-1]))
260 260 # cdir = self.config.Global.cluster_dir
261 261 self.db = import_item(str(self.db_class))(session=self.session.session,
262 262 config=self.config, log=self.log)
263 263 time.sleep(.25)
264 264 try:
265 265 scheme = self.config.TaskScheduler.scheme_name
266 266 except AttributeError:
267 267 from .scheduler import TaskScheduler
268 268 scheme = TaskScheduler.scheme_name.get_default_value()
269 269 # build connection dicts
270 270 self.engine_info = {
271 271 'control' : engine_iface%self.control[1],
272 272 'mux': engine_iface%self.mux[1],
273 273 'heartbeat': (engine_iface%self.hb[0], engine_iface%self.hb[1]),
274 274 'task' : engine_iface%self.task[1],
275 275 'iopub' : engine_iface%self.iopub[1],
276 276 # 'monitor' : engine_iface%self.mon_port,
277 277 }
278 278
279 279 self.client_info = {
280 280 'control' : client_iface%self.control[0],
281 281 'mux': client_iface%self.mux[0],
282 282 'task' : (scheme, client_iface%self.task[0]),
283 283 'iopub' : client_iface%self.iopub[0],
284 284 'notification': client_iface%self.notifier_port
285 285 }
286 286 self.log.debug("Hub engine addrs: %s"%self.engine_info)
287 287 self.log.debug("Hub client addrs: %s"%self.client_info)
288 288
289 289 # resubmit stream
290 290 r = ZMQStream(ctx.socket(zmq.XREQ), loop)
291 291 url = util.disambiguate_url(self.client_info['task'][-1])
292 292 r.setsockopt(zmq.IDENTITY, util.asbytes(self.session.session))
293 293 r.connect(url)
294 294
295 295 self.hub = Hub(loop=loop, session=self.session, monitor=sub, heartmonitor=self.heartmonitor,
296 296 query=q, notifier=n, resubmit=r, db=self.db,
297 297 engine_info=self.engine_info, client_info=self.client_info,
298 298 log=self.log)
299 299
300 300
301 301 class Hub(SessionFactory):
302 302 """The IPython Controller Hub with 0MQ connections
303 303
304 304 Parameters
305 305 ==========
306 306 loop: zmq IOLoop instance
307 307 session: Session object
308 308 <removed> context: zmq context for creating new connections (?)
309 309 queue: ZMQStream for monitoring the command queue (SUB)
310 310 query: ZMQStream for engine registration and client queries requests (XREP)
311 311 heartbeat: HeartMonitor object checking the pulse of the engines
312 312 notifier: ZMQStream for broadcasting engine registration changes (PUB)
313 313 db: connection to db for out of memory logging of commands
314 314 NotImplemented
315 315 engine_info: dict of zmq connection information for engines to connect
316 316 to the queues.
317 317 client_info: dict of zmq connection information for engines to connect
318 318 to the queues.
319 319 """
320 320 # internal data structures:
321 321 ids=Set() # engine IDs
322 322 keytable=Dict()
323 323 by_ident=Dict()
324 324 engines=Dict()
325 325 clients=Dict()
326 326 hearts=Dict()
327 327 pending=Set()
328 328 queues=Dict() # pending msg_ids keyed by engine_id
329 329 tasks=Dict() # pending msg_ids submitted as tasks, keyed by client_id
330 330 completed=Dict() # completed msg_ids keyed by engine_id
331 331 all_completed=Set() # completed msg_ids keyed by engine_id
332 332 dead_engines=Set() # completed msg_ids keyed by engine_id
333 333 unassigned=Set() # set of task msg_ds not yet assigned a destination
334 334 incoming_registrations=Dict()
335 335 registration_timeout=Int()
336 336 _idcounter=Int(0)
337 337
338 338 # objects from constructor:
339 339 query=Instance(ZMQStream)
340 340 monitor=Instance(ZMQStream)
341 341 notifier=Instance(ZMQStream)
342 342 resubmit=Instance(ZMQStream)
343 343 heartmonitor=Instance(HeartMonitor)
344 344 db=Instance(object)
345 345 client_info=Dict()
346 346 engine_info=Dict()
347 347
348 348
349 349 def __init__(self, **kwargs):
350 350 """
351 351 # universal:
352 352 loop: IOLoop for creating future connections
353 353 session: streamsession for sending serialized data
354 354 # engine:
355 355 queue: ZMQStream for monitoring queue messages
356 356 query: ZMQStream for engine+client registration and client requests
357 357 heartbeat: HeartMonitor object for tracking engines
358 358 # extra:
359 359 db: ZMQStream for db connection (NotImplemented)
360 360 engine_info: zmq address/protocol dict for engine connections
361 361 client_info: zmq address/protocol dict for client connections
362 362 """
363 363
364 364 super(Hub, self).__init__(**kwargs)
365 365 self.registration_timeout = max(5000, 2*self.heartmonitor.period)
366 366
367 367 # validate connection dicts:
368 368 for k,v in self.client_info.iteritems():
369 369 if k == 'task':
370 370 util.validate_url_container(v[1])
371 371 else:
372 372 util.validate_url_container(v)
373 373 # util.validate_url_container(self.client_info)
374 374 util.validate_url_container(self.engine_info)
375 375
376 376 # register our callbacks
377 377 self.query.on_recv(self.dispatch_query)
378 378 self.monitor.on_recv(self.dispatch_monitor_traffic)
379 379
380 380 self.heartmonitor.add_heart_failure_handler(self.handle_heart_failure)
381 381 self.heartmonitor.add_new_heart_handler(self.handle_new_heart)
382 382
383 383 self.monitor_handlers = {b'in' : self.save_queue_request,
384 384 b'out': self.save_queue_result,
385 385 b'intask': self.save_task_request,
386 386 b'outtask': self.save_task_result,
387 387 b'tracktask': self.save_task_destination,
388 388 b'incontrol': _passer,
389 389 b'outcontrol': _passer,
390 390 b'iopub': self.save_iopub_message,
391 391 }
392 392
393 393 self.query_handlers = {'queue_request': self.queue_status,
394 394 'result_request': self.get_results,
395 395 'history_request': self.get_history,
396 396 'db_request': self.db_query,
397 397 'purge_request': self.purge_results,
398 398 'load_request': self.check_load,
399 399 'resubmit_request': self.resubmit_task,
400 400 'shutdown_request': self.shutdown_request,
401 401 'registration_request' : self.register_engine,
402 402 'unregistration_request' : self.unregister_engine,
403 403 'connection_request': self.connection_request,
404 404 }
405 405
406 406 # ignore resubmit replies
407 407 self.resubmit.on_recv(lambda msg: None, copy=False)
408 408
409 409 self.log.info("hub::created hub")
410 410
411 411 @property
412 412 def _next_id(self):
413 413 """gemerate a new ID.
414 414
415 415 No longer reuse old ids, just count from 0."""
416 416 newid = self._idcounter
417 417 self._idcounter += 1
418 418 return newid
419 419 # newid = 0
420 420 # incoming = [id[0] for id in self.incoming_registrations.itervalues()]
421 421 # # print newid, self.ids, self.incoming_registrations
422 422 # while newid in self.ids or newid in incoming:
423 423 # newid += 1
424 424 # return newid
425 425
426 426 #-----------------------------------------------------------------------------
427 427 # message validation
428 428 #-----------------------------------------------------------------------------
429 429
430 430 def _validate_targets(self, targets):
431 431 """turn any valid targets argument into a list of integer ids"""
432 432 if targets is None:
433 433 # default to all
434 434 targets = self.ids
435 435
436 436 if isinstance(targets, (int,str,unicode)):
437 437 # only one target specified
438 438 targets = [targets]
439 439 _targets = []
440 440 for t in targets:
441 441 # map raw identities to ids
442 442 if isinstance(t, (str,unicode)):
443 443 t = self.by_ident.get(t, t)
444 444 _targets.append(t)
445 445 targets = _targets
446 446 bad_targets = [ t for t in targets if t not in self.ids ]
447 447 if bad_targets:
448 448 raise IndexError("No Such Engine: %r"%bad_targets)
449 449 if not targets:
450 450 raise IndexError("No Engines Registered")
451 451 return targets
452 452
453 453 #-----------------------------------------------------------------------------
454 454 # dispatch methods (1 per stream)
455 455 #-----------------------------------------------------------------------------
456 456
457 457
458 458 def dispatch_monitor_traffic(self, msg):
459 459 """all ME and Task queue messages come through here, as well as
460 460 IOPub traffic."""
461 461 self.log.debug("monitor traffic: %r"%msg[:2])
462 462 switch = msg[0]
463 463 try:
464 464 idents, msg = self.session.feed_identities(msg[1:])
465 465 except ValueError:
466 466 idents=[]
467 467 if not idents:
468 468 self.log.error("Bad Monitor Message: %r"%msg)
469 469 return
470 470 handler = self.monitor_handlers.get(switch, None)
471 471 if handler is not None:
472 472 handler(idents, msg)
473 473 else:
474 474 self.log.error("Invalid monitor topic: %r"%switch)
475 475
476 476
477 477 def dispatch_query(self, msg):
478 478 """Route registration requests and queries from clients."""
479 479 try:
480 480 idents, msg = self.session.feed_identities(msg)
481 481 except ValueError:
482 482 idents = []
483 483 if not idents:
484 484 self.log.error("Bad Query Message: %r"%msg)
485 485 return
486 486 client_id = idents[0]
487 487 try:
488 msg = self.session.unpack_message(msg, content=True)
488 msg = self.session.unserialize(msg, content=True)
489 489 except Exception:
490 490 content = error.wrap_exception()
491 491 self.log.error("Bad Query Message: %r"%msg, exc_info=True)
492 492 self.session.send(self.query, "hub_error", ident=client_id,
493 493 content=content)
494 494 return
495 495 # print client_id, header, parent, content
496 496 #switch on message type:
497 msg_type = msg['msg_type']
497 msg_type = msg['header']['msg_type']
498 498 self.log.info("client::client %r requested %r"%(client_id, msg_type))
499 499 handler = self.query_handlers.get(msg_type, None)
500 500 try:
501 501 assert handler is not None, "Bad Message Type: %r"%msg_type
502 502 except:
503 503 content = error.wrap_exception()
504 504 self.log.error("Bad Message Type: %r"%msg_type, exc_info=True)
505 505 self.session.send(self.query, "hub_error", ident=client_id,
506 506 content=content)
507 507 return
508 508
509 509 else:
510 510 handler(idents, msg)
511 511
512 512 def dispatch_db(self, msg):
513 513 """"""
514 514 raise NotImplementedError
515 515
516 516 #---------------------------------------------------------------------------
517 517 # handler methods (1 per event)
518 518 #---------------------------------------------------------------------------
519 519
520 520 #----------------------- Heartbeat --------------------------------------
521 521
522 522 def handle_new_heart(self, heart):
523 523 """handler to attach to heartbeater.
524 524 Called when a new heart starts to beat.
525 525 Triggers completion of registration."""
526 526 self.log.debug("heartbeat::handle_new_heart(%r)"%heart)
527 527 if heart not in self.incoming_registrations:
528 528 self.log.info("heartbeat::ignoring new heart: %r"%heart)
529 529 else:
530 530 self.finish_registration(heart)
531 531
532 532
533 533 def handle_heart_failure(self, heart):
534 534 """handler to attach to heartbeater.
535 535 called when a previously registered heart fails to respond to beat request.
536 536 triggers unregistration"""
537 537 self.log.debug("heartbeat::handle_heart_failure(%r)"%heart)
538 538 eid = self.hearts.get(heart, None)
539 539 queue = self.engines[eid].queue
540 540 if eid is None:
541 541 self.log.info("heartbeat::ignoring heart failure %r"%heart)
542 542 else:
543 543 self.unregister_engine(heart, dict(content=dict(id=eid, queue=queue)))
544 544
545 545 #----------------------- MUX Queue Traffic ------------------------------
546 546
547 547 def save_queue_request(self, idents, msg):
548 548 if len(idents) < 2:
549 549 self.log.error("invalid identity prefix: %r"%idents)
550 550 return
551 551 queue_id, client_id = idents[:2]
552 552 try:
553 msg = self.session.unpack_message(msg)
553 msg = self.session.unserialize(msg)
554 554 except Exception:
555 555 self.log.error("queue::client %r sent invalid message to %r: %r"%(client_id, queue_id, msg), exc_info=True)
556 556 return
557 557
558 558 eid = self.by_ident.get(queue_id, None)
559 559 if eid is None:
560 560 self.log.error("queue::target %r not registered"%queue_id)
561 561 self.log.debug("queue:: valid are: %r"%(self.by_ident.keys()))
562 562 return
563 563 record = init_record(msg)
564 564 msg_id = record['msg_id']
565 565 # Unicode in records
566 566 record['engine_uuid'] = queue_id.decode('ascii')
567 567 record['client_uuid'] = client_id.decode('ascii')
568 568 record['queue'] = 'mux'
569 569
570 570 try:
571 571 # it's posible iopub arrived first:
572 572 existing = self.db.get_record(msg_id)
573 573 for key,evalue in existing.iteritems():
574 574 rvalue = record.get(key, None)
575 575 if evalue and rvalue and evalue != rvalue:
576 576 self.log.warn("conflicting initial state for record: %r:%r <%r> %r"%(msg_id, rvalue, key, evalue))
577 577 elif evalue and not rvalue:
578 578 record[key] = evalue
579 579 try:
580 580 self.db.update_record(msg_id, record)
581 581 except Exception:
582 582 self.log.error("DB Error updating record %r"%msg_id, exc_info=True)
583 583 except KeyError:
584 584 try:
585 585 self.db.add_record(msg_id, record)
586 586 except Exception:
587 587 self.log.error("DB Error adding record %r"%msg_id, exc_info=True)
588 588
589 589
590 590 self.pending.add(msg_id)
591 591 self.queues[eid].append(msg_id)
592 592
593 593 def save_queue_result(self, idents, msg):
594 594 if len(idents) < 2:
595 595 self.log.error("invalid identity prefix: %r"%idents)
596 596 return
597 597
598 598 client_id, queue_id = idents[:2]
599 599 try:
600 msg = self.session.unpack_message(msg)
600 msg = self.session.unserialize(msg)
601 601 except Exception:
602 602 self.log.error("queue::engine %r sent invalid message to %r: %r"%(
603 603 queue_id,client_id, msg), exc_info=True)
604 604 return
605 605
606 606 eid = self.by_ident.get(queue_id, None)
607 607 if eid is None:
608 608 self.log.error("queue::unknown engine %r is sending a reply: "%queue_id)
609 609 return
610 610
611 611 parent = msg['parent_header']
612 612 if not parent:
613 613 return
614 614 msg_id = parent['msg_id']
615 615 if msg_id in self.pending:
616 616 self.pending.remove(msg_id)
617 617 self.all_completed.add(msg_id)
618 618 self.queues[eid].remove(msg_id)
619 619 self.completed[eid].append(msg_id)
620 620 elif msg_id not in self.all_completed:
621 621 # it could be a result from a dead engine that died before delivering the
622 622 # result
623 623 self.log.warn("queue:: unknown msg finished %r"%msg_id)
624 624 return
625 625 # update record anyway, because the unregistration could have been premature
626 626 rheader = msg['header']
627 627 completed = rheader['date']
628 628 started = rheader.get('started', None)
629 629 result = {
630 630 'result_header' : rheader,
631 631 'result_content': msg['content'],
632 632 'started' : started,
633 633 'completed' : completed
634 634 }
635 635
636 636 result['result_buffers'] = msg['buffers']
637 637 try:
638 638 self.db.update_record(msg_id, result)
639 639 except Exception:
640 640 self.log.error("DB Error updating record %r"%msg_id, exc_info=True)
641 641
642 642
643 643 #--------------------- Task Queue Traffic ------------------------------
644 644
645 645 def save_task_request(self, idents, msg):
646 646 """Save the submission of a task."""
647 647 client_id = idents[0]
648 648
649 649 try:
650 msg = self.session.unpack_message(msg)
650 msg = self.session.unserialize(msg)
651 651 except Exception:
652 652 self.log.error("task::client %r sent invalid task message: %r"%(
653 653 client_id, msg), exc_info=True)
654 654 return
655 655 record = init_record(msg)
656 656
657 657 record['client_uuid'] = client_id
658 658 record['queue'] = 'task'
659 659 header = msg['header']
660 660 msg_id = header['msg_id']
661 661 self.pending.add(msg_id)
662 662 self.unassigned.add(msg_id)
663 663 try:
664 664 # it's posible iopub arrived first:
665 665 existing = self.db.get_record(msg_id)
666 666 if existing['resubmitted']:
667 667 for key in ('submitted', 'client_uuid', 'buffers'):
668 668 # don't clobber these keys on resubmit
669 669 # submitted and client_uuid should be different
670 670 # and buffers might be big, and shouldn't have changed
671 671 record.pop(key)
672 672 # still check content,header which should not change
673 673 # but are not expensive to compare as buffers
674 674
675 675 for key,evalue in existing.iteritems():
676 676 if key.endswith('buffers'):
677 677 # don't compare buffers
678 678 continue
679 679 rvalue = record.get(key, None)
680 680 if evalue and rvalue and evalue != rvalue:
681 681 self.log.warn("conflicting initial state for record: %r:%r <%r> %r"%(msg_id, rvalue, key, evalue))
682 682 elif evalue and not rvalue:
683 683 record[key] = evalue
684 684 try:
685 685 self.db.update_record(msg_id, record)
686 686 except Exception:
687 687 self.log.error("DB Error updating record %r"%msg_id, exc_info=True)
688 688 except KeyError:
689 689 try:
690 690 self.db.add_record(msg_id, record)
691 691 except Exception:
692 692 self.log.error("DB Error adding record %r"%msg_id, exc_info=True)
693 693 except Exception:
694 694 self.log.error("DB Error saving task request %r"%msg_id, exc_info=True)
695 695
696 696 def save_task_result(self, idents, msg):
697 697 """save the result of a completed task."""
698 698 client_id = idents[0]
699 699 try:
700 msg = self.session.unpack_message(msg)
700 msg = self.session.unserialize(msg)
701 701 except Exception:
702 702 self.log.error("task::invalid task result message send to %r: %r"%(
703 703 client_id, msg), exc_info=True)
704 704 return
705 705
706 706 parent = msg['parent_header']
707 707 if not parent:
708 708 # print msg
709 709 self.log.warn("Task %r had no parent!"%msg)
710 710 return
711 711 msg_id = parent['msg_id']
712 712 if msg_id in self.unassigned:
713 713 self.unassigned.remove(msg_id)
714 714
715 715 header = msg['header']
716 716 engine_uuid = header.get('engine', None)
717 717 eid = self.by_ident.get(engine_uuid, None)
718 718
719 719 if msg_id in self.pending:
720 720 self.pending.remove(msg_id)
721 721 self.all_completed.add(msg_id)
722 722 if eid is not None:
723 723 self.completed[eid].append(msg_id)
724 724 if msg_id in self.tasks[eid]:
725 725 self.tasks[eid].remove(msg_id)
726 726 completed = header['date']
727 727 started = header.get('started', None)
728 728 result = {
729 729 'result_header' : header,
730 730 'result_content': msg['content'],
731 731 'started' : started,
732 732 'completed' : completed,
733 733 'engine_uuid': engine_uuid
734 734 }
735 735
736 736 result['result_buffers'] = msg['buffers']
737 737 try:
738 738 self.db.update_record(msg_id, result)
739 739 except Exception:
740 740 self.log.error("DB Error saving task request %r"%msg_id, exc_info=True)
741 741
742 742 else:
743 743 self.log.debug("task::unknown task %r finished"%msg_id)
744 744
745 745 def save_task_destination(self, idents, msg):
746 746 try:
747 msg = self.session.unpack_message(msg, content=True)
747 msg = self.session.unserialize(msg, content=True)
748 748 except Exception:
749 749 self.log.error("task::invalid task tracking message", exc_info=True)
750 750 return
751 751 content = msg['content']
752 752 # print (content)
753 753 msg_id = content['msg_id']
754 754 engine_uuid = content['engine_id']
755 755 eid = self.by_ident[util.asbytes(engine_uuid)]
756 756
757 757 self.log.info("task::task %r arrived on %r"%(msg_id, eid))
758 758 if msg_id in self.unassigned:
759 759 self.unassigned.remove(msg_id)
760 760 # else:
761 761 # self.log.debug("task::task %r not listed as MIA?!"%(msg_id))
762 762
763 763 self.tasks[eid].append(msg_id)
764 764 # self.pending[msg_id][1].update(received=datetime.now(),engine=(eid,engine_uuid))
765 765 try:
766 766 self.db.update_record(msg_id, dict(engine_uuid=engine_uuid))
767 767 except Exception:
768 768 self.log.error("DB Error saving task destination %r"%msg_id, exc_info=True)
769 769
770 770
771 771 def mia_task_request(self, idents, msg):
772 772 raise NotImplementedError
773 773 client_id = idents[0]
774 774 # content = dict(mia=self.mia,status='ok')
775 775 # self.session.send('mia_reply', content=content, idents=client_id)
776 776
777 777
778 778 #--------------------- IOPub Traffic ------------------------------
779 779
780 780 def save_iopub_message(self, topics, msg):
781 781 """save an iopub message into the db"""
782 782 # print (topics)
783 783 try:
784 msg = self.session.unpack_message(msg, content=True)
784 msg = self.session.unserialize(msg, content=True)
785 785 except Exception:
786 786 self.log.error("iopub::invalid IOPub message", exc_info=True)
787 787 return
788 788
789 789 parent = msg['parent_header']
790 790 if not parent:
791 791 self.log.error("iopub::invalid IOPub message: %r"%msg)
792 792 return
793 793 msg_id = parent['msg_id']
794 msg_type = msg['msg_type']
794 msg_type = msg['header']['msg_type']
795 795 content = msg['content']
796 796
797 797 # ensure msg_id is in db
798 798 try:
799 799 rec = self.db.get_record(msg_id)
800 800 except KeyError:
801 801 rec = empty_record()
802 802 rec['msg_id'] = msg_id
803 803 self.db.add_record(msg_id, rec)
804 804 # stream
805 805 d = {}
806 806 if msg_type == 'stream':
807 807 name = content['name']
808 808 s = rec[name] or ''
809 809 d[name] = s + content['data']
810 810
811 811 elif msg_type == 'pyerr':
812 812 d['pyerr'] = content
813 813 elif msg_type == 'pyin':
814 814 d['pyin'] = content['code']
815 815 else:
816 816 d[msg_type] = content.get('data', '')
817 817
818 818 try:
819 819 self.db.update_record(msg_id, d)
820 820 except Exception:
821 821 self.log.error("DB Error saving iopub message %r"%msg_id, exc_info=True)
822 822
823 823
824 824
825 825 #-------------------------------------------------------------------------
826 826 # Registration requests
827 827 #-------------------------------------------------------------------------
828 828
829 829 def connection_request(self, client_id, msg):
830 830 """Reply with connection addresses for clients."""
831 831 self.log.info("client::client %r connected"%client_id)
832 832 content = dict(status='ok')
833 833 content.update(self.client_info)
834 834 jsonable = {}
835 835 for k,v in self.keytable.iteritems():
836 836 if v not in self.dead_engines:
837 837 jsonable[str(k)] = v.decode('ascii')
838 838 content['engines'] = jsonable
839 839 self.session.send(self.query, 'connection_reply', content, parent=msg, ident=client_id)
840 840
841 841 def register_engine(self, reg, msg):
842 842 """Register a new engine."""
843 843 content = msg['content']
844 844 try:
845 845 queue = util.asbytes(content['queue'])
846 846 except KeyError:
847 847 self.log.error("registration::queue not specified", exc_info=True)
848 848 return
849 849 heart = content.get('heartbeat', None)
850 850 if heart:
851 851 heart = util.asbytes(heart)
852 852 """register a new engine, and create the socket(s) necessary"""
853 853 eid = self._next_id
854 854 # print (eid, queue, reg, heart)
855 855
856 856 self.log.debug("registration::register_engine(%i, %r, %r, %r)"%(eid, queue, reg, heart))
857 857
858 858 content = dict(id=eid,status='ok')
859 859 content.update(self.engine_info)
860 860 # check if requesting available IDs:
861 861 if queue in self.by_ident:
862 862 try:
863 863 raise KeyError("queue_id %r in use"%queue)
864 864 except:
865 865 content = error.wrap_exception()
866 866 self.log.error("queue_id %r in use"%queue, exc_info=True)
867 867 elif heart in self.hearts: # need to check unique hearts?
868 868 try:
869 869 raise KeyError("heart_id %r in use"%heart)
870 870 except:
871 871 self.log.error("heart_id %r in use"%heart, exc_info=True)
872 872 content = error.wrap_exception()
873 873 else:
874 874 for h, pack in self.incoming_registrations.iteritems():
875 875 if heart == h:
876 876 try:
877 877 raise KeyError("heart_id %r in use"%heart)
878 878 except:
879 879 self.log.error("heart_id %r in use"%heart, exc_info=True)
880 880 content = error.wrap_exception()
881 881 break
882 882 elif queue == pack[1]:
883 883 try:
884 884 raise KeyError("queue_id %r in use"%queue)
885 885 except:
886 886 self.log.error("queue_id %r in use"%queue, exc_info=True)
887 887 content = error.wrap_exception()
888 888 break
889 889
890 890 msg = self.session.send(self.query, "registration_reply",
891 891 content=content,
892 892 ident=reg)
893 893
894 894 if content['status'] == 'ok':
895 895 if heart in self.heartmonitor.hearts:
896 896 # already beating
897 897 self.incoming_registrations[heart] = (eid,queue,reg[0],None)
898 898 self.finish_registration(heart)
899 899 else:
900 900 purge = lambda : self._purge_stalled_registration(heart)
901 901 dc = ioloop.DelayedCallback(purge, self.registration_timeout, self.loop)
902 902 dc.start()
903 903 self.incoming_registrations[heart] = (eid,queue,reg[0],dc)
904 904 else:
905 905 self.log.error("registration::registration %i failed: %r"%(eid, content['evalue']))
906 906 return eid
907 907
908 908 def unregister_engine(self, ident, msg):
909 909 """Unregister an engine that explicitly requested to leave."""
910 910 try:
911 911 eid = msg['content']['id']
912 912 except:
913 913 self.log.error("registration::bad engine id for unregistration: %r"%ident, exc_info=True)
914 914 return
915 915 self.log.info("registration::unregister_engine(%r)"%eid)
916 916 # print (eid)
917 917 uuid = self.keytable[eid]
918 918 content=dict(id=eid, queue=uuid.decode('ascii'))
919 919 self.dead_engines.add(uuid)
920 920 # self.ids.remove(eid)
921 921 # uuid = self.keytable.pop(eid)
922 922 #
923 923 # ec = self.engines.pop(eid)
924 924 # self.hearts.pop(ec.heartbeat)
925 925 # self.by_ident.pop(ec.queue)
926 926 # self.completed.pop(eid)
927 927 handleit = lambda : self._handle_stranded_msgs(eid, uuid)
928 928 dc = ioloop.DelayedCallback(handleit, self.registration_timeout, self.loop)
929 929 dc.start()
930 930 ############## TODO: HANDLE IT ################
931 931
932 932 if self.notifier:
933 933 self.session.send(self.notifier, "unregistration_notification", content=content)
934 934
935 935 def _handle_stranded_msgs(self, eid, uuid):
936 936 """Handle messages known to be on an engine when the engine unregisters.
937 937
938 938 It is possible that this will fire prematurely - that is, an engine will
939 939 go down after completing a result, and the client will be notified
940 940 that the result failed and later receive the actual result.
941 941 """
942 942
943 943 outstanding = self.queues[eid]
944 944
945 945 for msg_id in outstanding:
946 946 self.pending.remove(msg_id)
947 947 self.all_completed.add(msg_id)
948 948 try:
949 949 raise error.EngineError("Engine %r died while running task %r"%(eid, msg_id))
950 950 except:
951 951 content = error.wrap_exception()
952 952 # build a fake header:
953 953 header = {}
954 954 header['engine'] = uuid
955 955 header['date'] = datetime.now()
956 956 rec = dict(result_content=content, result_header=header, result_buffers=[])
957 957 rec['completed'] = header['date']
958 958 rec['engine_uuid'] = uuid
959 959 try:
960 960 self.db.update_record(msg_id, rec)
961 961 except Exception:
962 962 self.log.error("DB Error handling stranded msg %r"%msg_id, exc_info=True)
963 963
964 964
965 965 def finish_registration(self, heart):
966 966 """Second half of engine registration, called after our HeartMonitor
967 967 has received a beat from the Engine's Heart."""
968 968 try:
969 969 (eid,queue,reg,purge) = self.incoming_registrations.pop(heart)
970 970 except KeyError:
971 971 self.log.error("registration::tried to finish nonexistant registration", exc_info=True)
972 972 return
973 973 self.log.info("registration::finished registering engine %i:%r"%(eid,queue))
974 974 if purge is not None:
975 975 purge.stop()
976 976 control = queue
977 977 self.ids.add(eid)
978 978 self.keytable[eid] = queue
979 979 self.engines[eid] = EngineConnector(id=eid, queue=queue, registration=reg,
980 980 control=control, heartbeat=heart)
981 981 self.by_ident[queue] = eid
982 982 self.queues[eid] = list()
983 983 self.tasks[eid] = list()
984 984 self.completed[eid] = list()
985 985 self.hearts[heart] = eid
986 986 content = dict(id=eid, queue=self.engines[eid].queue.decode('ascii'))
987 987 if self.notifier:
988 988 self.session.send(self.notifier, "registration_notification", content=content)
989 989 self.log.info("engine::Engine Connected: %i"%eid)
990 990
991 991 def _purge_stalled_registration(self, heart):
992 992 if heart in self.incoming_registrations:
993 993 eid = self.incoming_registrations.pop(heart)[0]
994 994 self.log.info("registration::purging stalled registration: %i"%eid)
995 995 else:
996 996 pass
997 997
998 998 #-------------------------------------------------------------------------
999 999 # Client Requests
1000 1000 #-------------------------------------------------------------------------
1001 1001
1002 1002 def shutdown_request(self, client_id, msg):
1003 1003 """handle shutdown request."""
1004 1004 self.session.send(self.query, 'shutdown_reply', content={'status': 'ok'}, ident=client_id)
1005 1005 # also notify other clients of shutdown
1006 1006 self.session.send(self.notifier, 'shutdown_notice', content={'status': 'ok'})
1007 1007 dc = ioloop.DelayedCallback(lambda : self._shutdown(), 1000, self.loop)
1008 1008 dc.start()
1009 1009
1010 1010 def _shutdown(self):
1011 1011 self.log.info("hub::hub shutting down.")
1012 1012 time.sleep(0.1)
1013 1013 sys.exit(0)
1014 1014
1015 1015
1016 1016 def check_load(self, client_id, msg):
1017 1017 content = msg['content']
1018 1018 try:
1019 1019 targets = content['targets']
1020 1020 targets = self._validate_targets(targets)
1021 1021 except:
1022 1022 content = error.wrap_exception()
1023 1023 self.session.send(self.query, "hub_error",
1024 1024 content=content, ident=client_id)
1025 1025 return
1026 1026
1027 1027 content = dict(status='ok')
1028 1028 # loads = {}
1029 1029 for t in targets:
1030 1030 content[bytes(t)] = len(self.queues[t])+len(self.tasks[t])
1031 1031 self.session.send(self.query, "load_reply", content=content, ident=client_id)
1032 1032
1033 1033
1034 1034 def queue_status(self, client_id, msg):
1035 1035 """Return the Queue status of one or more targets.
1036 1036 if verbose: return the msg_ids
1037 1037 else: return len of each type.
1038 1038 keys: queue (pending MUX jobs)
1039 1039 tasks (pending Task jobs)
1040 1040 completed (finished jobs from both queues)"""
1041 1041 content = msg['content']
1042 1042 targets = content['targets']
1043 1043 try:
1044 1044 targets = self._validate_targets(targets)
1045 1045 except:
1046 1046 content = error.wrap_exception()
1047 1047 self.session.send(self.query, "hub_error",
1048 1048 content=content, ident=client_id)
1049 1049 return
1050 1050 verbose = content.get('verbose', False)
1051 1051 content = dict(status='ok')
1052 1052 for t in targets:
1053 1053 queue = self.queues[t]
1054 1054 completed = self.completed[t]
1055 1055 tasks = self.tasks[t]
1056 1056 if not verbose:
1057 1057 queue = len(queue)
1058 1058 completed = len(completed)
1059 1059 tasks = len(tasks)
1060 1060 content[str(t)] = {'queue': queue, 'completed': completed , 'tasks': tasks}
1061 1061 content['unassigned'] = list(self.unassigned) if verbose else len(self.unassigned)
1062 1062 # print (content)
1063 1063 self.session.send(self.query, "queue_reply", content=content, ident=client_id)
1064 1064
1065 1065 def purge_results(self, client_id, msg):
1066 1066 """Purge results from memory. This method is more valuable before we move
1067 1067 to a DB based message storage mechanism."""
1068 1068 content = msg['content']
1069 1069 self.log.info("Dropping records with %s", content)
1070 1070 msg_ids = content.get('msg_ids', [])
1071 1071 reply = dict(status='ok')
1072 1072 if msg_ids == 'all':
1073 1073 try:
1074 1074 self.db.drop_matching_records(dict(completed={'$ne':None}))
1075 1075 except Exception:
1076 1076 reply = error.wrap_exception()
1077 1077 else:
1078 1078 pending = filter(lambda m: m in self.pending, msg_ids)
1079 1079 if pending:
1080 1080 try:
1081 1081 raise IndexError("msg pending: %r"%pending[0])
1082 1082 except:
1083 1083 reply = error.wrap_exception()
1084 1084 else:
1085 1085 try:
1086 1086 self.db.drop_matching_records(dict(msg_id={'$in':msg_ids}))
1087 1087 except Exception:
1088 1088 reply = error.wrap_exception()
1089 1089
1090 1090 if reply['status'] == 'ok':
1091 1091 eids = content.get('engine_ids', [])
1092 1092 for eid in eids:
1093 1093 if eid not in self.engines:
1094 1094 try:
1095 1095 raise IndexError("No such engine: %i"%eid)
1096 1096 except:
1097 1097 reply = error.wrap_exception()
1098 1098 break
1099 1099 uid = self.engines[eid].queue
1100 1100 try:
1101 1101 self.db.drop_matching_records(dict(engine_uuid=uid, completed={'$ne':None}))
1102 1102 except Exception:
1103 1103 reply = error.wrap_exception()
1104 1104 break
1105 1105
1106 1106 self.session.send(self.query, 'purge_reply', content=reply, ident=client_id)
1107 1107
1108 1108 def resubmit_task(self, client_id, msg):
1109 1109 """Resubmit one or more tasks."""
1110 1110 def finish(reply):
1111 1111 self.session.send(self.query, 'resubmit_reply', content=reply, ident=client_id)
1112 1112
1113 1113 content = msg['content']
1114 1114 msg_ids = content['msg_ids']
1115 1115 reply = dict(status='ok')
1116 1116 try:
1117 1117 records = self.db.find_records({'msg_id' : {'$in' : msg_ids}}, keys=[
1118 1118 'header', 'content', 'buffers'])
1119 1119 except Exception:
1120 1120 self.log.error('db::db error finding tasks to resubmit', exc_info=True)
1121 1121 return finish(error.wrap_exception())
1122 1122
1123 1123 # validate msg_ids
1124 1124 found_ids = [ rec['msg_id'] for rec in records ]
1125 1125 invalid_ids = filter(lambda m: m in self.pending, found_ids)
1126 1126 if len(records) > len(msg_ids):
1127 1127 try:
1128 1128 raise RuntimeError("DB appears to be in an inconsistent state."
1129 1129 "More matching records were found than should exist")
1130 1130 except Exception:
1131 1131 return finish(error.wrap_exception())
1132 1132 elif len(records) < len(msg_ids):
1133 1133 missing = [ m for m in msg_ids if m not in found_ids ]
1134 1134 try:
1135 1135 raise KeyError("No such msg(s): %r"%missing)
1136 1136 except KeyError:
1137 1137 return finish(error.wrap_exception())
1138 1138 elif invalid_ids:
1139 1139 msg_id = invalid_ids[0]
1140 1140 try:
1141 1141 raise ValueError("Task %r appears to be inflight"%(msg_id))
1142 1142 except Exception:
1143 1143 return finish(error.wrap_exception())
1144 1144
1145 1145 # clear the existing records
1146 1146 now = datetime.now()
1147 1147 rec = empty_record()
1148 1148 map(rec.pop, ['msg_id', 'header', 'content', 'buffers', 'submitted'])
1149 1149 rec['resubmitted'] = now
1150 1150 rec['queue'] = 'task'
1151 1151 rec['client_uuid'] = client_id[0]
1152 1152 try:
1153 1153 for msg_id in msg_ids:
1154 1154 self.all_completed.discard(msg_id)
1155 1155 self.db.update_record(msg_id, rec)
1156 1156 except Exception:
1157 1157 self.log.error('db::db error upating record', exc_info=True)
1158 1158 reply = error.wrap_exception()
1159 1159 else:
1160 1160 # send the messages
1161 1161 for rec in records:
1162 1162 header = rec['header']
1163 1163 # include resubmitted in header to prevent digest collision
1164 1164 header['resubmitted'] = now
1165 1165 msg = self.session.msg(header['msg_type'])
1166 1166 msg['content'] = rec['content']
1167 1167 msg['header'] = header
1168 msg['msg_id'] = rec['msg_id']
1168 msg['header']['msg_id'] = rec['msg_id']
1169 1169 self.session.send(self.resubmit, msg, buffers=rec['buffers'])
1170 1170
1171 1171 finish(dict(status='ok'))
1172 1172
1173 1173
1174 1174 def _extract_record(self, rec):
1175 1175 """decompose a TaskRecord dict into subsection of reply for get_result"""
1176 1176 io_dict = {}
1177 1177 for key in 'pyin pyout pyerr stdout stderr'.split():
1178 1178 io_dict[key] = rec[key]
1179 1179 content = { 'result_content': rec['result_content'],
1180 1180 'header': rec['header'],
1181 1181 'result_header' : rec['result_header'],
1182 1182 'io' : io_dict,
1183 1183 }
1184 1184 if rec['result_buffers']:
1185 1185 buffers = map(bytes, rec['result_buffers'])
1186 1186 else:
1187 1187 buffers = []
1188 1188
1189 1189 return content, buffers
1190 1190
1191 1191 def get_results(self, client_id, msg):
1192 1192 """Get the result of 1 or more messages."""
1193 1193 content = msg['content']
1194 1194 msg_ids = sorted(set(content['msg_ids']))
1195 1195 statusonly = content.get('status_only', False)
1196 1196 pending = []
1197 1197 completed = []
1198 1198 content = dict(status='ok')
1199 1199 content['pending'] = pending
1200 1200 content['completed'] = completed
1201 1201 buffers = []
1202 1202 if not statusonly:
1203 1203 try:
1204 1204 matches = self.db.find_records(dict(msg_id={'$in':msg_ids}))
1205 1205 # turn match list into dict, for faster lookup
1206 1206 records = {}
1207 1207 for rec in matches:
1208 1208 records[rec['msg_id']] = rec
1209 1209 except Exception:
1210 1210 content = error.wrap_exception()
1211 1211 self.session.send(self.query, "result_reply", content=content,
1212 1212 parent=msg, ident=client_id)
1213 1213 return
1214 1214 else:
1215 1215 records = {}
1216 1216 for msg_id in msg_ids:
1217 1217 if msg_id in self.pending:
1218 1218 pending.append(msg_id)
1219 1219 elif msg_id in self.all_completed:
1220 1220 completed.append(msg_id)
1221 1221 if not statusonly:
1222 1222 c,bufs = self._extract_record(records[msg_id])
1223 1223 content[msg_id] = c
1224 1224 buffers.extend(bufs)
1225 1225 elif msg_id in records:
1226 1226 if rec['completed']:
1227 1227 completed.append(msg_id)
1228 1228 c,bufs = self._extract_record(records[msg_id])
1229 1229 content[msg_id] = c
1230 1230 buffers.extend(bufs)
1231 1231 else:
1232 1232 pending.append(msg_id)
1233 1233 else:
1234 1234 try:
1235 1235 raise KeyError('No such message: '+msg_id)
1236 1236 except:
1237 1237 content = error.wrap_exception()
1238 1238 break
1239 1239 self.session.send(self.query, "result_reply", content=content,
1240 1240 parent=msg, ident=client_id,
1241 1241 buffers=buffers)
1242 1242
1243 1243 def get_history(self, client_id, msg):
1244 1244 """Get a list of all msg_ids in our DB records"""
1245 1245 try:
1246 1246 msg_ids = self.db.get_history()
1247 1247 except Exception as e:
1248 1248 content = error.wrap_exception()
1249 1249 else:
1250 1250 content = dict(status='ok', history=msg_ids)
1251 1251
1252 1252 self.session.send(self.query, "history_reply", content=content,
1253 1253 parent=msg, ident=client_id)
1254 1254
1255 1255 def db_query(self, client_id, msg):
1256 1256 """Perform a raw query on the task record database."""
1257 1257 content = msg['content']
1258 1258 query = content.get('query', {})
1259 1259 keys = content.get('keys', None)
1260 1260 buffers = []
1261 1261 empty = list()
1262 1262 try:
1263 1263 records = self.db.find_records(query, keys)
1264 1264 except Exception as e:
1265 1265 content = error.wrap_exception()
1266 1266 else:
1267 1267 # extract buffers from reply content:
1268 1268 if keys is not None:
1269 1269 buffer_lens = [] if 'buffers' in keys else None
1270 1270 result_buffer_lens = [] if 'result_buffers' in keys else None
1271 1271 else:
1272 1272 buffer_lens = []
1273 1273 result_buffer_lens = []
1274 1274
1275 1275 for rec in records:
1276 1276 # buffers may be None, so double check
1277 1277 if buffer_lens is not None:
1278 1278 b = rec.pop('buffers', empty) or empty
1279 1279 buffer_lens.append(len(b))
1280 1280 buffers.extend(b)
1281 1281 if result_buffer_lens is not None:
1282 1282 rb = rec.pop('result_buffers', empty) or empty
1283 1283 result_buffer_lens.append(len(rb))
1284 1284 buffers.extend(rb)
1285 1285 content = dict(status='ok', records=records, buffer_lens=buffer_lens,
1286 1286 result_buffer_lens=result_buffer_lens)
1287 1287 # self.log.debug (content)
1288 1288 self.session.send(self.query, "db_reply", content=content,
1289 1289 parent=msg, ident=client_id,
1290 1290 buffers=buffers)
1291 1291
@@ -1,714 +1,714 b''
1 1 """The Python scheduler for rich scheduling.
2 2
3 3 The Pure ZMQ scheduler does not allow routing schemes other than LRU,
4 4 nor does it check msg_id DAG dependencies. For those, a slightly slower
5 5 Python Scheduler exists.
6 6
7 7 Authors:
8 8
9 9 * Min RK
10 10 """
11 11 #-----------------------------------------------------------------------------
12 12 # Copyright (C) 2010-2011 The IPython Development Team
13 13 #
14 14 # Distributed under the terms of the BSD License. The full license is in
15 15 # the file COPYING, distributed as part of this software.
16 16 #-----------------------------------------------------------------------------
17 17
18 18 #----------------------------------------------------------------------
19 19 # Imports
20 20 #----------------------------------------------------------------------
21 21
22 22 from __future__ import print_function
23 23
24 24 import logging
25 25 import sys
26 26
27 27 from datetime import datetime, timedelta
28 28 from random import randint, random
29 29 from types import FunctionType
30 30
31 31 try:
32 32 import numpy
33 33 except ImportError:
34 34 numpy = None
35 35
36 36 import zmq
37 37 from zmq.eventloop import ioloop, zmqstream
38 38
39 39 # local imports
40 40 from IPython.external.decorator import decorator
41 41 from IPython.config.application import Application
42 42 from IPython.config.loader import Config
43 43 from IPython.utils.traitlets import Instance, Dict, List, Set, Int, Enum, CBytes
44 44
45 45 from IPython.parallel import error
46 46 from IPython.parallel.factory import SessionFactory
47 47 from IPython.parallel.util import connect_logger, local_logger, asbytes
48 48
49 49 from .dependency import Dependency
50 50
51 51 @decorator
52 52 def logged(f,self,*args,**kwargs):
53 53 # print ("#--------------------")
54 54 self.log.debug("scheduler::%s(*%s,**%s)", f.func_name, args, kwargs)
55 55 # print ("#--")
56 56 return f(self,*args, **kwargs)
57 57
58 58 #----------------------------------------------------------------------
59 59 # Chooser functions
60 60 #----------------------------------------------------------------------
61 61
62 62 def plainrandom(loads):
63 63 """Plain random pick."""
64 64 n = len(loads)
65 65 return randint(0,n-1)
66 66
67 67 def lru(loads):
68 68 """Always pick the front of the line.
69 69
70 70 The content of `loads` is ignored.
71 71
72 72 Assumes LRU ordering of loads, with oldest first.
73 73 """
74 74 return 0
75 75
76 76 def twobin(loads):
77 77 """Pick two at random, use the LRU of the two.
78 78
79 79 The content of loads is ignored.
80 80
81 81 Assumes LRU ordering of loads, with oldest first.
82 82 """
83 83 n = len(loads)
84 84 a = randint(0,n-1)
85 85 b = randint(0,n-1)
86 86 return min(a,b)
87 87
88 88 def weighted(loads):
89 89 """Pick two at random using inverse load as weight.
90 90
91 91 Return the less loaded of the two.
92 92 """
93 93 # weight 0 a million times more than 1:
94 94 weights = 1./(1e-6+numpy.array(loads))
95 95 sums = weights.cumsum()
96 96 t = sums[-1]
97 97 x = random()*t
98 98 y = random()*t
99 99 idx = 0
100 100 idy = 0
101 101 while sums[idx] < x:
102 102 idx += 1
103 103 while sums[idy] < y:
104 104 idy += 1
105 105 if weights[idy] > weights[idx]:
106 106 return idy
107 107 else:
108 108 return idx
109 109
110 110 def leastload(loads):
111 111 """Always choose the lowest load.
112 112
113 113 If the lowest load occurs more than once, the first
114 114 occurance will be used. If loads has LRU ordering, this means
115 115 the LRU of those with the lowest load is chosen.
116 116 """
117 117 return loads.index(min(loads))
118 118
119 119 #---------------------------------------------------------------------
120 120 # Classes
121 121 #---------------------------------------------------------------------
122 122 # store empty default dependency:
123 123 MET = Dependency([])
124 124
125 125 class TaskScheduler(SessionFactory):
126 126 """Python TaskScheduler object.
127 127
128 128 This is the simplest object that supports msg_id based
129 129 DAG dependencies. *Only* task msg_ids are checked, not
130 130 msg_ids of jobs submitted via the MUX queue.
131 131
132 132 """
133 133
134 134 hwm = Int(0, config=True, shortname='hwm',
135 135 help="""specify the High Water Mark (HWM) for the downstream
136 136 socket in the Task scheduler. This is the maximum number
137 137 of allowed outstanding tasks on each engine."""
138 138 )
139 139 scheme_name = Enum(('leastload', 'pure', 'lru', 'plainrandom', 'weighted', 'twobin'),
140 140 'leastload', config=True, shortname='scheme', allow_none=False,
141 141 help="""select the task scheduler scheme [default: Python LRU]
142 142 Options are: 'pure', 'lru', 'plainrandom', 'weighted', 'twobin','leastload'"""
143 143 )
144 144 def _scheme_name_changed(self, old, new):
145 145 self.log.debug("Using scheme %r"%new)
146 146 self.scheme = globals()[new]
147 147
148 148 # input arguments:
149 149 scheme = Instance(FunctionType) # function for determining the destination
150 150 def _scheme_default(self):
151 151 return leastload
152 152 client_stream = Instance(zmqstream.ZMQStream) # client-facing stream
153 153 engine_stream = Instance(zmqstream.ZMQStream) # engine-facing stream
154 154 notifier_stream = Instance(zmqstream.ZMQStream) # hub-facing sub stream
155 155 mon_stream = Instance(zmqstream.ZMQStream) # hub-facing pub stream
156 156
157 157 # internals:
158 158 graph = Dict() # dict by msg_id of [ msg_ids that depend on key ]
159 159 retries = Dict() # dict by msg_id of retries remaining (non-neg ints)
160 160 # waiting = List() # list of msg_ids ready to run, but haven't due to HWM
161 161 depending = Dict() # dict by msg_id of (msg_id, raw_msg, after, follow)
162 162 pending = Dict() # dict by engine_uuid of submitted tasks
163 163 completed = Dict() # dict by engine_uuid of completed tasks
164 164 failed = Dict() # dict by engine_uuid of failed tasks
165 165 destinations = Dict() # dict by msg_id of engine_uuids where jobs ran (reverse of completed+failed)
166 166 clients = Dict() # dict by msg_id for who submitted the task
167 167 targets = List() # list of target IDENTs
168 168 loads = List() # list of engine loads
169 169 # full = Set() # set of IDENTs that have HWM outstanding tasks
170 170 all_completed = Set() # set of all completed tasks
171 171 all_failed = Set() # set of all failed tasks
172 172 all_done = Set() # set of all finished tasks=union(completed,failed)
173 173 all_ids = Set() # set of all submitted task IDs
174 174 blacklist = Dict() # dict by msg_id of locations where a job has encountered UnmetDependency
175 175 auditor = Instance('zmq.eventloop.ioloop.PeriodicCallback')
176 176
177 177 ident = CBytes() # ZMQ identity. This should just be self.session.session
178 178 # but ensure Bytes
179 179 def _ident_default(self):
180 180 return asbytes(self.session.session)
181 181
182 182 def start(self):
183 183 self.engine_stream.on_recv(self.dispatch_result, copy=False)
184 184 self._notification_handlers = dict(
185 185 registration_notification = self._register_engine,
186 186 unregistration_notification = self._unregister_engine
187 187 )
188 188 self.notifier_stream.on_recv(self.dispatch_notification)
189 189 self.auditor = ioloop.PeriodicCallback(self.audit_timeouts, 2e3, self.loop) # 1 Hz
190 190 self.auditor.start()
191 191 self.log.info("Scheduler started [%s]"%self.scheme_name)
192 192
193 193 def resume_receiving(self):
194 194 """Resume accepting jobs."""
195 195 self.client_stream.on_recv(self.dispatch_submission, copy=False)
196 196
197 197 def stop_receiving(self):
198 198 """Stop accepting jobs while there are no engines.
199 199 Leave them in the ZMQ queue."""
200 200 self.client_stream.on_recv(None)
201 201
202 202 #-----------------------------------------------------------------------
203 203 # [Un]Registration Handling
204 204 #-----------------------------------------------------------------------
205 205
206 206 def dispatch_notification(self, msg):
207 207 """dispatch register/unregister events."""
208 208 try:
209 209 idents,msg = self.session.feed_identities(msg)
210 210 except ValueError:
211 211 self.log.warn("task::Invalid Message: %r",msg)
212 212 return
213 213 try:
214 msg = self.session.unpack_message(msg)
214 msg = self.session.unserialize(msg)
215 215 except ValueError:
216 216 self.log.warn("task::Unauthorized message from: %r"%idents)
217 217 return
218 218
219 msg_type = msg['msg_type']
219 msg_type = msg['header']['msg_type']
220 220
221 221 handler = self._notification_handlers.get(msg_type, None)
222 222 if handler is None:
223 223 self.log.error("Unhandled message type: %r"%msg_type)
224 224 else:
225 225 try:
226 226 handler(asbytes(msg['content']['queue']))
227 227 except Exception:
228 228 self.log.error("task::Invalid notification msg: %r",msg)
229 229
230 230 def _register_engine(self, uid):
231 231 """New engine with ident `uid` became available."""
232 232 # head of the line:
233 233 self.targets.insert(0,uid)
234 234 self.loads.insert(0,0)
235 235
236 236 # initialize sets
237 237 self.completed[uid] = set()
238 238 self.failed[uid] = set()
239 239 self.pending[uid] = {}
240 240 if len(self.targets) == 1:
241 241 self.resume_receiving()
242 242 # rescan the graph:
243 243 self.update_graph(None)
244 244
245 245 def _unregister_engine(self, uid):
246 246 """Existing engine with ident `uid` became unavailable."""
247 247 if len(self.targets) == 1:
248 248 # this was our only engine
249 249 self.stop_receiving()
250 250
251 251 # handle any potentially finished tasks:
252 252 self.engine_stream.flush()
253 253
254 254 # don't pop destinations, because they might be used later
255 255 # map(self.destinations.pop, self.completed.pop(uid))
256 256 # map(self.destinations.pop, self.failed.pop(uid))
257 257
258 258 # prevent this engine from receiving work
259 259 idx = self.targets.index(uid)
260 260 self.targets.pop(idx)
261 261 self.loads.pop(idx)
262 262
263 263 # wait 5 seconds before cleaning up pending jobs, since the results might
264 264 # still be incoming
265 265 if self.pending[uid]:
266 266 dc = ioloop.DelayedCallback(lambda : self.handle_stranded_tasks(uid), 5000, self.loop)
267 267 dc.start()
268 268 else:
269 269 self.completed.pop(uid)
270 270 self.failed.pop(uid)
271 271
272 272
273 273 def handle_stranded_tasks(self, engine):
274 274 """Deal with jobs resident in an engine that died."""
275 275 lost = self.pending[engine]
276 276 for msg_id in lost.keys():
277 277 if msg_id not in self.pending[engine]:
278 278 # prevent double-handling of messages
279 279 continue
280 280
281 281 raw_msg = lost[msg_id][0]
282 282 idents,msg = self.session.feed_identities(raw_msg, copy=False)
283 283 parent = self.session.unpack(msg[1].bytes)
284 284 idents = [engine, idents[0]]
285 285
286 286 # build fake error reply
287 287 try:
288 288 raise error.EngineError("Engine %r died while running task %r"%(engine, msg_id))
289 289 except:
290 290 content = error.wrap_exception()
291 291 msg = self.session.msg('apply_reply', content, parent=parent, subheader={'status':'error'})
292 292 raw_reply = map(zmq.Message, self.session.serialize(msg, ident=idents))
293 293 # and dispatch it
294 294 self.dispatch_result(raw_reply)
295 295
296 296 # finally scrub completed/failed lists
297 297 self.completed.pop(engine)
298 298 self.failed.pop(engine)
299 299
300 300
301 301 #-----------------------------------------------------------------------
302 302 # Job Submission
303 303 #-----------------------------------------------------------------------
304 304 def dispatch_submission(self, raw_msg):
305 305 """Dispatch job submission to appropriate handlers."""
306 306 # ensure targets up to date:
307 307 self.notifier_stream.flush()
308 308 try:
309 309 idents, msg = self.session.feed_identities(raw_msg, copy=False)
310 msg = self.session.unpack_message(msg, content=False, copy=False)
310 msg = self.session.unserialize(msg, content=False, copy=False)
311 311 except Exception:
312 312 self.log.error("task::Invaid task msg: %r"%raw_msg, exc_info=True)
313 313 return
314 314
315 315
316 316 # send to monitor
317 317 self.mon_stream.send_multipart([b'intask']+raw_msg, copy=False)
318 318
319 319 header = msg['header']
320 320 msg_id = header['msg_id']
321 321 self.all_ids.add(msg_id)
322 322
323 323 # get targets as a set of bytes objects
324 324 # from a list of unicode objects
325 325 targets = header.get('targets', [])
326 326 targets = map(asbytes, targets)
327 327 targets = set(targets)
328 328
329 329 retries = header.get('retries', 0)
330 330 self.retries[msg_id] = retries
331 331
332 332 # time dependencies
333 333 after = header.get('after', None)
334 334 if after:
335 335 after = Dependency(after)
336 336 if after.all:
337 337 if after.success:
338 338 after = Dependency(after.difference(self.all_completed),
339 339 success=after.success,
340 340 failure=after.failure,
341 341 all=after.all,
342 342 )
343 343 if after.failure:
344 344 after = Dependency(after.difference(self.all_failed),
345 345 success=after.success,
346 346 failure=after.failure,
347 347 all=after.all,
348 348 )
349 349 if after.check(self.all_completed, self.all_failed):
350 350 # recast as empty set, if `after` already met,
351 351 # to prevent unnecessary set comparisons
352 352 after = MET
353 353 else:
354 354 after = MET
355 355
356 356 # location dependencies
357 357 follow = Dependency(header.get('follow', []))
358 358
359 359 # turn timeouts into datetime objects:
360 360 timeout = header.get('timeout', None)
361 361 if timeout:
362 362 timeout = datetime.now() + timedelta(0,timeout,0)
363 363
364 364 args = [raw_msg, targets, after, follow, timeout]
365 365
366 366 # validate and reduce dependencies:
367 367 for dep in after,follow:
368 368 if not dep: # empty dependency
369 369 continue
370 370 # check valid:
371 371 if msg_id in dep or dep.difference(self.all_ids):
372 372 self.depending[msg_id] = args
373 373 return self.fail_unreachable(msg_id, error.InvalidDependency)
374 374 # check if unreachable:
375 375 if dep.unreachable(self.all_completed, self.all_failed):
376 376 self.depending[msg_id] = args
377 377 return self.fail_unreachable(msg_id)
378 378
379 379 if after.check(self.all_completed, self.all_failed):
380 380 # time deps already met, try to run
381 381 if not self.maybe_run(msg_id, *args):
382 382 # can't run yet
383 383 if msg_id not in self.all_failed:
384 384 # could have failed as unreachable
385 385 self.save_unmet(msg_id, *args)
386 386 else:
387 387 self.save_unmet(msg_id, *args)
388 388
389 389 def audit_timeouts(self):
390 390 """Audit all waiting tasks for expired timeouts."""
391 391 now = datetime.now()
392 392 for msg_id in self.depending.keys():
393 393 # must recheck, in case one failure cascaded to another:
394 394 if msg_id in self.depending:
395 395 raw,after,targets,follow,timeout = self.depending[msg_id]
396 396 if timeout and timeout < now:
397 397 self.fail_unreachable(msg_id, error.TaskTimeout)
398 398
399 399 def fail_unreachable(self, msg_id, why=error.ImpossibleDependency):
400 400 """a task has become unreachable, send a reply with an ImpossibleDependency
401 401 error."""
402 402 if msg_id not in self.depending:
403 403 self.log.error("msg %r already failed!", msg_id)
404 404 return
405 405 raw_msg,targets,after,follow,timeout = self.depending.pop(msg_id)
406 406 for mid in follow.union(after):
407 407 if mid in self.graph:
408 408 self.graph[mid].remove(msg_id)
409 409
410 410 # FIXME: unpacking a message I've already unpacked, but didn't save:
411 411 idents,msg = self.session.feed_identities(raw_msg, copy=False)
412 412 header = self.session.unpack(msg[1].bytes)
413 413
414 414 try:
415 415 raise why()
416 416 except:
417 417 content = error.wrap_exception()
418 418
419 419 self.all_done.add(msg_id)
420 420 self.all_failed.add(msg_id)
421 421
422 422 msg = self.session.send(self.client_stream, 'apply_reply', content,
423 423 parent=header, ident=idents)
424 424 self.session.send(self.mon_stream, msg, ident=[b'outtask']+idents)
425 425
426 426 self.update_graph(msg_id, success=False)
427 427
428 428 def maybe_run(self, msg_id, raw_msg, targets, after, follow, timeout):
429 429 """check location dependencies, and run if they are met."""
430 430 blacklist = self.blacklist.setdefault(msg_id, set())
431 431 if follow or targets or blacklist or self.hwm:
432 432 # we need a can_run filter
433 433 def can_run(idx):
434 434 # check hwm
435 435 if self.hwm and self.loads[idx] == self.hwm:
436 436 return False
437 437 target = self.targets[idx]
438 438 # check blacklist
439 439 if target in blacklist:
440 440 return False
441 441 # check targets
442 442 if targets and target not in targets:
443 443 return False
444 444 # check follow
445 445 return follow.check(self.completed[target], self.failed[target])
446 446
447 447 indices = filter(can_run, range(len(self.targets)))
448 448
449 449 if not indices:
450 450 # couldn't run
451 451 if follow.all:
452 452 # check follow for impossibility
453 453 dests = set()
454 454 relevant = set()
455 455 if follow.success:
456 456 relevant = self.all_completed
457 457 if follow.failure:
458 458 relevant = relevant.union(self.all_failed)
459 459 for m in follow.intersection(relevant):
460 460 dests.add(self.destinations[m])
461 461 if len(dests) > 1:
462 462 self.depending[msg_id] = (raw_msg, targets, after, follow, timeout)
463 463 self.fail_unreachable(msg_id)
464 464 return False
465 465 if targets:
466 466 # check blacklist+targets for impossibility
467 467 targets.difference_update(blacklist)
468 468 if not targets or not targets.intersection(self.targets):
469 469 self.depending[msg_id] = (raw_msg, targets, after, follow, timeout)
470 470 self.fail_unreachable(msg_id)
471 471 return False
472 472 return False
473 473 else:
474 474 indices = None
475 475
476 476 self.submit_task(msg_id, raw_msg, targets, follow, timeout, indices)
477 477 return True
478 478
479 479 def save_unmet(self, msg_id, raw_msg, targets, after, follow, timeout):
480 480 """Save a message for later submission when its dependencies are met."""
481 481 self.depending[msg_id] = [raw_msg,targets,after,follow,timeout]
482 482 # track the ids in follow or after, but not those already finished
483 483 for dep_id in after.union(follow).difference(self.all_done):
484 484 if dep_id not in self.graph:
485 485 self.graph[dep_id] = set()
486 486 self.graph[dep_id].add(msg_id)
487 487
488 488 def submit_task(self, msg_id, raw_msg, targets, follow, timeout, indices=None):
489 489 """Submit a task to any of a subset of our targets."""
490 490 if indices:
491 491 loads = [self.loads[i] for i in indices]
492 492 else:
493 493 loads = self.loads
494 494 idx = self.scheme(loads)
495 495 if indices:
496 496 idx = indices[idx]
497 497 target = self.targets[idx]
498 498 # print (target, map(str, msg[:3]))
499 499 # send job to the engine
500 500 self.engine_stream.send(target, flags=zmq.SNDMORE, copy=False)
501 501 self.engine_stream.send_multipart(raw_msg, copy=False)
502 502 # update load
503 503 self.add_job(idx)
504 504 self.pending[target][msg_id] = (raw_msg, targets, MET, follow, timeout)
505 505 # notify Hub
506 506 content = dict(msg_id=msg_id, engine_id=target.decode('ascii'))
507 507 self.session.send(self.mon_stream, 'task_destination', content=content,
508 508 ident=[b'tracktask',self.ident])
509 509
510 510
511 511 #-----------------------------------------------------------------------
512 512 # Result Handling
513 513 #-----------------------------------------------------------------------
514 514 def dispatch_result(self, raw_msg):
515 515 """dispatch method for result replies"""
516 516 try:
517 517 idents,msg = self.session.feed_identities(raw_msg, copy=False)
518 msg = self.session.unpack_message(msg, content=False, copy=False)
518 msg = self.session.unserialize(msg, content=False, copy=False)
519 519 engine = idents[0]
520 520 try:
521 521 idx = self.targets.index(engine)
522 522 except ValueError:
523 523 pass # skip load-update for dead engines
524 524 else:
525 525 self.finish_job(idx)
526 526 except Exception:
527 527 self.log.error("task::Invaid result: %r", raw_msg, exc_info=True)
528 528 return
529 529
530 530 header = msg['header']
531 531 parent = msg['parent_header']
532 532 if header.get('dependencies_met', True):
533 533 success = (header['status'] == 'ok')
534 534 msg_id = parent['msg_id']
535 535 retries = self.retries[msg_id]
536 536 if not success and retries > 0:
537 537 # failed
538 538 self.retries[msg_id] = retries - 1
539 539 self.handle_unmet_dependency(idents, parent)
540 540 else:
541 541 del self.retries[msg_id]
542 542 # relay to client and update graph
543 543 self.handle_result(idents, parent, raw_msg, success)
544 544 # send to Hub monitor
545 545 self.mon_stream.send_multipart([b'outtask']+raw_msg, copy=False)
546 546 else:
547 547 self.handle_unmet_dependency(idents, parent)
548 548
549 549 def handle_result(self, idents, parent, raw_msg, success=True):
550 550 """handle a real task result, either success or failure"""
551 551 # first, relay result to client
552 552 engine = idents[0]
553 553 client = idents[1]
554 554 # swap_ids for XREP-XREP mirror
555 555 raw_msg[:2] = [client,engine]
556 556 # print (map(str, raw_msg[:4]))
557 557 self.client_stream.send_multipart(raw_msg, copy=False)
558 558 # now, update our data structures
559 559 msg_id = parent['msg_id']
560 560 self.blacklist.pop(msg_id, None)
561 561 self.pending[engine].pop(msg_id)
562 562 if success:
563 563 self.completed[engine].add(msg_id)
564 564 self.all_completed.add(msg_id)
565 565 else:
566 566 self.failed[engine].add(msg_id)
567 567 self.all_failed.add(msg_id)
568 568 self.all_done.add(msg_id)
569 569 self.destinations[msg_id] = engine
570 570
571 571 self.update_graph(msg_id, success)
572 572
573 573 def handle_unmet_dependency(self, idents, parent):
574 574 """handle an unmet dependency"""
575 575 engine = idents[0]
576 576 msg_id = parent['msg_id']
577 577
578 578 if msg_id not in self.blacklist:
579 579 self.blacklist[msg_id] = set()
580 580 self.blacklist[msg_id].add(engine)
581 581
582 582 args = self.pending[engine].pop(msg_id)
583 583 raw,targets,after,follow,timeout = args
584 584
585 585 if self.blacklist[msg_id] == targets:
586 586 self.depending[msg_id] = args
587 587 self.fail_unreachable(msg_id)
588 588 elif not self.maybe_run(msg_id, *args):
589 589 # resubmit failed
590 590 if msg_id not in self.all_failed:
591 591 # put it back in our dependency tree
592 592 self.save_unmet(msg_id, *args)
593 593
594 594 if self.hwm:
595 595 try:
596 596 idx = self.targets.index(engine)
597 597 except ValueError:
598 598 pass # skip load-update for dead engines
599 599 else:
600 600 if self.loads[idx] == self.hwm-1:
601 601 self.update_graph(None)
602 602
603 603
604 604
605 605 def update_graph(self, dep_id=None, success=True):
606 606 """dep_id just finished. Update our dependency
607 607 graph and submit any jobs that just became runable.
608 608
609 609 Called with dep_id=None to update entire graph for hwm, but without finishing
610 610 a task.
611 611 """
612 612 # print ("\n\n***********")
613 613 # pprint (dep_id)
614 614 # pprint (self.graph)
615 615 # pprint (self.depending)
616 616 # pprint (self.all_completed)
617 617 # pprint (self.all_failed)
618 618 # print ("\n\n***********\n\n")
619 619 # update any jobs that depended on the dependency
620 620 jobs = self.graph.pop(dep_id, [])
621 621
622 622 # recheck *all* jobs if
623 623 # a) we have HWM and an engine just become no longer full
624 624 # or b) dep_id was given as None
625 625 if dep_id is None or self.hwm and any( [ load==self.hwm-1 for load in self.loads ]):
626 626 jobs = self.depending.keys()
627 627
628 628 for msg_id in jobs:
629 629 raw_msg, targets, after, follow, timeout = self.depending[msg_id]
630 630
631 631 if after.unreachable(self.all_completed, self.all_failed)\
632 632 or follow.unreachable(self.all_completed, self.all_failed):
633 633 self.fail_unreachable(msg_id)
634 634
635 635 elif after.check(self.all_completed, self.all_failed): # time deps met, maybe run
636 636 if self.maybe_run(msg_id, raw_msg, targets, MET, follow, timeout):
637 637
638 638 self.depending.pop(msg_id)
639 639 for mid in follow.union(after):
640 640 if mid in self.graph:
641 641 self.graph[mid].remove(msg_id)
642 642
643 643 #----------------------------------------------------------------------
644 644 # methods to be overridden by subclasses
645 645 #----------------------------------------------------------------------
646 646
647 647 def add_job(self, idx):
648 648 """Called after self.targets[idx] just got the job with header.
649 649 Override with subclasses. The default ordering is simple LRU.
650 650 The default loads are the number of outstanding jobs."""
651 651 self.loads[idx] += 1
652 652 for lis in (self.targets, self.loads):
653 653 lis.append(lis.pop(idx))
654 654
655 655
656 656 def finish_job(self, idx):
657 657 """Called after self.targets[idx] just finished a job.
658 658 Override with subclasses."""
659 659 self.loads[idx] -= 1
660 660
661 661
662 662
663 663 def launch_scheduler(in_addr, out_addr, mon_addr, not_addr, config=None,
664 664 logname='root', log_url=None, loglevel=logging.DEBUG,
665 665 identity=b'task', in_thread=False):
666 666
667 667 ZMQStream = zmqstream.ZMQStream
668 668
669 669 if config:
670 670 # unwrap dict back into Config
671 671 config = Config(config)
672 672
673 673 if in_thread:
674 674 # use instance() to get the same Context/Loop as our parent
675 675 ctx = zmq.Context.instance()
676 676 loop = ioloop.IOLoop.instance()
677 677 else:
678 678 # in a process, don't use instance()
679 679 # for safety with multiprocessing
680 680 ctx = zmq.Context()
681 681 loop = ioloop.IOLoop()
682 682 ins = ZMQStream(ctx.socket(zmq.XREP),loop)
683 683 ins.setsockopt(zmq.IDENTITY, identity)
684 684 ins.bind(in_addr)
685 685
686 686 outs = ZMQStream(ctx.socket(zmq.XREP),loop)
687 687 outs.setsockopt(zmq.IDENTITY, identity)
688 688 outs.bind(out_addr)
689 689 mons = zmqstream.ZMQStream(ctx.socket(zmq.PUB),loop)
690 690 mons.connect(mon_addr)
691 691 nots = zmqstream.ZMQStream(ctx.socket(zmq.SUB),loop)
692 692 nots.setsockopt(zmq.SUBSCRIBE, b'')
693 693 nots.connect(not_addr)
694 694
695 695 # setup logging.
696 696 if in_thread:
697 697 log = Application.instance().log
698 698 else:
699 699 if log_url:
700 700 log = connect_logger(logname, ctx, log_url, root="scheduler", loglevel=loglevel)
701 701 else:
702 702 log = local_logger(logname, loglevel)
703 703
704 704 scheduler = TaskScheduler(client_stream=ins, engine_stream=outs,
705 705 mon_stream=mons, notifier_stream=nots,
706 706 loop=loop, log=log,
707 707 config=config)
708 708 scheduler.start()
709 709 if not in_thread:
710 710 try:
711 711 loop.start()
712 712 except KeyboardInterrupt:
713 713 print ("interrupted, exiting...", file=sys.__stderr__)
714 714
@@ -1,174 +1,174 b''
1 1 #!/usr/bin/env python
2 2 """A simple engine that talks to a controller over 0MQ.
3 3 it handles registration, etc. and launches a kernel
4 4 connected to the Controller's Schedulers.
5 5
6 6 Authors:
7 7
8 8 * Min RK
9 9 """
10 10 #-----------------------------------------------------------------------------
11 11 # Copyright (C) 2010-2011 The IPython Development Team
12 12 #
13 13 # Distributed under the terms of the BSD License. The full license is in
14 14 # the file COPYING, distributed as part of this software.
15 15 #-----------------------------------------------------------------------------
16 16
17 17 from __future__ import print_function
18 18
19 19 import sys
20 20 import time
21 21
22 22 import zmq
23 23 from zmq.eventloop import ioloop, zmqstream
24 24
25 25 # internal
26 26 from IPython.utils.traitlets import Instance, Dict, Int, Type, CFloat, Unicode, CBytes
27 27 # from IPython.utils.localinterfaces import LOCALHOST
28 28
29 29 from IPython.parallel.controller.heartmonitor import Heart
30 30 from IPython.parallel.factory import RegistrationFactory
31 31 from IPython.parallel.util import disambiguate_url, asbytes
32 32
33 33 from IPython.zmq.session import Message
34 34
35 35 from .streamkernel import Kernel
36 36
37 37 class EngineFactory(RegistrationFactory):
38 38 """IPython engine"""
39 39
40 40 # configurables:
41 41 out_stream_factory=Type('IPython.zmq.iostream.OutStream', config=True,
42 42 help="""The OutStream for handling stdout/err.
43 43 Typically 'IPython.zmq.iostream.OutStream'""")
44 44 display_hook_factory=Type('IPython.zmq.displayhook.ZMQDisplayHook', config=True,
45 45 help="""The class for handling displayhook.
46 46 Typically 'IPython.zmq.displayhook.ZMQDisplayHook'""")
47 47 location=Unicode(config=True,
48 48 help="""The location (an IP address) of the controller. This is
49 49 used for disambiguating URLs, to determine whether
50 50 loopback should be used to connect or the public address.""")
51 51 timeout=CFloat(2,config=True,
52 52 help="""The time (in seconds) to wait for the Controller to respond
53 53 to registration requests before giving up.""")
54 54
55 55 # not configurable:
56 56 user_ns=Dict()
57 57 id=Int(allow_none=True)
58 58 registrar=Instance('zmq.eventloop.zmqstream.ZMQStream')
59 59 kernel=Instance(Kernel)
60 60
61 61 bident = CBytes()
62 62 ident = Unicode()
63 63 def _ident_changed(self, name, old, new):
64 64 self.bident = asbytes(new)
65 65
66 66
67 67 def __init__(self, **kwargs):
68 68 super(EngineFactory, self).__init__(**kwargs)
69 69 self.ident = self.session.session
70 70 ctx = self.context
71 71
72 72 reg = ctx.socket(zmq.XREQ)
73 73 reg.setsockopt(zmq.IDENTITY, self.bident)
74 74 reg.connect(self.url)
75 75 self.registrar = zmqstream.ZMQStream(reg, self.loop)
76 76
77 77 def register(self):
78 78 """send the registration_request"""
79 79
80 80 self.log.info("Registering with controller at %s"%self.url)
81 81 content = dict(queue=self.ident, heartbeat=self.ident, control=self.ident)
82 82 self.registrar.on_recv(self.complete_registration)
83 83 # print (self.session.key)
84 84 self.session.send(self.registrar, "registration_request",content=content)
85 85
86 86 def complete_registration(self, msg):
87 87 # print msg
88 88 self._abort_dc.stop()
89 89 ctx = self.context
90 90 loop = self.loop
91 91 identity = self.bident
92 92 idents,msg = self.session.feed_identities(msg)
93 msg = Message(self.session.unpack_message(msg))
93 msg = Message(self.session.unserialize(msg))
94 94
95 95 if msg.content.status == 'ok':
96 96 self.id = int(msg.content.id)
97 97
98 98 # create Shell Streams (MUX, Task, etc.):
99 99 queue_addr = msg.content.mux
100 100 shell_addrs = [ str(queue_addr) ]
101 101 task_addr = msg.content.task
102 102 if task_addr:
103 103 shell_addrs.append(str(task_addr))
104 104
105 105 # Uncomment this to go back to two-socket model
106 106 # shell_streams = []
107 107 # for addr in shell_addrs:
108 108 # stream = zmqstream.ZMQStream(ctx.socket(zmq.XREP), loop)
109 109 # stream.setsockopt(zmq.IDENTITY, identity)
110 110 # stream.connect(disambiguate_url(addr, self.location))
111 111 # shell_streams.append(stream)
112 112
113 113 # Now use only one shell stream for mux and tasks
114 114 stream = zmqstream.ZMQStream(ctx.socket(zmq.XREP), loop)
115 115 stream.setsockopt(zmq.IDENTITY, identity)
116 116 shell_streams = [stream]
117 117 for addr in shell_addrs:
118 118 stream.connect(disambiguate_url(addr, self.location))
119 119 # end single stream-socket
120 120
121 121 # control stream:
122 122 control_addr = str(msg.content.control)
123 123 control_stream = zmqstream.ZMQStream(ctx.socket(zmq.XREP), loop)
124 124 control_stream.setsockopt(zmq.IDENTITY, identity)
125 125 control_stream.connect(disambiguate_url(control_addr, self.location))
126 126
127 127 # create iopub stream:
128 128 iopub_addr = msg.content.iopub
129 129 iopub_stream = zmqstream.ZMQStream(ctx.socket(zmq.PUB), loop)
130 130 iopub_stream.setsockopt(zmq.IDENTITY, identity)
131 131 iopub_stream.connect(disambiguate_url(iopub_addr, self.location))
132 132
133 133 # launch heartbeat
134 134 hb_addrs = msg.content.heartbeat
135 135 # print (hb_addrs)
136 136
137 137 # # Redirect input streams and set a display hook.
138 138 if self.out_stream_factory:
139 139 sys.stdout = self.out_stream_factory(self.session, iopub_stream, u'stdout')
140 140 sys.stdout.topic = 'engine.%i.stdout'%self.id
141 141 sys.stderr = self.out_stream_factory(self.session, iopub_stream, u'stderr')
142 142 sys.stderr.topic = 'engine.%i.stderr'%self.id
143 143 if self.display_hook_factory:
144 144 sys.displayhook = self.display_hook_factory(self.session, iopub_stream)
145 145 sys.displayhook.topic = 'engine.%i.pyout'%self.id
146 146
147 147 self.kernel = Kernel(config=self.config, int_id=self.id, ident=self.ident, session=self.session,
148 148 control_stream=control_stream, shell_streams=shell_streams, iopub_stream=iopub_stream,
149 149 loop=loop, user_ns = self.user_ns, log=self.log)
150 150 self.kernel.start()
151 151 hb_addrs = [ disambiguate_url(addr, self.location) for addr in hb_addrs ]
152 152 heart = Heart(*map(str, hb_addrs), heart_id=identity)
153 153 heart.start()
154 154
155 155
156 156 else:
157 157 self.log.fatal("Registration Failed: %s"%msg)
158 158 raise Exception("Registration Failed: %s"%msg)
159 159
160 160 self.log.info("Completed registration with id %i"%self.id)
161 161
162 162
163 163 def abort(self):
164 164 self.log.fatal("Registration timed out after %.1f seconds"%self.timeout)
165 165 self.session.send(self.registrar, "unregistration_request", content=dict(id=self.id))
166 166 time.sleep(1)
167 167 sys.exit(255)
168 168
169 169 def start(self):
170 170 dc = ioloop.DelayedCallback(self.register, 0, self.loop)
171 171 dc.start()
172 172 self._abort_dc = ioloop.DelayedCallback(self.abort, self.timeout*1000, self.loop)
173 173 self._abort_dc.start()
174 174
@@ -1,230 +1,230 b''
1 1 """KernelStarter class that intercepts Control Queue messages, and handles process management.
2 2
3 3 Authors:
4 4
5 5 * Min RK
6 6 """
7 7 #-----------------------------------------------------------------------------
8 8 # Copyright (C) 2010-2011 The IPython Development Team
9 9 #
10 10 # Distributed under the terms of the BSD License. The full license is in
11 11 # the file COPYING, distributed as part of this software.
12 12 #-----------------------------------------------------------------------------
13 13
14 14 from zmq.eventloop import ioloop
15 15
16 16 from IPython.zmq.session import Session
17 17
18 18 class KernelStarter(object):
19 19 """Object for resetting/killing the Kernel."""
20 20
21 21
22 22 def __init__(self, session, upstream, downstream, *kernel_args, **kernel_kwargs):
23 23 self.session = session
24 24 self.upstream = upstream
25 25 self.downstream = downstream
26 26 self.kernel_args = kernel_args
27 27 self.kernel_kwargs = kernel_kwargs
28 28 self.handlers = {}
29 29 for method in 'shutdown_request shutdown_reply'.split():
30 30 self.handlers[method] = getattr(self, method)
31 31
32 32 def start(self):
33 33 self.upstream.on_recv(self.dispatch_request)
34 34 self.downstream.on_recv(self.dispatch_reply)
35 35
36 36 #--------------------------------------------------------------------------
37 37 # Dispatch methods
38 38 #--------------------------------------------------------------------------
39 39
40 40 def dispatch_request(self, raw_msg):
41 41 idents, msg = self.session.feed_identities()
42 42 try:
43 msg = self.session.unpack_message(msg, content=False)
43 msg = self.session.unserialize(msg, content=False)
44 44 except:
45 45 print ("bad msg: %s"%msg)
46 46
47 msgtype = msg['msg_type']
47 msgtype = msg['header']['msg_type']
48 48 handler = self.handlers.get(msgtype, None)
49 49 if handler is None:
50 50 self.downstream.send_multipart(raw_msg, copy=False)
51 51 else:
52 52 handler(msg)
53 53
54 54 def dispatch_reply(self, raw_msg):
55 55 idents, msg = self.session.feed_identities()
56 56 try:
57 msg = self.session.unpack_message(msg, content=False)
57 msg = self.session.unserialize(msg, content=False)
58 58 except:
59 59 print ("bad msg: %s"%msg)
60 60
61 msgtype = msg['msg_type']
61 msgtype = msg['header']['msg_type']
62 62 handler = self.handlers.get(msgtype, None)
63 63 if handler is None:
64 64 self.upstream.send_multipart(raw_msg, copy=False)
65 65 else:
66 66 handler(msg)
67 67
68 68 #--------------------------------------------------------------------------
69 69 # Handlers
70 70 #--------------------------------------------------------------------------
71 71
72 72 def shutdown_request(self, msg):
73 73 """"""
74 74 self.downstream.send_multipart(msg)
75 75
76 76 #--------------------------------------------------------------------------
77 77 # Kernel process management methods, from KernelManager:
78 78 #--------------------------------------------------------------------------
79 79
80 80 def _check_local(addr):
81 81 if isinstance(addr, tuple):
82 82 addr = addr[0]
83 83 return addr in LOCAL_IPS
84 84
85 85 def start_kernel(self, **kw):
86 86 """Starts a kernel process and configures the manager to use it.
87 87
88 88 If random ports (port=0) are being used, this method must be called
89 89 before the channels are created.
90 90
91 91 Parameters:
92 92 -----------
93 93 ipython : bool, optional (default True)
94 94 Whether to use an IPython kernel instead of a plain Python kernel.
95 95 """
96 96 self.kernel = Process(target=make_kernel, args=self.kernel_args,
97 97 kwargs=self.kernel_kwargs)
98 98
99 99 def shutdown_kernel(self, restart=False):
100 100 """ Attempts to the stop the kernel process cleanly. If the kernel
101 101 cannot be stopped, it is killed, if possible.
102 102 """
103 103 # FIXME: Shutdown does not work on Windows due to ZMQ errors!
104 104 if sys.platform == 'win32':
105 105 self.kill_kernel()
106 106 return
107 107
108 108 # Don't send any additional kernel kill messages immediately, to give
109 109 # the kernel a chance to properly execute shutdown actions. Wait for at
110 110 # most 1s, checking every 0.1s.
111 111 self.xreq_channel.shutdown(restart=restart)
112 112 for i in range(10):
113 113 if self.is_alive:
114 114 time.sleep(0.1)
115 115 else:
116 116 break
117 117 else:
118 118 # OK, we've waited long enough.
119 119 if self.has_kernel:
120 120 self.kill_kernel()
121 121
122 122 def restart_kernel(self, now=False):
123 123 """Restarts a kernel with the same arguments that were used to launch
124 124 it. If the old kernel was launched with random ports, the same ports
125 125 will be used for the new kernel.
126 126
127 127 Parameters
128 128 ----------
129 129 now : bool, optional
130 130 If True, the kernel is forcefully restarted *immediately*, without
131 131 having a chance to do any cleanup action. Otherwise the kernel is
132 132 given 1s to clean up before a forceful restart is issued.
133 133
134 134 In all cases the kernel is restarted, the only difference is whether
135 135 it is given a chance to perform a clean shutdown or not.
136 136 """
137 137 if self._launch_args is None:
138 138 raise RuntimeError("Cannot restart the kernel. "
139 139 "No previous call to 'start_kernel'.")
140 140 else:
141 141 if self.has_kernel:
142 142 if now:
143 143 self.kill_kernel()
144 144 else:
145 145 self.shutdown_kernel(restart=True)
146 146 self.start_kernel(**self._launch_args)
147 147
148 148 # FIXME: Messages get dropped in Windows due to probable ZMQ bug
149 149 # unless there is some delay here.
150 150 if sys.platform == 'win32':
151 151 time.sleep(0.2)
152 152
153 153 @property
154 154 def has_kernel(self):
155 155 """Returns whether a kernel process has been specified for the kernel
156 156 manager.
157 157 """
158 158 return self.kernel is not None
159 159
160 160 def kill_kernel(self):
161 161 """ Kill the running kernel. """
162 162 if self.has_kernel:
163 163 # Pause the heart beat channel if it exists.
164 164 if self._hb_channel is not None:
165 165 self._hb_channel.pause()
166 166
167 167 # Attempt to kill the kernel.
168 168 try:
169 169 self.kernel.kill()
170 170 except OSError, e:
171 171 # In Windows, we will get an Access Denied error if the process
172 172 # has already terminated. Ignore it.
173 173 if not (sys.platform == 'win32' and e.winerror == 5):
174 174 raise
175 175 self.kernel = None
176 176 else:
177 177 raise RuntimeError("Cannot kill kernel. No kernel is running!")
178 178
179 179 def interrupt_kernel(self):
180 180 """ Interrupts the kernel. Unlike ``signal_kernel``, this operation is
181 181 well supported on all platforms.
182 182 """
183 183 if self.has_kernel:
184 184 if sys.platform == 'win32':
185 185 from parentpoller import ParentPollerWindows as Poller
186 186 Poller.send_interrupt(self.kernel.win32_interrupt_event)
187 187 else:
188 188 self.kernel.send_signal(signal.SIGINT)
189 189 else:
190 190 raise RuntimeError("Cannot interrupt kernel. No kernel is running!")
191 191
192 192 def signal_kernel(self, signum):
193 193 """ Sends a signal to the kernel. Note that since only SIGTERM is
194 194 supported on Windows, this function is only useful on Unix systems.
195 195 """
196 196 if self.has_kernel:
197 197 self.kernel.send_signal(signum)
198 198 else:
199 199 raise RuntimeError("Cannot signal kernel. No kernel is running!")
200 200
201 201 @property
202 202 def is_alive(self):
203 203 """Is the kernel process still running?"""
204 204 # FIXME: not using a heartbeat means this method is broken for any
205 205 # remote kernel, it's only capable of handling local kernels.
206 206 if self.has_kernel:
207 207 if self.kernel.poll() is None:
208 208 return True
209 209 else:
210 210 return False
211 211 else:
212 212 # We didn't start the kernel with this KernelManager so we don't
213 213 # know if it is running. We should use a heartbeat for this case.
214 214 return True
215 215
216 216
217 217 def make_starter(up_addr, down_addr, *args, **kwargs):
218 218 """entry point function for launching a kernelstarter in a subprocess"""
219 219 loop = ioloop.IOLoop.instance()
220 220 ctx = zmq.Context()
221 221 session = Session()
222 222 upstream = zmqstream.ZMQStream(ctx.socket(zmq.XREQ),loop)
223 223 upstream.connect(up_addr)
224 224 downstream = zmqstream.ZMQStream(ctx.socket(zmq.XREQ),loop)
225 225 downstream.connect(down_addr)
226 226
227 227 starter = KernelStarter(session, upstream, downstream, *args, **kwargs)
228 228 starter.start()
229 229 loop.start()
230 No newline at end of file
230
@@ -1,438 +1,440 b''
1 1 #!/usr/bin/env python
2 2 """
3 3 Kernel adapted from kernel.py to use ZMQ Streams
4 4
5 5 Authors:
6 6
7 7 * Min RK
8 8 * Brian Granger
9 9 * Fernando Perez
10 10 * Evan Patterson
11 11 """
12 12 #-----------------------------------------------------------------------------
13 13 # Copyright (C) 2010-2011 The IPython Development Team
14 14 #
15 15 # Distributed under the terms of the BSD License. The full license is in
16 16 # the file COPYING, distributed as part of this software.
17 17 #-----------------------------------------------------------------------------
18 18
19 19 #-----------------------------------------------------------------------------
20 20 # Imports
21 21 #-----------------------------------------------------------------------------
22 22
23 23 # Standard library imports.
24 24 from __future__ import print_function
25 25
26 26 import sys
27 27 import time
28 28
29 29 from code import CommandCompiler
30 30 from datetime import datetime
31 31 from pprint import pprint
32 32
33 33 # System library imports.
34 34 import zmq
35 35 from zmq.eventloop import ioloop, zmqstream
36 36
37 37 # Local imports.
38 38 from IPython.utils.traitlets import Instance, List, Int, Dict, Set, Unicode, CBytes
39 39 from IPython.zmq.completer import KernelCompleter
40 40
41 41 from IPython.parallel.error import wrap_exception
42 42 from IPython.parallel.factory import SessionFactory
43 43 from IPython.parallel.util import serialize_object, unpack_apply_message, asbytes
44 44
45 45 def printer(*args):
46 46 pprint(args, stream=sys.__stdout__)
47 47
48 48
49 49 class _Passer(zmqstream.ZMQStream):
50 50 """Empty class that implements `send()` that does nothing.
51 51
52 52 Subclass ZMQStream for Session typechecking
53 53
54 54 """
55 55 def __init__(self, *args, **kwargs):
56 56 pass
57 57
58 58 def send(self, *args, **kwargs):
59 59 pass
60 60 send_multipart = send
61 61
62 62
63 63 #-----------------------------------------------------------------------------
64 64 # Main kernel class
65 65 #-----------------------------------------------------------------------------
66 66
67 67 class Kernel(SessionFactory):
68 68
69 69 #---------------------------------------------------------------------------
70 70 # Kernel interface
71 71 #---------------------------------------------------------------------------
72 72
73 73 # kwargs:
74 74 exec_lines = List(Unicode, config=True,
75 75 help="List of lines to execute")
76 76
77 77 # identities:
78 78 int_id = Int(-1)
79 79 bident = CBytes()
80 80 ident = Unicode()
81 81 def _ident_changed(self, name, old, new):
82 82 self.bident = asbytes(new)
83 83
84 84 user_ns = Dict(config=True, help="""Set the user's namespace of the Kernel""")
85 85
86 86 control_stream = Instance(zmqstream.ZMQStream)
87 87 task_stream = Instance(zmqstream.ZMQStream)
88 88 iopub_stream = Instance(zmqstream.ZMQStream)
89 89 client = Instance('IPython.parallel.Client')
90 90
91 91 # internals
92 92 shell_streams = List()
93 93 compiler = Instance(CommandCompiler, (), {})
94 94 completer = Instance(KernelCompleter)
95 95
96 96 aborted = Set()
97 97 shell_handlers = Dict()
98 98 control_handlers = Dict()
99 99
100 100 def _set_prefix(self):
101 101 self.prefix = "engine.%s"%self.int_id
102 102
103 103 def _connect_completer(self):
104 104 self.completer = KernelCompleter(self.user_ns)
105 105
106 106 def __init__(self, **kwargs):
107 107 super(Kernel, self).__init__(**kwargs)
108 108 self._set_prefix()
109 109 self._connect_completer()
110 110
111 111 self.on_trait_change(self._set_prefix, 'id')
112 112 self.on_trait_change(self._connect_completer, 'user_ns')
113 113
114 114 # Build dict of handlers for message types
115 115 for msg_type in ['execute_request', 'complete_request', 'apply_request',
116 116 'clear_request']:
117 117 self.shell_handlers[msg_type] = getattr(self, msg_type)
118 118
119 119 for msg_type in ['shutdown_request', 'abort_request']+self.shell_handlers.keys():
120 120 self.control_handlers[msg_type] = getattr(self, msg_type)
121 121
122 122 self._initial_exec_lines()
123 123
124 124 def _wrap_exception(self, method=None):
125 125 e_info = dict(engine_uuid=self.ident, engine_id=self.int_id, method=method)
126 126 content=wrap_exception(e_info)
127 127 return content
128 128
129 129 def _initial_exec_lines(self):
130 130 s = _Passer()
131 131 content = dict(silent=True, user_variable=[],user_expressions=[])
132 132 for line in self.exec_lines:
133 133 self.log.debug("executing initialization: %s"%line)
134 134 content.update({'code':line})
135 135 msg = self.session.msg('execute_request', content)
136 136 self.execute_request(s, [], msg)
137 137
138 138
139 139 #-------------------- control handlers -----------------------------
140 140 def abort_queues(self):
141 141 for stream in self.shell_streams:
142 142 if stream:
143 143 self.abort_queue(stream)
144 144
145 145 def abort_queue(self, stream):
146 146 while True:
147 147 idents,msg = self.session.recv(stream, zmq.NOBLOCK, content=True)
148 148 if msg is None:
149 149 return
150 150
151 151 self.log.info("Aborting:")
152 152 self.log.info(str(msg))
153 msg_type = msg['msg_type']
153 msg_type = msg['header']['msg_type']
154 154 reply_type = msg_type.split('_')[0] + '_reply'
155 155 # reply_msg = self.session.msg(reply_type, {'status' : 'aborted'}, msg)
156 156 # self.reply_socket.send(ident,zmq.SNDMORE)
157 157 # self.reply_socket.send_json(reply_msg)
158 158 reply_msg = self.session.send(stream, reply_type,
159 159 content={'status' : 'aborted'}, parent=msg, ident=idents)
160 160 self.log.debug(str(reply_msg))
161 161 # We need to wait a bit for requests to come in. This can probably
162 162 # be set shorter for true asynchronous clients.
163 163 time.sleep(0.05)
164 164
165 165 def abort_request(self, stream, ident, parent):
166 166 """abort a specifig msg by id"""
167 167 msg_ids = parent['content'].get('msg_ids', None)
168 168 if isinstance(msg_ids, basestring):
169 169 msg_ids = [msg_ids]
170 170 if not msg_ids:
171 171 self.abort_queues()
172 172 for mid in msg_ids:
173 173 self.aborted.add(str(mid))
174 174
175 175 content = dict(status='ok')
176 176 reply_msg = self.session.send(stream, 'abort_reply', content=content,
177 177 parent=parent, ident=ident)
178 178 self.log.debug(str(reply_msg))
179 179
180 180 def shutdown_request(self, stream, ident, parent):
181 181 """kill ourself. This should really be handled in an external process"""
182 182 try:
183 183 self.abort_queues()
184 184 except:
185 185 content = self._wrap_exception('shutdown')
186 186 else:
187 187 content = dict(parent['content'])
188 188 content['status'] = 'ok'
189 189 msg = self.session.send(stream, 'shutdown_reply',
190 190 content=content, parent=parent, ident=ident)
191 191 self.log.debug(str(msg))
192 192 dc = ioloop.DelayedCallback(lambda : sys.exit(0), 1000, self.loop)
193 193 dc.start()
194 194
195 195 def dispatch_control(self, msg):
196 196 idents,msg = self.session.feed_identities(msg, copy=False)
197 197 try:
198 msg = self.session.unpack_message(msg, content=True, copy=False)
198 msg = self.session.unserialize(msg, content=True, copy=False)
199 199 except:
200 200 self.log.error("Invalid Message", exc_info=True)
201 201 return
202 202 else:
203 203 self.log.debug("Control received, %s", msg)
204 204
205 205 header = msg['header']
206 206 msg_id = header['msg_id']
207
208 handler = self.control_handlers.get(msg['msg_type'], None)
207 msg_type = header['msg_type']
208
209 handler = self.control_handlers.get(msg_type, None)
209 210 if handler is None:
210 self.log.error("UNKNOWN CONTROL MESSAGE TYPE: %r"%msg['msg_type'])
211 self.log.error("UNKNOWN CONTROL MESSAGE TYPE: %r"%msg_type)
211 212 else:
212 213 handler(self.control_stream, idents, msg)
213 214
214 215
215 216 #-------------------- queue helpers ------------------------------
216 217
217 218 def check_dependencies(self, dependencies):
218 219 if not dependencies:
219 220 return True
220 221 if len(dependencies) == 2 and dependencies[0] in 'any all'.split():
221 222 anyorall = dependencies[0]
222 223 dependencies = dependencies[1]
223 224 else:
224 225 anyorall = 'all'
225 226 results = self.client.get_results(dependencies,status_only=True)
226 227 if results['status'] != 'ok':
227 228 return False
228 229
229 230 if anyorall == 'any':
230 231 if not results['completed']:
231 232 return False
232 233 else:
233 234 if results['pending']:
234 235 return False
235 236
236 237 return True
237 238
238 239 def check_aborted(self, msg_id):
239 240 return msg_id in self.aborted
240 241
241 242 #-------------------- queue handlers -----------------------------
242 243
243 244 def clear_request(self, stream, idents, parent):
244 245 """Clear our namespace."""
245 246 self.user_ns = {}
246 247 msg = self.session.send(stream, 'clear_reply', ident=idents, parent=parent,
247 248 content = dict(status='ok'))
248 249 self._initial_exec_lines()
249 250
250 251 def execute_request(self, stream, ident, parent):
251 252 self.log.debug('execute request %s'%parent)
252 253 try:
253 254 code = parent[u'content'][u'code']
254 255 except:
255 256 self.log.error("Got bad msg: %s"%parent, exc_info=True)
256 257 return
257 258 self.session.send(self.iopub_stream, u'pyin', {u'code':code},parent=parent,
258 259 ident=asbytes('%s.pyin'%self.prefix))
259 260 started = datetime.now()
260 261 try:
261 262 comp_code = self.compiler(code, '<zmq-kernel>')
262 263 # allow for not overriding displayhook
263 264 if hasattr(sys.displayhook, 'set_parent'):
264 265 sys.displayhook.set_parent(parent)
265 266 sys.stdout.set_parent(parent)
266 267 sys.stderr.set_parent(parent)
267 268 exec comp_code in self.user_ns, self.user_ns
268 269 except:
269 270 exc_content = self._wrap_exception('execute')
270 271 # exc_msg = self.session.msg(u'pyerr', exc_content, parent)
271 272 self.session.send(self.iopub_stream, u'pyerr', exc_content, parent=parent,
272 273 ident=asbytes('%s.pyerr'%self.prefix))
273 274 reply_content = exc_content
274 275 else:
275 276 reply_content = {'status' : 'ok'}
276 277
277 278 reply_msg = self.session.send(stream, u'execute_reply', reply_content, parent=parent,
278 279 ident=ident, subheader = dict(started=started))
279 280 self.log.debug(str(reply_msg))
280 281 if reply_msg['content']['status'] == u'error':
281 282 self.abort_queues()
282 283
283 284 def complete_request(self, stream, ident, parent):
284 285 matches = {'matches' : self.complete(parent),
285 286 'status' : 'ok'}
286 287 completion_msg = self.session.send(stream, 'complete_reply',
287 288 matches, parent, ident)
288 289 # print >> sys.__stdout__, completion_msg
289 290
290 291 def complete(self, msg):
291 292 return self.completer.complete(msg.content.line, msg.content.text)
292 293
293 294 def apply_request(self, stream, ident, parent):
294 295 # flush previous reply, so this request won't block it
295 296 stream.flush(zmq.POLLOUT)
296 297 try:
297 298 content = parent[u'content']
298 299 bufs = parent[u'buffers']
299 300 msg_id = parent['header']['msg_id']
300 301 # bound = parent['header'].get('bound', False)
301 302 except:
302 303 self.log.error("Got bad msg: %s"%parent, exc_info=True)
303 304 return
304 305 # pyin_msg = self.session.msg(u'pyin',{u'code':code}, parent=parent)
305 306 # self.iopub_stream.send(pyin_msg)
306 307 # self.session.send(self.iopub_stream, u'pyin', {u'code':code},parent=parent)
307 308 sub = {'dependencies_met' : True, 'engine' : self.ident,
308 309 'started': datetime.now()}
309 310 try:
310 311 # allow for not overriding displayhook
311 312 if hasattr(sys.displayhook, 'set_parent'):
312 313 sys.displayhook.set_parent(parent)
313 314 sys.stdout.set_parent(parent)
314 315 sys.stderr.set_parent(parent)
315 316 # exec "f(*args,**kwargs)" in self.user_ns, self.user_ns
316 317 working = self.user_ns
317 318 # suffix =
318 319 prefix = "_"+str(msg_id).replace("-","")+"_"
319 320
320 321 f,args,kwargs = unpack_apply_message(bufs, working, copy=False)
321 322 # if bound:
322 323 # bound_ns = Namespace(working)
323 324 # args = [bound_ns]+list(args)
324 325
325 326 fname = getattr(f, '__name__', 'f')
326 327
327 328 fname = prefix+"f"
328 329 argname = prefix+"args"
329 330 kwargname = prefix+"kwargs"
330 331 resultname = prefix+"result"
331 332
332 333 ns = { fname : f, argname : args, kwargname : kwargs , resultname : None }
333 334 # print ns
334 335 working.update(ns)
335 336 code = "%s=%s(*%s,**%s)"%(resultname, fname, argname, kwargname)
336 337 try:
337 338 exec code in working,working
338 339 result = working.get(resultname)
339 340 finally:
340 341 for key in ns.iterkeys():
341 342 working.pop(key)
342 343 # if bound:
343 344 # working.update(bound_ns)
344 345
345 346 packed_result,buf = serialize_object(result)
346 347 result_buf = [packed_result]+buf
347 348 except:
348 349 exc_content = self._wrap_exception('apply')
349 350 # exc_msg = self.session.msg(u'pyerr', exc_content, parent)
350 351 self.session.send(self.iopub_stream, u'pyerr', exc_content, parent=parent,
351 352 ident=asbytes('%s.pyerr'%self.prefix))
352 353 reply_content = exc_content
353 354 result_buf = []
354 355
355 356 if exc_content['ename'] == 'UnmetDependency':
356 357 sub['dependencies_met'] = False
357 358 else:
358 359 reply_content = {'status' : 'ok'}
359 360
360 361 # put 'ok'/'error' status in header, for scheduler introspection:
361 362 sub['status'] = reply_content['status']
362 363
363 364 reply_msg = self.session.send(stream, u'apply_reply', reply_content,
364 365 parent=parent, ident=ident,buffers=result_buf, subheader=sub)
365 366
366 367 # flush i/o
367 368 # should this be before reply_msg is sent, like in the single-kernel code,
368 369 # or should nothing get in the way of real results?
369 370 sys.stdout.flush()
370 371 sys.stderr.flush()
371 372
372 373 def dispatch_queue(self, stream, msg):
373 374 self.control_stream.flush()
374 375 idents,msg = self.session.feed_identities(msg, copy=False)
375 376 try:
376 msg = self.session.unpack_message(msg, content=True, copy=False)
377 msg = self.session.unserialize(msg, content=True, copy=False)
377 378 except:
378 379 self.log.error("Invalid Message", exc_info=True)
379 380 return
380 381 else:
381 382 self.log.debug("Message received, %s", msg)
382 383
383 384
384 385 header = msg['header']
385 386 msg_id = header['msg_id']
387 msg_type = msg['header']['msg_type']
386 388 if self.check_aborted(msg_id):
387 389 self.aborted.remove(msg_id)
388 390 # is it safe to assume a msg_id will not be resubmitted?
389 reply_type = msg['msg_type'].split('_')[0] + '_reply'
391 reply_type = msg_type.split('_')[0] + '_reply'
390 392 status = {'status' : 'aborted'}
391 393 reply_msg = self.session.send(stream, reply_type, subheader=status,
392 394 content=status, parent=msg, ident=idents)
393 395 return
394 handler = self.shell_handlers.get(msg['msg_type'], None)
396 handler = self.shell_handlers.get(msg_type, None)
395 397 if handler is None:
396 self.log.error("UNKNOWN MESSAGE TYPE: %r"%msg['msg_type'])
398 self.log.error("UNKNOWN MESSAGE TYPE: %r"%msg_type)
397 399 else:
398 400 handler(stream, idents, msg)
399 401
400 402 def start(self):
401 403 #### stream mode:
402 404 if self.control_stream:
403 405 self.control_stream.on_recv(self.dispatch_control, copy=False)
404 406 self.control_stream.on_err(printer)
405 407
406 408 def make_dispatcher(stream):
407 409 def dispatcher(msg):
408 410 return self.dispatch_queue(stream, msg)
409 411 return dispatcher
410 412
411 413 for s in self.shell_streams:
412 414 s.on_recv(make_dispatcher(s), copy=False)
413 415 s.on_err(printer)
414 416
415 417 if self.iopub_stream:
416 418 self.iopub_stream.on_err(printer)
417 419
418 420 #### while True mode:
419 421 # while True:
420 422 # idle = True
421 423 # try:
422 424 # msg = self.shell_stream.socket.recv_multipart(
423 425 # zmq.NOBLOCK, copy=False)
424 426 # except zmq.ZMQError, e:
425 427 # if e.errno != zmq.EAGAIN:
426 428 # raise e
427 429 # else:
428 430 # idle=False
429 431 # self.dispatch_queue(self.shell_stream, msg)
430 432 #
431 433 # if not self.task_stream.empty():
432 434 # idle=False
433 435 # msg = self.task_stream.recv_multipart()
434 436 # self.dispatch_queue(self.task_stream, msg)
435 437 # if idle:
436 438 # # don't busywait
437 439 # time.sleep(1e-3)
438 440
@@ -1,179 +1,180 b''
1 1 """Tests for db backends
2 2
3 3 Authors:
4 4
5 5 * Min RK
6 6 """
7 7
8 8 #-------------------------------------------------------------------------------
9 9 # Copyright (C) 2011 The IPython Development Team
10 10 #
11 11 # Distributed under the terms of the BSD License. The full license is in
12 12 # the file COPYING, distributed as part of this software.
13 13 #-------------------------------------------------------------------------------
14 14
15 15 #-------------------------------------------------------------------------------
16 16 # Imports
17 17 #-------------------------------------------------------------------------------
18 18
19 19 from __future__ import division
20 20
21 21 import tempfile
22 22 import time
23 23
24 24 from datetime import datetime, timedelta
25 25 from unittest import TestCase
26 26
27 27 from nose import SkipTest
28 28
29 29 from IPython.parallel import error
30 30 from IPython.parallel.controller.dictdb import DictDB
31 31 from IPython.parallel.controller.sqlitedb import SQLiteDB
32 32 from IPython.parallel.controller.hub import init_record, empty_record
33 33
34 34 from IPython.zmq.session import Session
35 35
36 36
37 37 #-------------------------------------------------------------------------------
38 38 # TestCases
39 39 #-------------------------------------------------------------------------------
40 40
41 41 class TestDictBackend(TestCase):
42 42 def setUp(self):
43 43 self.session = Session()
44 44 self.db = self.create_db()
45 45 self.load_records(16)
46 46
47 47 def create_db(self):
48 48 return DictDB()
49 49
50 50 def load_records(self, n=1):
51 51 """load n records for testing"""
52 52 #sleep 1/10 s, to ensure timestamp is different to previous calls
53 53 time.sleep(0.1)
54 54 msg_ids = []
55 55 for i in range(n):
56 56 msg = self.session.msg('apply_request', content=dict(a=5))
57 57 msg['buffers'] = []
58 58 rec = init_record(msg)
59 msg_ids.append(msg['msg_id'])
60 self.db.add_record(msg['msg_id'], rec)
59 msg_id = msg['header']['msg_id']
60 msg_ids.append(msg_id)
61 self.db.add_record(msg_id, rec)
61 62 return msg_ids
62 63
63 64 def test_add_record(self):
64 65 before = self.db.get_history()
65 66 self.load_records(5)
66 67 after = self.db.get_history()
67 68 self.assertEquals(len(after), len(before)+5)
68 69 self.assertEquals(after[:-5],before)
69 70
70 71 def test_drop_record(self):
71 72 msg_id = self.load_records()[-1]
72 73 rec = self.db.get_record(msg_id)
73 74 self.db.drop_record(msg_id)
74 75 self.assertRaises(KeyError,self.db.get_record, msg_id)
75 76
76 77 def _round_to_millisecond(self, dt):
77 78 """necessary because mongodb rounds microseconds"""
78 79 micro = dt.microsecond
79 80 extra = int(str(micro)[-3:])
80 81 return dt - timedelta(microseconds=extra)
81 82
82 83 def test_update_record(self):
83 84 now = self._round_to_millisecond(datetime.now())
84 85 #
85 86 msg_id = self.db.get_history()[-1]
86 87 rec1 = self.db.get_record(msg_id)
87 88 data = {'stdout': 'hello there', 'completed' : now}
88 89 self.db.update_record(msg_id, data)
89 90 rec2 = self.db.get_record(msg_id)
90 91 self.assertEquals(rec2['stdout'], 'hello there')
91 92 self.assertEquals(rec2['completed'], now)
92 93 rec1.update(data)
93 94 self.assertEquals(rec1, rec2)
94 95
95 96 # def test_update_record_bad(self):
96 97 # """test updating nonexistant records"""
97 98 # msg_id = str(uuid.uuid4())
98 99 # data = {'stdout': 'hello there'}
99 100 # self.assertRaises(KeyError, self.db.update_record, msg_id, data)
100 101
101 102 def test_find_records_dt(self):
102 103 """test finding records by date"""
103 104 hist = self.db.get_history()
104 105 middle = self.db.get_record(hist[len(hist)//2])
105 106 tic = middle['submitted']
106 107 before = self.db.find_records({'submitted' : {'$lt' : tic}})
107 108 after = self.db.find_records({'submitted' : {'$gte' : tic}})
108 109 self.assertEquals(len(before)+len(after),len(hist))
109 110 for b in before:
110 111 self.assertTrue(b['submitted'] < tic)
111 112 for a in after:
112 113 self.assertTrue(a['submitted'] >= tic)
113 114 same = self.db.find_records({'submitted' : tic})
114 115 for s in same:
115 116 self.assertTrue(s['submitted'] == tic)
116 117
117 118 def test_find_records_keys(self):
118 119 """test extracting subset of record keys"""
119 120 found = self.db.find_records({'msg_id': {'$ne' : ''}},keys=['submitted', 'completed'])
120 121 for rec in found:
121 122 self.assertEquals(set(rec.keys()), set(['msg_id', 'submitted', 'completed']))
122 123
123 124 def test_find_records_msg_id(self):
124 125 """ensure msg_id is always in found records"""
125 126 found = self.db.find_records({'msg_id': {'$ne' : ''}},keys=['submitted', 'completed'])
126 127 for rec in found:
127 128 self.assertTrue('msg_id' in rec.keys())
128 129 found = self.db.find_records({'msg_id': {'$ne' : ''}},keys=['submitted'])
129 130 for rec in found:
130 131 self.assertTrue('msg_id' in rec.keys())
131 132 found = self.db.find_records({'msg_id': {'$ne' : ''}},keys=['msg_id'])
132 133 for rec in found:
133 134 self.assertTrue('msg_id' in rec.keys())
134 135
135 136 def test_find_records_in(self):
136 137 """test finding records with '$in','$nin' operators"""
137 138 hist = self.db.get_history()
138 139 even = hist[::2]
139 140 odd = hist[1::2]
140 141 recs = self.db.find_records({ 'msg_id' : {'$in' : even}})
141 142 found = [ r['msg_id'] for r in recs ]
142 143 self.assertEquals(set(even), set(found))
143 144 recs = self.db.find_records({ 'msg_id' : {'$nin' : even}})
144 145 found = [ r['msg_id'] for r in recs ]
145 146 self.assertEquals(set(odd), set(found))
146 147
147 148 def test_get_history(self):
148 149 msg_ids = self.db.get_history()
149 150 latest = datetime(1984,1,1)
150 151 for msg_id in msg_ids:
151 152 rec = self.db.get_record(msg_id)
152 153 newt = rec['submitted']
153 154 self.assertTrue(newt >= latest)
154 155 latest = newt
155 156 msg_id = self.load_records(1)[-1]
156 157 self.assertEquals(self.db.get_history()[-1],msg_id)
157 158
158 159 def test_datetime(self):
159 160 """get/set timestamps with datetime objects"""
160 161 msg_id = self.db.get_history()[-1]
161 162 rec = self.db.get_record(msg_id)
162 163 self.assertTrue(isinstance(rec['submitted'], datetime))
163 164 self.db.update_record(msg_id, dict(completed=datetime.now()))
164 165 rec = self.db.get_record(msg_id)
165 166 self.assertTrue(isinstance(rec['completed'], datetime))
166 167
167 168 def test_drop_matching(self):
168 169 msg_ids = self.load_records(10)
169 170 query = {'msg_id' : {'$in':msg_ids}}
170 171 self.db.drop_matching_records(query)
171 172 recs = self.db.find_records(query)
172 173 self.assertEquals(len(recs), 0)
173 174
174 175 class TestSQLiteBackend(TestDictBackend):
175 176 def create_db(self):
176 177 return SQLiteDB(location=tempfile.gettempdir())
177 178
178 179 def tearDown(self):
179 180 self.db._db.close()
@@ -1,678 +1,680 b''
1 1 #!/usr/bin/env python
2 2 """A simple interactive kernel that talks to a frontend over 0MQ.
3 3
4 4 Things to do:
5 5
6 6 * Implement `set_parent` logic. Right before doing exec, the Kernel should
7 7 call set_parent on all the PUB objects with the message about to be executed.
8 8 * Implement random port and security key logic.
9 9 * Implement control messages.
10 10 * Implement event loop and poll version.
11 11 """
12 12
13 13 #-----------------------------------------------------------------------------
14 14 # Imports
15 15 #-----------------------------------------------------------------------------
16 16 from __future__ import print_function
17 17
18 18 # Standard library imports.
19 19 import __builtin__
20 20 import atexit
21 21 import sys
22 22 import time
23 23 import traceback
24 24 import logging
25 25 # System library imports.
26 26 import zmq
27 27
28 28 # Local imports.
29 29 from IPython.config.configurable import Configurable
30 30 from IPython.config.application import boolean_flag
31 31 from IPython.core.application import ProfileDir
32 32 from IPython.core.shellapp import (
33 33 InteractiveShellApp, shell_flags, shell_aliases
34 34 )
35 35 from IPython.utils import io
36 36 from IPython.utils.jsonutil import json_clean
37 37 from IPython.lib import pylabtools
38 38 from IPython.utils.traitlets import (
39 39 List, Instance, Float, Dict, Bool, Int, Unicode, CaselessStrEnum
40 40 )
41 41
42 42 from entry_point import base_launch_kernel
43 43 from kernelapp import KernelApp, kernel_flags, kernel_aliases
44 44 from iostream import OutStream
45 45 from session import Session, Message
46 46 from zmqshell import ZMQInteractiveShell
47 47
48 48
49 49
50 50 #-----------------------------------------------------------------------------
51 51 # Main kernel class
52 52 #-----------------------------------------------------------------------------
53 53
54 54 class Kernel(Configurable):
55 55
56 56 #---------------------------------------------------------------------------
57 57 # Kernel interface
58 58 #---------------------------------------------------------------------------
59 59
60 60 shell = Instance('IPython.core.interactiveshell.InteractiveShellABC')
61 61 session = Instance(Session)
62 62 shell_socket = Instance('zmq.Socket')
63 63 iopub_socket = Instance('zmq.Socket')
64 64 stdin_socket = Instance('zmq.Socket')
65 65 log = Instance(logging.Logger)
66 66
67 67 # Private interface
68 68
69 69 # Time to sleep after flushing the stdout/err buffers in each execute
70 70 # cycle. While this introduces a hard limit on the minimal latency of the
71 71 # execute cycle, it helps prevent output synchronization problems for
72 72 # clients.
73 73 # Units are in seconds. The minimum zmq latency on local host is probably
74 74 # ~150 microseconds, set this to 500us for now. We may need to increase it
75 75 # a little if it's not enough after more interactive testing.
76 76 _execute_sleep = Float(0.0005, config=True)
77 77
78 78 # Frequency of the kernel's event loop.
79 79 # Units are in seconds, kernel subclasses for GUI toolkits may need to
80 80 # adapt to milliseconds.
81 81 _poll_interval = Float(0.05, config=True)
82 82
83 83 # If the shutdown was requested over the network, we leave here the
84 84 # necessary reply message so it can be sent by our registered atexit
85 85 # handler. This ensures that the reply is only sent to clients truly at
86 86 # the end of our shutdown process (which happens after the underlying
87 87 # IPython shell's own shutdown).
88 88 _shutdown_message = None
89 89
90 90 # This is a dict of port number that the kernel is listening on. It is set
91 91 # by record_ports and used by connect_request.
92 92 _recorded_ports = Dict()
93 93
94 94
95 95
96 96 def __init__(self, **kwargs):
97 97 super(Kernel, self).__init__(**kwargs)
98 98
99 99 # Before we even start up the shell, register *first* our exit handlers
100 100 # so they come before the shell's
101 101 atexit.register(self._at_shutdown)
102 102
103 103 # Initialize the InteractiveShell subclass
104 104 self.shell = ZMQInteractiveShell.instance(config=self.config)
105 105 self.shell.displayhook.session = self.session
106 106 self.shell.displayhook.pub_socket = self.iopub_socket
107 107 self.shell.display_pub.session = self.session
108 108 self.shell.display_pub.pub_socket = self.iopub_socket
109 109
110 110 # TMP - hack while developing
111 111 self.shell._reply_content = None
112 112
113 113 # Build dict of handlers for message types
114 114 msg_types = [ 'execute_request', 'complete_request',
115 115 'object_info_request', 'history_request',
116 116 'connect_request', 'shutdown_request']
117 117 self.handlers = {}
118 118 for msg_type in msg_types:
119 119 self.handlers[msg_type] = getattr(self, msg_type)
120 120
121 121 def do_one_iteration(self):
122 122 """Do one iteration of the kernel's evaluation loop.
123 123 """
124 124 ident,msg = self.session.recv(self.shell_socket, zmq.NOBLOCK)
125 125 if msg is None:
126 126 return
127
127
128 msg_type = msg['header']['msg_type']
129
128 130 # This assert will raise in versions of zeromq 2.0.7 and lesser.
129 131 # We now require 2.0.8 or above, so we can uncomment for safety.
130 132 # print(ident,msg, file=sys.__stdout__)
131 133 assert ident is not None, "Missing message part."
132 134
133 135 # Print some info about this message and leave a '--->' marker, so it's
134 136 # easier to trace visually the message chain when debugging. Each
135 137 # handler prints its message at the end.
136 self.log.debug('\n*** MESSAGE TYPE:'+str(msg['msg_type'])+'***')
138 self.log.debug('\n*** MESSAGE TYPE:'+str(msg_type)+'***')
137 139 self.log.debug(' Content: '+str(msg['content'])+'\n --->\n ')
138 140
139 141 # Find and call actual handler for message
140 handler = self.handlers.get(msg['msg_type'], None)
142 handler = self.handlers.get(msg_type, None)
141 143 if handler is None:
142 144 self.log.error("UNKNOWN MESSAGE TYPE:" +str(msg))
143 145 else:
144 146 handler(ident, msg)
145 147
146 148 # Check whether we should exit, in case the incoming message set the
147 149 # exit flag on
148 150 if self.shell.exit_now:
149 151 self.log.debug('\nExiting IPython kernel...')
150 152 # We do a normal, clean exit, which allows any actions registered
151 153 # via atexit (such as history saving) to take place.
152 154 sys.exit(0)
153 155
154 156
155 157 def start(self):
156 158 """ Start the kernel main loop.
157 159 """
158 160 poller = zmq.Poller()
159 161 poller.register(self.shell_socket, zmq.POLLIN)
160 162 while True:
161 163 try:
162 164 # scale by extra factor of 10, because there is no
163 165 # reason for this to be anything less than ~ 0.1s
164 166 # since it is a real poller and will respond
165 167 # to events immediately
166 168 poller.poll(10*1000*self._poll_interval)
167 169 self.do_one_iteration()
168 170 except KeyboardInterrupt:
169 171 # Ctrl-C shouldn't crash the kernel
170 172 io.raw_print("KeyboardInterrupt caught in kernel")
171 173
172 174 def record_ports(self, ports):
173 175 """Record the ports that this kernel is using.
174 176
175 177 The creator of the Kernel instance must call this methods if they
176 178 want the :meth:`connect_request` method to return the port numbers.
177 179 """
178 180 self._recorded_ports = ports
179 181
180 182 #---------------------------------------------------------------------------
181 183 # Kernel request handlers
182 184 #---------------------------------------------------------------------------
183 185
184 186 def _publish_pyin(self, code, parent):
185 187 """Publish the code request on the pyin stream."""
186 188
187 189 pyin_msg = self.session.send(self.iopub_socket, u'pyin',{u'code':code}, parent=parent)
188 190
189 191 def execute_request(self, ident, parent):
190 192
191 193 status_msg = self.session.send(self.iopub_socket,
192 194 u'status',
193 195 {u'execution_state':u'busy'},
194 196 parent=parent
195 197 )
196 198
197 199 try:
198 200 content = parent[u'content']
199 201 code = content[u'code']
200 202 silent = content[u'silent']
201 203 except:
202 204 self.log.error("Got bad msg: ")
203 205 self.log.error(str(Message(parent)))
204 206 return
205 207
206 208 shell = self.shell # we'll need this a lot here
207 209
208 210 # Replace raw_input. Note that is not sufficient to replace
209 211 # raw_input in the user namespace.
210 212 raw_input = lambda prompt='': self._raw_input(prompt, ident, parent)
211 213 __builtin__.raw_input = raw_input
212 214
213 215 # Set the parent message of the display hook and out streams.
214 216 shell.displayhook.set_parent(parent)
215 217 shell.display_pub.set_parent(parent)
216 218 sys.stdout.set_parent(parent)
217 219 sys.stderr.set_parent(parent)
218 220
219 221 # Re-broadcast our input for the benefit of listening clients, and
220 222 # start computing output
221 223 if not silent:
222 224 self._publish_pyin(code, parent)
223 225
224 226 reply_content = {}
225 227 try:
226 228 if silent:
227 229 # run_code uses 'exec' mode, so no displayhook will fire, and it
228 230 # doesn't call logging or history manipulations. Print
229 231 # statements in that code will obviously still execute.
230 232 shell.run_code(code)
231 233 else:
232 234 # FIXME: the shell calls the exception handler itself.
233 235 shell.run_cell(code)
234 236 except:
235 237 status = u'error'
236 238 # FIXME: this code right now isn't being used yet by default,
237 239 # because the run_cell() call above directly fires off exception
238 240 # reporting. This code, therefore, is only active in the scenario
239 241 # where runlines itself has an unhandled exception. We need to
240 242 # uniformize this, for all exception construction to come from a
241 243 # single location in the codbase.
242 244 etype, evalue, tb = sys.exc_info()
243 245 tb_list = traceback.format_exception(etype, evalue, tb)
244 246 reply_content.update(shell._showtraceback(etype, evalue, tb_list))
245 247 else:
246 248 status = u'ok'
247 249
248 250 reply_content[u'status'] = status
249 251
250 252 # Return the execution counter so clients can display prompts
251 253 reply_content['execution_count'] = shell.execution_count -1
252 254
253 255 # FIXME - fish exception info out of shell, possibly left there by
254 256 # runlines. We'll need to clean up this logic later.
255 257 if shell._reply_content is not None:
256 258 reply_content.update(shell._reply_content)
257 259 # reset after use
258 260 shell._reply_content = None
259 261
260 262 # At this point, we can tell whether the main code execution succeeded
261 263 # or not. If it did, we proceed to evaluate user_variables/expressions
262 264 if reply_content['status'] == 'ok':
263 265 reply_content[u'user_variables'] = \
264 266 shell.user_variables(content[u'user_variables'])
265 267 reply_content[u'user_expressions'] = \
266 268 shell.user_expressions(content[u'user_expressions'])
267 269 else:
268 270 # If there was an error, don't even try to compute variables or
269 271 # expressions
270 272 reply_content[u'user_variables'] = {}
271 273 reply_content[u'user_expressions'] = {}
272 274
273 275 # Payloads should be retrieved regardless of outcome, so we can both
274 276 # recover partial output (that could have been generated early in a
275 277 # block, before an error) and clear the payload system always.
276 278 reply_content[u'payload'] = shell.payload_manager.read_payload()
277 279 # Be agressive about clearing the payload because we don't want
278 280 # it to sit in memory until the next execute_request comes in.
279 281 shell.payload_manager.clear_payload()
280 282
281 283 # Flush output before sending the reply.
282 284 sys.stdout.flush()
283 285 sys.stderr.flush()
284 286 # FIXME: on rare occasions, the flush doesn't seem to make it to the
285 287 # clients... This seems to mitigate the problem, but we definitely need
286 288 # to better understand what's going on.
287 289 if self._execute_sleep:
288 290 time.sleep(self._execute_sleep)
289 291
290 292 # Send the reply.
291 293 reply_msg = self.session.send(self.shell_socket, u'execute_reply',
292 294 reply_content, parent, ident=ident)
293 295 self.log.debug(str(reply_msg))
294 296
295 297 if reply_msg['content']['status'] == u'error':
296 298 self._abort_queue()
297 299
298 300 status_msg = self.session.send(self.iopub_socket,
299 301 u'status',
300 302 {u'execution_state':u'idle'},
301 303 parent=parent
302 304 )
303 305
304 306 def complete_request(self, ident, parent):
305 307 txt, matches = self._complete(parent)
306 308 matches = {'matches' : matches,
307 309 'matched_text' : txt,
308 310 'status' : 'ok'}
309 311 completion_msg = self.session.send(self.shell_socket, 'complete_reply',
310 312 matches, parent, ident)
311 313 self.log.debug(str(completion_msg))
312 314
313 315 def object_info_request(self, ident, parent):
314 316 object_info = self.shell.object_inspect(parent['content']['oname'])
315 317 # Before we send this object over, we scrub it for JSON usage
316 318 oinfo = json_clean(object_info)
317 319 msg = self.session.send(self.shell_socket, 'object_info_reply',
318 320 oinfo, parent, ident)
319 321 self.log.debug(msg)
320 322
321 323 def history_request(self, ident, parent):
322 324 # We need to pull these out, as passing **kwargs doesn't work with
323 325 # unicode keys before Python 2.6.5.
324 326 hist_access_type = parent['content']['hist_access_type']
325 327 raw = parent['content']['raw']
326 328 output = parent['content']['output']
327 329 if hist_access_type == 'tail':
328 330 n = parent['content']['n']
329 331 hist = self.shell.history_manager.get_tail(n, raw=raw, output=output,
330 332 include_latest=True)
331 333
332 334 elif hist_access_type == 'range':
333 335 session = parent['content']['session']
334 336 start = parent['content']['start']
335 337 stop = parent['content']['stop']
336 338 hist = self.shell.history_manager.get_range(session, start, stop,
337 339 raw=raw, output=output)
338 340
339 341 elif hist_access_type == 'search':
340 342 pattern = parent['content']['pattern']
341 343 hist = self.shell.history_manager.search(pattern, raw=raw, output=output)
342 344
343 345 else:
344 346 hist = []
345 347 content = {'history' : list(hist)}
346 348 msg = self.session.send(self.shell_socket, 'history_reply',
347 349 content, parent, ident)
348 350 self.log.debug(str(msg))
349 351
350 352 def connect_request(self, ident, parent):
351 353 if self._recorded_ports is not None:
352 354 content = self._recorded_ports.copy()
353 355 else:
354 356 content = {}
355 357 msg = self.session.send(self.shell_socket, 'connect_reply',
356 358 content, parent, ident)
357 359 self.log.debug(msg)
358 360
359 361 def shutdown_request(self, ident, parent):
360 362 self.shell.exit_now = True
361 363 self._shutdown_message = self.session.msg(u'shutdown_reply', parent['content'], parent)
362 364 sys.exit(0)
363 365
364 366 #---------------------------------------------------------------------------
365 367 # Protected interface
366 368 #---------------------------------------------------------------------------
367 369
368 370 def _abort_queue(self):
369 371 while True:
370 372 ident,msg = self.session.recv(self.shell_socket, zmq.NOBLOCK)
371 373 if msg is None:
372 374 break
373 375 else:
374 376 assert ident is not None, \
375 377 "Unexpected missing message part."
376 378
377 379 self.log.debug("Aborting:\n"+str(Message(msg)))
378 msg_type = msg['msg_type']
380 msg_type = msg['header']['msg_type']
379 381 reply_type = msg_type.split('_')[0] + '_reply'
380 382 reply_msg = self.session.send(self.shell_socket, reply_type,
381 383 {'status' : 'aborted'}, msg, ident=ident)
382 384 self.log.debug(reply_msg)
383 385 # We need to wait a bit for requests to come in. This can probably
384 386 # be set shorter for true asynchronous clients.
385 387 time.sleep(0.1)
386 388
387 389 def _raw_input(self, prompt, ident, parent):
388 390 # Flush output before making the request.
389 391 sys.stderr.flush()
390 392 sys.stdout.flush()
391 393
392 394 # Send the input request.
393 395 content = dict(prompt=prompt)
394 396 msg = self.session.send(self.stdin_socket, u'input_request', content, parent)
395 397
396 398 # Await a response.
397 399 ident, reply = self.session.recv(self.stdin_socket, 0)
398 400 try:
399 401 value = reply['content']['value']
400 402 except:
401 403 self.log.error("Got bad raw_input reply: ")
402 404 self.log.error(str(Message(parent)))
403 405 value = ''
404 406 return value
405 407
406 408 def _complete(self, msg):
407 409 c = msg['content']
408 410 try:
409 411 cpos = int(c['cursor_pos'])
410 412 except:
411 413 # If we don't get something that we can convert to an integer, at
412 414 # least attempt the completion guessing the cursor is at the end of
413 415 # the text, if there's any, and otherwise of the line
414 416 cpos = len(c['text'])
415 417 if cpos==0:
416 418 cpos = len(c['line'])
417 419 return self.shell.complete(c['text'], c['line'], cpos)
418 420
419 421 def _object_info(self, context):
420 422 symbol, leftover = self._symbol_from_context(context)
421 423 if symbol is not None and not leftover:
422 424 doc = getattr(symbol, '__doc__', '')
423 425 else:
424 426 doc = ''
425 427 object_info = dict(docstring = doc)
426 428 return object_info
427 429
428 430 def _symbol_from_context(self, context):
429 431 if not context:
430 432 return None, context
431 433
432 434 base_symbol_string = context[0]
433 435 symbol = self.shell.user_ns.get(base_symbol_string, None)
434 436 if symbol is None:
435 437 symbol = __builtin__.__dict__.get(base_symbol_string, None)
436 438 if symbol is None:
437 439 return None, context
438 440
439 441 context = context[1:]
440 442 for i, name in enumerate(context):
441 443 new_symbol = getattr(symbol, name, None)
442 444 if new_symbol is None:
443 445 return symbol, context[i:]
444 446 else:
445 447 symbol = new_symbol
446 448
447 449 return symbol, []
448 450
449 451 def _at_shutdown(self):
450 452 """Actions taken at shutdown by the kernel, called by python's atexit.
451 453 """
452 454 # io.rprint("Kernel at_shutdown") # dbg
453 455 if self._shutdown_message is not None:
454 456 self.session.send(self.shell_socket, self._shutdown_message)
455 457 self.session.send(self.iopub_socket, self._shutdown_message)
456 458 self.log.debug(str(self._shutdown_message))
457 459 # A very short sleep to give zmq time to flush its message buffers
458 460 # before Python truly shuts down.
459 461 time.sleep(0.01)
460 462
461 463
462 464 class QtKernel(Kernel):
463 465 """A Kernel subclass with Qt support."""
464 466
465 467 def start(self):
466 468 """Start a kernel with QtPy4 event loop integration."""
467 469
468 470 from IPython.external.qt_for_kernel import QtCore
469 471 from IPython.lib.guisupport import get_app_qt4, start_event_loop_qt4
470 472
471 473 self.app = get_app_qt4([" "])
472 474 self.app.setQuitOnLastWindowClosed(False)
473 475 self.timer = QtCore.QTimer()
474 476 self.timer.timeout.connect(self.do_one_iteration)
475 477 # Units for the timer are in milliseconds
476 478 self.timer.start(1000*self._poll_interval)
477 479 start_event_loop_qt4(self.app)
478 480
479 481
480 482 class WxKernel(Kernel):
481 483 """A Kernel subclass with Wx support."""
482 484
483 485 def start(self):
484 486 """Start a kernel with wx event loop support."""
485 487
486 488 import wx
487 489 from IPython.lib.guisupport import start_event_loop_wx
488 490
489 491 doi = self.do_one_iteration
490 492 # Wx uses milliseconds
491 493 poll_interval = int(1000*self._poll_interval)
492 494
493 495 # We have to put the wx.Timer in a wx.Frame for it to fire properly.
494 496 # We make the Frame hidden when we create it in the main app below.
495 497 class TimerFrame(wx.Frame):
496 498 def __init__(self, func):
497 499 wx.Frame.__init__(self, None, -1)
498 500 self.timer = wx.Timer(self)
499 501 # Units for the timer are in milliseconds
500 502 self.timer.Start(poll_interval)
501 503 self.Bind(wx.EVT_TIMER, self.on_timer)
502 504 self.func = func
503 505
504 506 def on_timer(self, event):
505 507 self.func()
506 508
507 509 # We need a custom wx.App to create our Frame subclass that has the
508 510 # wx.Timer to drive the ZMQ event loop.
509 511 class IPWxApp(wx.App):
510 512 def OnInit(self):
511 513 self.frame = TimerFrame(doi)
512 514 self.frame.Show(False)
513 515 return True
514 516
515 517 # The redirect=False here makes sure that wx doesn't replace
516 518 # sys.stdout/stderr with its own classes.
517 519 self.app = IPWxApp(redirect=False)
518 520 start_event_loop_wx(self.app)
519 521
520 522
521 523 class TkKernel(Kernel):
522 524 """A Kernel subclass with Tk support."""
523 525
524 526 def start(self):
525 527 """Start a Tk enabled event loop."""
526 528
527 529 import Tkinter
528 530 doi = self.do_one_iteration
529 531 # Tk uses milliseconds
530 532 poll_interval = int(1000*self._poll_interval)
531 533 # For Tkinter, we create a Tk object and call its withdraw method.
532 534 class Timer(object):
533 535 def __init__(self, func):
534 536 self.app = Tkinter.Tk()
535 537 self.app.withdraw()
536 538 self.func = func
537 539
538 540 def on_timer(self):
539 541 self.func()
540 542 self.app.after(poll_interval, self.on_timer)
541 543
542 544 def start(self):
543 545 self.on_timer() # Call it once to get things going.
544 546 self.app.mainloop()
545 547
546 548 self.timer = Timer(doi)
547 549 self.timer.start()
548 550
549 551
550 552 class GTKKernel(Kernel):
551 553 """A Kernel subclass with GTK support."""
552 554
553 555 def start(self):
554 556 """Start the kernel, coordinating with the GTK event loop"""
555 557 from .gui.gtkembed import GTKEmbed
556 558
557 559 gtk_kernel = GTKEmbed(self)
558 560 gtk_kernel.start()
559 561
560 562
561 563 #-----------------------------------------------------------------------------
562 564 # Aliases and Flags for the IPKernelApp
563 565 #-----------------------------------------------------------------------------
564 566
565 567 flags = dict(kernel_flags)
566 568 flags.update(shell_flags)
567 569
568 570 addflag = lambda *args: flags.update(boolean_flag(*args))
569 571
570 572 flags['pylab'] = (
571 573 {'IPKernelApp' : {'pylab' : 'auto'}},
572 574 """Pre-load matplotlib and numpy for interactive use with
573 575 the default matplotlib backend."""
574 576 )
575 577
576 578 aliases = dict(kernel_aliases)
577 579 aliases.update(shell_aliases)
578 580
579 581 # it's possible we don't want short aliases for *all* of these:
580 582 aliases.update(dict(
581 583 pylab='IPKernelApp.pylab',
582 584 ))
583 585
584 586 #-----------------------------------------------------------------------------
585 587 # The IPKernelApp class
586 588 #-----------------------------------------------------------------------------
587 589
588 590 class IPKernelApp(KernelApp, InteractiveShellApp):
589 591 name = 'ipkernel'
590 592
591 593 aliases = Dict(aliases)
592 594 flags = Dict(flags)
593 595 classes = [Kernel, ZMQInteractiveShell, ProfileDir, Session]
594 596 # configurables
595 597 pylab = CaselessStrEnum(['tk', 'qt', 'wx', 'gtk', 'osx', 'inline', 'auto'],
596 598 config=True,
597 599 help="""Pre-load matplotlib and numpy for interactive use,
598 600 selecting a particular matplotlib backend and loop integration.
599 601 """
600 602 )
601 603 pylab_import_all = Bool(True, config=True,
602 604 help="""If true, an 'import *' is done from numpy and pylab,
603 605 when using pylab"""
604 606 )
605 607 def initialize(self, argv=None):
606 608 super(IPKernelApp, self).initialize(argv)
607 609 self.init_shell()
608 610 self.init_extensions()
609 611 self.init_code()
610 612
611 613 def init_kernel(self):
612 614 kernel_factory = Kernel
613 615
614 616 kernel_map = {
615 617 'qt' : QtKernel,
616 618 'qt4': QtKernel,
617 619 'inline': Kernel,
618 620 'osx': TkKernel,
619 621 'wx' : WxKernel,
620 622 'tk' : TkKernel,
621 623 'gtk': GTKKernel,
622 624 }
623 625
624 626 if self.pylab:
625 627 key = None if self.pylab == 'auto' else self.pylab
626 628 gui, backend = pylabtools.find_gui_and_backend(key)
627 629 kernel_factory = kernel_map.get(gui)
628 630 if kernel_factory is None:
629 631 raise ValueError('GUI is not supported: %r' % gui)
630 632 pylabtools.activate_matplotlib(backend)
631 633
632 634 kernel = kernel_factory(config=self.config, session=self.session,
633 635 shell_socket=self.shell_socket,
634 636 iopub_socket=self.iopub_socket,
635 637 stdin_socket=self.stdin_socket,
636 638 log=self.log
637 639 )
638 640 self.kernel = kernel
639 641 kernel.record_ports(self.ports)
640 642
641 643 if self.pylab:
642 644 import_all = self.pylab_import_all
643 645 pylabtools.import_pylab(kernel.shell.user_ns, backend, import_all,
644 646 shell=kernel.shell)
645 647
646 648 def init_shell(self):
647 649 self.shell = self.kernel.shell
648 650
649 651
650 652 #-----------------------------------------------------------------------------
651 653 # Kernel main and launch functions
652 654 #-----------------------------------------------------------------------------
653 655
654 656 def launch_kernel(*args, **kwargs):
655 657 """Launches a localhost IPython kernel, binding to the specified ports.
656 658
657 659 This function simply calls entry_point.base_launch_kernel with the right first
658 660 command to start an ipkernel. See base_launch_kernel for arguments.
659 661
660 662 Returns
661 663 -------
662 664 A tuple of form:
663 665 (kernel_process, shell_port, iopub_port, stdin_port, hb_port)
664 666 where kernel_process is a Popen object and the ports are integers.
665 667 """
666 668 return base_launch_kernel('from IPython.zmq.ipkernel import main; main()',
667 669 *args, **kwargs)
668 670
669 671
670 672 def main():
671 673 """Run an IPKernel as an application"""
672 674 app = IPKernelApp.instance()
673 675 app.initialize()
674 676 app.start()
675 677
676 678
677 679 if __name__ == '__main__':
678 680 main()
@@ -1,278 +1,278 b''
1 1 #!/usr/bin/env python
2 2 """A simple interactive kernel that talks to a frontend over 0MQ.
3 3
4 4 Things to do:
5 5
6 6 * Implement `set_parent` logic. Right before doing exec, the Kernel should
7 7 call set_parent on all the PUB objects with the message about to be executed.
8 8 * Implement random port and security key logic.
9 9 * Implement control messages.
10 10 * Implement event loop and poll version.
11 11 """
12 12
13 13 #-----------------------------------------------------------------------------
14 14 # Imports
15 15 #-----------------------------------------------------------------------------
16 16
17 17 # Standard library imports.
18 18 import __builtin__
19 19 from code import CommandCompiler
20 20 import sys
21 21 import time
22 22 import traceback
23 23
24 24 # System library imports.
25 25 import zmq
26 26
27 27 # Local imports.
28 28 from IPython.utils.traitlets import HasTraits, Instance, Dict, Float
29 29 from completer import KernelCompleter
30 30 from entry_point import base_launch_kernel
31 31 from session import Session, Message
32 32 from kernelapp import KernelApp
33 33
34 34 #-----------------------------------------------------------------------------
35 35 # Main kernel class
36 36 #-----------------------------------------------------------------------------
37 37
38 38 class Kernel(HasTraits):
39 39
40 40 # Private interface
41 41
42 42 # Time to sleep after flushing the stdout/err buffers in each execute
43 43 # cycle. While this introduces a hard limit on the minimal latency of the
44 44 # execute cycle, it helps prevent output synchronization problems for
45 45 # clients.
46 46 # Units are in seconds. The minimum zmq latency on local host is probably
47 47 # ~150 microseconds, set this to 500us for now. We may need to increase it
48 48 # a little if it's not enough after more interactive testing.
49 49 _execute_sleep = Float(0.0005, config=True)
50 50
51 51 # This is a dict of port number that the kernel is listening on. It is set
52 52 # by record_ports and used by connect_request.
53 53 _recorded_ports = Dict()
54 54
55 55 #---------------------------------------------------------------------------
56 56 # Kernel interface
57 57 #---------------------------------------------------------------------------
58 58
59 59 session = Instance(Session)
60 60 shell_socket = Instance('zmq.Socket')
61 61 iopub_socket = Instance('zmq.Socket')
62 62 stdin_socket = Instance('zmq.Socket')
63 63 log = Instance('logging.Logger')
64 64
65 65 def __init__(self, **kwargs):
66 66 super(Kernel, self).__init__(**kwargs)
67 67 self.user_ns = {}
68 68 self.history = []
69 69 self.compiler = CommandCompiler()
70 70 self.completer = KernelCompleter(self.user_ns)
71 71
72 72 # Build dict of handlers for message types
73 73 msg_types = [ 'execute_request', 'complete_request',
74 74 'object_info_request', 'shutdown_request' ]
75 75 self.handlers = {}
76 76 for msg_type in msg_types:
77 77 self.handlers[msg_type] = getattr(self, msg_type)
78 78
79 79 def start(self):
80 80 """ Start the kernel main loop.
81 81 """
82 82 while True:
83 83 ident,msg = self.session.recv(self.shell_socket,0)
84 84 assert ident is not None, "Missing message part."
85 85 omsg = Message(msg)
86 86 self.log.debug(str(omsg))
87 87 handler = self.handlers.get(omsg.msg_type, None)
88 88 if handler is None:
89 89 self.log.error("UNKNOWN MESSAGE TYPE: %s"%omsg)
90 90 else:
91 91 handler(ident, omsg)
92 92
93 93 def record_ports(self, ports):
94 94 """Record the ports that this kernel is using.
95 95
96 96 The creator of the Kernel instance must call this methods if they
97 97 want the :meth:`connect_request` method to return the port numbers.
98 98 """
99 99 self._recorded_ports = ports
100 100
101 101 #---------------------------------------------------------------------------
102 102 # Kernel request handlers
103 103 #---------------------------------------------------------------------------
104 104
105 105 def execute_request(self, ident, parent):
106 106 try:
107 107 code = parent[u'content'][u'code']
108 108 except:
109 109 self.log.error("Got bad msg: %s"%Message(parent))
110 110 return
111 111 pyin_msg = self.session.send(self.iopub_socket, u'pyin',{u'code':code}, parent=parent)
112 112
113 113 try:
114 114 comp_code = self.compiler(code, '<zmq-kernel>')
115 115
116 116 # Replace raw_input. Note that is not sufficient to replace
117 117 # raw_input in the user namespace.
118 118 raw_input = lambda prompt='': self._raw_input(prompt, ident, parent)
119 119 __builtin__.raw_input = raw_input
120 120
121 121 # Set the parent message of the display hook and out streams.
122 122 sys.displayhook.set_parent(parent)
123 123 sys.stdout.set_parent(parent)
124 124 sys.stderr.set_parent(parent)
125 125
126 126 exec comp_code in self.user_ns, self.user_ns
127 127 except:
128 128 etype, evalue, tb = sys.exc_info()
129 129 tb = traceback.format_exception(etype, evalue, tb)
130 130 exc_content = {
131 131 u'status' : u'error',
132 132 u'traceback' : tb,
133 133 u'ename' : unicode(etype.__name__),
134 134 u'evalue' : unicode(evalue)
135 135 }
136 136 exc_msg = self.session.send(self.iopub_socket, u'pyerr', exc_content, parent)
137 137 reply_content = exc_content
138 138 else:
139 139 reply_content = { 'status' : 'ok', 'payload' : {} }
140 140
141 141 # Flush output before sending the reply.
142 142 sys.stderr.flush()
143 143 sys.stdout.flush()
144 144 # FIXME: on rare occasions, the flush doesn't seem to make it to the
145 145 # clients... This seems to mitigate the problem, but we definitely need
146 146 # to better understand what's going on.
147 147 if self._execute_sleep:
148 148 time.sleep(self._execute_sleep)
149 149
150 150 # Send the reply.
151 151 reply_msg = self.session.send(self.shell_socket, u'execute_reply', reply_content, parent, ident=ident)
152 152 self.log.debug(Message(reply_msg))
153 153 if reply_msg['content']['status'] == u'error':
154 154 self._abort_queue()
155 155
156 156 def complete_request(self, ident, parent):
157 157 matches = {'matches' : self._complete(parent),
158 158 'status' : 'ok'}
159 159 completion_msg = self.session.send(self.shell_socket, 'complete_reply',
160 160 matches, parent, ident)
161 161 self.log.debug(completion_msg)
162 162
163 163 def object_info_request(self, ident, parent):
164 164 context = parent['content']['oname'].split('.')
165 165 object_info = self._object_info(context)
166 166 msg = self.session.send(self.shell_socket, 'object_info_reply',
167 167 object_info, parent, ident)
168 168 self.log.debug(msg)
169 169
170 170 def shutdown_request(self, ident, parent):
171 171 content = dict(parent['content'])
172 172 msg = self.session.send(self.shell_socket, 'shutdown_reply',
173 173 content, parent, ident)
174 174 msg = self.session.send(self.iopub_socket, 'shutdown_reply',
175 175 content, parent, ident)
176 176 self.log.debug(msg)
177 177 time.sleep(0.1)
178 178 sys.exit(0)
179 179
180 180 #---------------------------------------------------------------------------
181 181 # Protected interface
182 182 #---------------------------------------------------------------------------
183 183
184 184 def _abort_queue(self):
185 185 while True:
186 186 ident,msg = self.session.recv(self.shell_socket, zmq.NOBLOCK)
187 187 if msg is None:
188 188 # msg=None on EAGAIN
189 189 break
190 190 else:
191 191 assert ident is not None, "Missing message part."
192 192 self.log.debug("Aborting: %s"%Message(msg))
193 msg_type = msg['msg_type']
193 msg_type = msg['header']['msg_type']
194 194 reply_type = msg_type.split('_')[0] + '_reply'
195 195 reply_msg = self.session.send(self.shell_socket, reply_type, {'status':'aborted'}, msg, ident=ident)
196 196 self.log.debug(Message(reply_msg))
197 197 # We need to wait a bit for requests to come in. This can probably
198 198 # be set shorter for true asynchronous clients.
199 199 time.sleep(0.1)
200 200
201 201 def _raw_input(self, prompt, ident, parent):
202 202 # Flush output before making the request.
203 203 sys.stderr.flush()
204 204 sys.stdout.flush()
205 205
206 206 # Send the input request.
207 207 content = dict(prompt=prompt)
208 208 msg = self.session.send(self.stdin_socket, u'input_request', content, parent)
209 209
210 210 # Await a response.
211 211 ident,reply = self.session.recv(self.stdin_socket, 0)
212 212 try:
213 213 value = reply['content']['value']
214 214 except:
215 215 self.log.error("Got bad raw_input reply: %s"%Message(parent))
216 216 value = ''
217 217 return value
218 218
219 219 def _complete(self, msg):
220 220 return self.completer.complete(msg.content.line, msg.content.text)
221 221
222 222 def _object_info(self, context):
223 223 symbol, leftover = self._symbol_from_context(context)
224 224 if symbol is not None and not leftover:
225 225 doc = getattr(symbol, '__doc__', '')
226 226 else:
227 227 doc = ''
228 228 object_info = dict(docstring = doc)
229 229 return object_info
230 230
231 231 def _symbol_from_context(self, context):
232 232 if not context:
233 233 return None, context
234 234
235 235 base_symbol_string = context[0]
236 236 symbol = self.user_ns.get(base_symbol_string, None)
237 237 if symbol is None:
238 238 symbol = __builtin__.__dict__.get(base_symbol_string, None)
239 239 if symbol is None:
240 240 return None, context
241 241
242 242 context = context[1:]
243 243 for i, name in enumerate(context):
244 244 new_symbol = getattr(symbol, name, None)
245 245 if new_symbol is None:
246 246 return symbol, context[i:]
247 247 else:
248 248 symbol = new_symbol
249 249
250 250 return symbol, []
251 251
252 252 #-----------------------------------------------------------------------------
253 253 # Kernel main and launch functions
254 254 #-----------------------------------------------------------------------------
255 255
256 256 def launch_kernel(*args, **kwargs):
257 257 """ Launches a simple Python kernel, binding to the specified ports.
258 258
259 259 This function simply calls entry_point.base_launch_kernel with the right first
260 260 command to start a pykernel. See base_launch_kernel for arguments.
261 261
262 262 Returns
263 263 -------
264 264 A tuple of form:
265 265 (kernel_process, xrep_port, pub_port, req_port, hb_port)
266 266 where kernel_process is a Popen object and the ports are integers.
267 267 """
268 268 return base_launch_kernel('from IPython.zmq.pykernel import main; main()',
269 269 *args, **kwargs)
270 270
271 271 def main():
272 272 """Run a PyKernel as an application"""
273 273 app = KernelApp.instance()
274 274 app.initialize()
275 275 app.start()
276 276
277 277 if __name__ == '__main__':
278 278 main()
@@ -1,679 +1,697 b''
1 1 #!/usr/bin/env python
2 2 """Session object for building, serializing, sending, and receiving messages in
3 3 IPython. The Session object supports serialization, HMAC signatures, and
4 4 metadata on messages.
5 5
6 6 Also defined here are utilities for working with Sessions:
7 7 * A SessionFactory to be used as a base class for configurables that work with
8 8 Sessions.
9 9 * A Message object for convenience that allows attribute-access to the msg dict.
10 10
11 11 Authors:
12 12
13 13 * Min RK
14 14 * Brian Granger
15 15 * Fernando Perez
16 16 """
17 17 #-----------------------------------------------------------------------------
18 18 # Copyright (C) 2010-2011 The IPython Development Team
19 19 #
20 20 # Distributed under the terms of the BSD License. The full license is in
21 21 # the file COPYING, distributed as part of this software.
22 22 #-----------------------------------------------------------------------------
23 23
24 24 #-----------------------------------------------------------------------------
25 25 # Imports
26 26 #-----------------------------------------------------------------------------
27 27
28 28 import hmac
29 29 import logging
30 30 import os
31 31 import pprint
32 32 import uuid
33 33 from datetime import datetime
34 34
35 35 try:
36 36 import cPickle
37 37 pickle = cPickle
38 38 except:
39 39 cPickle = None
40 40 import pickle
41 41
42 42 import zmq
43 43 from zmq.utils import jsonapi
44 44 from zmq.eventloop.ioloop import IOLoop
45 45 from zmq.eventloop.zmqstream import ZMQStream
46 46
47 47 from IPython.config.configurable import Configurable, LoggingConfigurable
48 48 from IPython.utils.importstring import import_item
49 49 from IPython.utils.jsonutil import extract_dates, squash_dates, date_default
50 50 from IPython.utils.traitlets import (CBytes, Unicode, Bool, Any, Instance, Set,
51 51 DottedObjectName)
52 52
53 53 #-----------------------------------------------------------------------------
54 54 # utility functions
55 55 #-----------------------------------------------------------------------------
56 56
57 57 def squash_unicode(obj):
58 58 """coerce unicode back to bytestrings."""
59 59 if isinstance(obj,dict):
60 60 for key in obj.keys():
61 61 obj[key] = squash_unicode(obj[key])
62 62 if isinstance(key, unicode):
63 63 obj[squash_unicode(key)] = obj.pop(key)
64 64 elif isinstance(obj, list):
65 65 for i,v in enumerate(obj):
66 66 obj[i] = squash_unicode(v)
67 67 elif isinstance(obj, unicode):
68 68 obj = obj.encode('utf8')
69 69 return obj
70 70
71 71 #-----------------------------------------------------------------------------
72 72 # globals and defaults
73 73 #-----------------------------------------------------------------------------
74 74 key = 'on_unknown' if jsonapi.jsonmod.__name__ == 'jsonlib' else 'default'
75 75 json_packer = lambda obj: jsonapi.dumps(obj, **{key:date_default})
76 76 json_unpacker = lambda s: extract_dates(jsonapi.loads(s))
77 77
78 78 pickle_packer = lambda o: pickle.dumps(o,-1)
79 79 pickle_unpacker = pickle.loads
80 80
81 81 default_packer = json_packer
82 82 default_unpacker = json_unpacker
83 83
84 84
85 85 DELIM=b"<IDS|MSG>"
86 86
87 87 #-----------------------------------------------------------------------------
88 88 # Classes
89 89 #-----------------------------------------------------------------------------
90 90
91 91 class SessionFactory(LoggingConfigurable):
92 92 """The Base class for configurables that have a Session, Context, logger,
93 93 and IOLoop.
94 94 """
95 95
96 96 logname = Unicode('')
97 97 def _logname_changed(self, name, old, new):
98 98 self.log = logging.getLogger(new)
99 99
100 100 # not configurable:
101 101 context = Instance('zmq.Context')
102 102 def _context_default(self):
103 103 return zmq.Context.instance()
104 104
105 105 session = Instance('IPython.zmq.session.Session')
106 106
107 107 loop = Instance('zmq.eventloop.ioloop.IOLoop', allow_none=False)
108 108 def _loop_default(self):
109 109 return IOLoop.instance()
110 110
111 111 def __init__(self, **kwargs):
112 112 super(SessionFactory, self).__init__(**kwargs)
113 113
114 114 if self.session is None:
115 115 # construct the session
116 116 self.session = Session(**kwargs)
117 117
118 118
119 119 class Message(object):
120 120 """A simple message object that maps dict keys to attributes.
121 121
122 122 A Message can be created from a dict and a dict from a Message instance
123 123 simply by calling dict(msg_obj)."""
124 124
125 125 def __init__(self, msg_dict):
126 126 dct = self.__dict__
127 127 for k, v in dict(msg_dict).iteritems():
128 128 if isinstance(v, dict):
129 129 v = Message(v)
130 130 dct[k] = v
131 131
132 132 # Having this iterator lets dict(msg_obj) work out of the box.
133 133 def __iter__(self):
134 134 return iter(self.__dict__.iteritems())
135 135
136 136 def __repr__(self):
137 137 return repr(self.__dict__)
138 138
139 139 def __str__(self):
140 140 return pprint.pformat(self.__dict__)
141 141
142 142 def __contains__(self, k):
143 143 return k in self.__dict__
144 144
145 145 def __getitem__(self, k):
146 146 return self.__dict__[k]
147 147
148 148
149 149 def msg_header(msg_id, msg_type, username, session):
150 150 date = datetime.now()
151 151 return locals()
152 152
153 153 def extract_header(msg_or_header):
154 154 """Given a message or header, return the header."""
155 155 if not msg_or_header:
156 156 return {}
157 157 try:
158 158 # See if msg_or_header is the entire message.
159 159 h = msg_or_header['header']
160 160 except KeyError:
161 161 try:
162 162 # See if msg_or_header is just the header
163 163 h = msg_or_header['msg_id']
164 164 except KeyError:
165 165 raise
166 166 else:
167 167 h = msg_or_header
168 168 if not isinstance(h, dict):
169 169 h = dict(h)
170 170 return h
171 171
172 172 class Session(Configurable):
173 173 """Object for handling serialization and sending of messages.
174 174
175 175 The Session object handles building messages and sending them
176 176 with ZMQ sockets or ZMQStream objects. Objects can communicate with each
177 177 other over the network via Session objects, and only need to work with the
178 178 dict-based IPython message spec. The Session will handle
179 179 serialization/deserialization, security, and metadata.
180 180
181 181 Sessions support configurable serialiization via packer/unpacker traits,
182 182 and signing with HMAC digests via the key/keyfile traits.
183 183
184 184 Parameters
185 185 ----------
186 186
187 187 debug : bool
188 188 whether to trigger extra debugging statements
189 189 packer/unpacker : str : 'json', 'pickle' or import_string
190 190 importstrings for methods to serialize message parts. If just
191 191 'json' or 'pickle', predefined JSON and pickle packers will be used.
192 192 Otherwise, the entire importstring must be used.
193 193
194 194 The functions must accept at least valid JSON input, and output *bytes*.
195 195
196 196 For example, to use msgpack:
197 197 packer = 'msgpack.packb', unpacker='msgpack.unpackb'
198 198 pack/unpack : callables
199 199 You can also set the pack/unpack callables for serialization directly.
200 200 session : bytes
201 201 the ID of this Session object. The default is to generate a new UUID.
202 202 username : unicode
203 203 username added to message headers. The default is to ask the OS.
204 204 key : bytes
205 205 The key used to initialize an HMAC signature. If unset, messages
206 206 will not be signed or checked.
207 207 keyfile : filepath
208 208 The file containing a key. If this is set, `key` will be initialized
209 209 to the contents of the file.
210 210
211 211 """
212 212
213 213 debug=Bool(False, config=True, help="""Debug output in the Session""")
214 214
215 215 packer = DottedObjectName('json',config=True,
216 216 help="""The name of the packer for serializing messages.
217 217 Should be one of 'json', 'pickle', or an import name
218 218 for a custom callable serializer.""")
219 219 def _packer_changed(self, name, old, new):
220 220 if new.lower() == 'json':
221 221 self.pack = json_packer
222 222 self.unpack = json_unpacker
223 223 elif new.lower() == 'pickle':
224 224 self.pack = pickle_packer
225 225 self.unpack = pickle_unpacker
226 226 else:
227 227 self.pack = import_item(str(new))
228 228
229 229 unpacker = DottedObjectName('json', config=True,
230 230 help="""The name of the unpacker for unserializing messages.
231 231 Only used with custom functions for `packer`.""")
232 232 def _unpacker_changed(self, name, old, new):
233 233 if new.lower() == 'json':
234 234 self.pack = json_packer
235 235 self.unpack = json_unpacker
236 236 elif new.lower() == 'pickle':
237 237 self.pack = pickle_packer
238 238 self.unpack = pickle_unpacker
239 239 else:
240 240 self.unpack = import_item(str(new))
241 241
242 242 session = CBytes(b'', config=True,
243 243 help="""The UUID identifying this session.""")
244 244 def _session_default(self):
245 245 return bytes(uuid.uuid4())
246 246
247 username = Unicode(os.environ.get('USER','username'), config=True,
247 username = Unicode(os.environ.get('USER',u'username'), config=True,
248 248 help="""Username for the Session. Default is your system username.""")
249 249
250 250 # message signature related traits:
251 251 key = CBytes(b'', config=True,
252 252 help="""execution key, for extra authentication.""")
253 253 def _key_changed(self, name, old, new):
254 254 if new:
255 255 self.auth = hmac.HMAC(new)
256 256 else:
257 257 self.auth = None
258 258 auth = Instance(hmac.HMAC)
259 259 digest_history = Set()
260 260
261 261 keyfile = Unicode('', config=True,
262 262 help="""path to file containing execution key.""")
263 263 def _keyfile_changed(self, name, old, new):
264 264 with open(new, 'rb') as f:
265 265 self.key = f.read().strip()
266 266
267 267 pack = Any(default_packer) # the actual packer function
268 268 def _pack_changed(self, name, old, new):
269 269 if not callable(new):
270 270 raise TypeError("packer must be callable, not %s"%type(new))
271 271
272 272 unpack = Any(default_unpacker) # the actual packer function
273 273 def _unpack_changed(self, name, old, new):
274 274 # unpacker is not checked - it is assumed to be
275 275 if not callable(new):
276 276 raise TypeError("unpacker must be callable, not %s"%type(new))
277 277
278 278 def __init__(self, **kwargs):
279 279 """create a Session object
280 280
281 281 Parameters
282 282 ----------
283 283
284 284 debug : bool
285 285 whether to trigger extra debugging statements
286 286 packer/unpacker : str : 'json', 'pickle' or import_string
287 287 importstrings for methods to serialize message parts. If just
288 288 'json' or 'pickle', predefined JSON and pickle packers will be used.
289 289 Otherwise, the entire importstring must be used.
290 290
291 291 The functions must accept at least valid JSON input, and output
292 292 *bytes*.
293 293
294 294 For example, to use msgpack:
295 295 packer = 'msgpack.packb', unpacker='msgpack.unpackb'
296 296 pack/unpack : callables
297 297 You can also set the pack/unpack callables for serialization
298 298 directly.
299 299 session : bytes
300 300 the ID of this Session object. The default is to generate a new
301 301 UUID.
302 302 username : unicode
303 303 username added to message headers. The default is to ask the OS.
304 304 key : bytes
305 305 The key used to initialize an HMAC signature. If unset, messages
306 306 will not be signed or checked.
307 307 keyfile : filepath
308 308 The file containing a key. If this is set, `key` will be
309 309 initialized to the contents of the file.
310 310 """
311 311 super(Session, self).__init__(**kwargs)
312 312 self._check_packers()
313 313 self.none = self.pack({})
314 314
315 315 @property
316 316 def msg_id(self):
317 317 """always return new uuid"""
318 318 return str(uuid.uuid4())
319 319
320 320 def _check_packers(self):
321 321 """check packers for binary data and datetime support."""
322 322 pack = self.pack
323 323 unpack = self.unpack
324 324
325 325 # check simple serialization
326 326 msg = dict(a=[1,'hi'])
327 327 try:
328 328 packed = pack(msg)
329 329 except Exception:
330 330 raise ValueError("packer could not serialize a simple message")
331 331
332 332 # ensure packed message is bytes
333 333 if not isinstance(packed, bytes):
334 334 raise ValueError("message packed to %r, but bytes are required"%type(packed))
335 335
336 336 # check that unpack is pack's inverse
337 337 try:
338 338 unpacked = unpack(packed)
339 339 except Exception:
340 340 raise ValueError("unpacker could not handle the packer's output")
341 341
342 342 # check datetime support
343 343 msg = dict(t=datetime.now())
344 344 try:
345 345 unpacked = unpack(pack(msg))
346 346 except Exception:
347 347 self.pack = lambda o: pack(squash_dates(o))
348 348 self.unpack = lambda s: extract_dates(unpack(s))
349 349
350 350 def msg_header(self, msg_type):
351 351 return msg_header(self.msg_id, msg_type, self.username, self.session)
352 352
353 def msg(self, msg_type, content=None, parent=None, subheader=None):
353 def msg(self, msg_type, content=None, parent=None, subheader=None, header=None):
354 354 """Return the nested message dict.
355 355
356 356 This format is different from what is sent over the wire. The
357 self.serialize method converts this nested message dict to the wire
358 format, which uses a message list.
357 serialize/unserialize methods converts this nested message dict to the wire
358 format, which is a list of message parts.
359 359 """
360 360 msg = {}
361 msg['header'] = self.msg_header(msg_type)
362 msg['msg_id'] = msg['header']['msg_id']
361 msg['header'] = self.msg_header(msg_type) if header is None else header
363 362 msg['parent_header'] = {} if parent is None else extract_header(parent)
364 msg['msg_type'] = msg_type
365 363 msg['content'] = {} if content is None else content
366 364 sub = {} if subheader is None else subheader
367 365 msg['header'].update(sub)
368 366 return msg
369 367
370 368 def sign(self, msg_list):
371 369 """Sign a message with HMAC digest. If no auth, return b''.
372 370
373 371 Parameters
374 372 ----------
375 373 msg_list : list
376 374 The [p_header,p_parent,p_content] part of the message list.
377 375 """
378 376 if self.auth is None:
379 377 return b''
380 378 h = self.auth.copy()
381 379 for m in msg_list:
382 380 h.update(m)
383 381 return h.hexdigest()
384 382
385 383 def serialize(self, msg, ident=None):
386 384 """Serialize the message components to bytes.
387 385
386 This is roughly the inverse of unserialize. The serialize/unserialize
387 methods work with full message lists, whereas pack/unpack work with
388 the individual message parts in the message list.
389
388 390 Parameters
389 391 ----------
390 392 msg : dict or Message
391 393 The nexted message dict as returned by the self.msg method.
392 394
393 395 Returns
394 396 -------
395 397 msg_list : list
396 398 The list of bytes objects to be sent with the format:
397 399 [ident1,ident2,...,DELIM,HMAC,p_header,p_parent,p_content,
398 400 buffer1,buffer2,...]. In this list, the p_* entities are
399 401 the packed or serialized versions, so if JSON is used, these
400 402 are uft8 encoded JSON strings.
401 403 """
402 404 content = msg.get('content', {})
403 405 if content is None:
404 406 content = self.none
405 407 elif isinstance(content, dict):
406 408 content = self.pack(content)
407 409 elif isinstance(content, bytes):
408 410 # content is already packed, as in a relayed message
409 411 pass
410 412 elif isinstance(content, unicode):
411 413 # should be bytes, but JSON often spits out unicode
412 414 content = content.encode('utf8')
413 415 else:
414 416 raise TypeError("Content incorrect type: %s"%type(content))
415 417
416 418 real_message = [self.pack(msg['header']),
417 419 self.pack(msg['parent_header']),
418 420 content
419 421 ]
420 422
421 423 to_send = []
422 424
423 425 if isinstance(ident, list):
424 426 # accept list of idents
425 427 to_send.extend(ident)
426 428 elif ident is not None:
427 429 to_send.append(ident)
428 430 to_send.append(DELIM)
429 431
430 432 signature = self.sign(real_message)
431 433 to_send.append(signature)
432 434
433 435 to_send.extend(real_message)
434 436
435 437 return to_send
436 438
437 def send(self, stream, msg_or_type, content=None, parent=None, ident=None,
438 buffers=None, subheader=None, track=False):
439 def send(self, stream, msg_or_type, content=None, parent=None, ident=None,
440 buffers=None, subheader=None, track=False, header=None):
439 441 """Build and send a message via stream or socket.
440 442
441 443 The message format used by this function internally is as follows:
442 444
443 445 [ident1,ident2,...,DELIM,HMAC,p_header,p_parent,p_content,
444 446 buffer1,buffer2,...]
445 447
446 The self.serialize method converts the nested message dict into this
448 The serialize/unserialize methods convert the nested message dict into this
447 449 format.
448 450
449 451 Parameters
450 452 ----------
451 453
452 454 stream : zmq.Socket or ZMQStream
453 the socket-like object used to send the data
455 The socket-like object used to send the data.
454 456 msg_or_type : str or Message/dict
455 457 Normally, msg_or_type will be a msg_type unless a message is being
456 sent more than once.
458 sent more than once. If a header is supplied, this can be set to
459 None and the msg_type will be pulled from the header.
457 460
458 461 content : dict or None
459 the content of the message (ignored if msg_or_type is a message)
462 The content of the message (ignored if msg_or_type is a message).
463 header : dict or None
464 The header dict for the message (ignores if msg_to_type is a message).
460 465 parent : Message or dict or None
461 the parent or parent header describing the parent of this message
466 The parent or parent header describing the parent of this message
467 (ignored if msg_or_type is a message).
462 468 ident : bytes or list of bytes
463 the zmq.IDENTITY routing path
469 The zmq.IDENTITY routing path.
464 470 subheader : dict or None
465 extra header keys for this message's header
471 Extra header keys for this message's header (ignored if msg_or_type
472 is a message).
466 473 buffers : list or None
467 the already-serialized buffers to be appended to the message
474 The already-serialized buffers to be appended to the message.
468 475 track : bool
469 whether to track. Only for use with Sockets,
470 because ZMQStream objects cannot track messages.
476 Whether to track. Only for use with Sockets, because ZMQStream
477 objects cannot track messages.
471 478
472 479 Returns
473 480 -------
474 msg : message dict
475 the constructed message
476 (msg,tracker) : (message dict, MessageTracker)
481 msg : dict
482 The constructed message.
483 (msg,tracker) : (dict, MessageTracker)
477 484 if track=True, then a 2-tuple will be returned,
478 485 the first element being the constructed
479 486 message, and the second being the MessageTracker
480 487
481 488 """
482 489
483 490 if not isinstance(stream, (zmq.Socket, ZMQStream)):
484 491 raise TypeError("stream must be Socket or ZMQStream, not %r"%type(stream))
485 492 elif track and isinstance(stream, ZMQStream):
486 493 raise TypeError("ZMQStream cannot track messages")
487 494
488 495 if isinstance(msg_or_type, (Message, dict)):
489 # we got a Message, not a msg_type
490 # don't build a new Message
496 # We got a Message or message dict, not a msg_type so don't
497 # build a new Message.
491 498 msg = msg_or_type
492 499 else:
493 msg = self.msg(msg_or_type, content, parent, subheader)
494
500 msg = self.msg(msg_or_type, content=content, parent=parent,
501 subheader=subheader, header=header)
502
495 503 buffers = [] if buffers is None else buffers
496 504 to_send = self.serialize(msg, ident)
497 505 flag = 0
498 506 if buffers:
499 507 flag = zmq.SNDMORE
500 508 _track = False
501 509 else:
502 510 _track=track
503 511 if track:
504 512 tracker = stream.send_multipart(to_send, flag, copy=False, track=_track)
505 513 else:
506 514 tracker = stream.send_multipart(to_send, flag, copy=False)
507 515 for b in buffers[:-1]:
508 516 stream.send(b, flag, copy=False)
509 517 if buffers:
510 518 if track:
511 519 tracker = stream.send(buffers[-1], copy=False, track=track)
512 520 else:
513 521 tracker = stream.send(buffers[-1], copy=False)
514 522
515 523 # omsg = Message(msg)
516 524 if self.debug:
517 525 pprint.pprint(msg)
518 526 pprint.pprint(to_send)
519 527 pprint.pprint(buffers)
520 528
521 529 msg['tracker'] = tracker
522 530
523 531 return msg
524
532
525 533 def send_raw(self, stream, msg_list, flags=0, copy=True, ident=None):
526 534 """Send a raw message via ident path.
527 535
528 536 This method is used to send a already serialized message.
529 537
530 538 Parameters
531 539 ----------
532 540 stream : ZMQStream or Socket
533 541 The ZMQ stream or socket to use for sending the message.
534 542 msg_list : list
535 543 The serialized list of messages to send. This only includes the
536 544 [p_header,p_parent,p_content,buffer1,buffer2,...] portion of
537 545 the message.
538 546 ident : ident or list
539 547 A single ident or a list of idents to use in sending.
540 548 """
541 549 to_send = []
542 550 if isinstance(ident, bytes):
543 551 ident = [ident]
544 552 if ident is not None:
545 553 to_send.extend(ident)
546
554
547 555 to_send.append(DELIM)
548 556 to_send.append(self.sign(msg_list))
549 557 to_send.extend(msg_list)
550 558 stream.send_multipart(msg_list, flags, copy=copy)
551 559
552 560 def recv(self, socket, mode=zmq.NOBLOCK, content=True, copy=True):
553 561 """Receive and unpack a message.
554 562
555 563 Parameters
556 564 ----------
557 565 socket : ZMQStream or Socket
558 566 The socket or stream to use in receiving.
559 567
560 568 Returns
561 569 -------
562 570 [idents], msg
563 571 [idents] is a list of idents and msg is a nested message dict of
564 572 same format as self.msg returns.
565 573 """
566 574 if isinstance(socket, ZMQStream):
567 575 socket = socket.socket
568 576 try:
569 577 msg_list = socket.recv_multipart(mode)
570 578 except zmq.ZMQError as e:
571 579 if e.errno == zmq.EAGAIN:
572 580 # We can convert EAGAIN to None as we know in this case
573 581 # recv_multipart won't return None.
574 582 return None,None
575 583 else:
576 584 raise
577 585 # split multipart message into identity list and message dict
578 586 # invalid large messages can cause very expensive string comparisons
579 587 idents, msg_list = self.feed_identities(msg_list, copy)
580 588 try:
581 return idents, self.unpack_message(msg_list, content=content, copy=copy)
589 return idents, self.unserialize(msg_list, content=content, copy=copy)
582 590 except Exception as e:
583 591 print (idents, msg_list)
584 592 # TODO: handle it
585 593 raise e
586 594
587 595 def feed_identities(self, msg_list, copy=True):
588 596 """Split the identities from the rest of the message.
589 597
590 598 Feed until DELIM is reached, then return the prefix as idents and
591 599 remainder as msg_list. This is easily broken by setting an IDENT to DELIM,
592 600 but that would be silly.
593 601
594 602 Parameters
595 603 ----------
596 604 msg_list : a list of Message or bytes objects
597 605 The message to be split.
598 606 copy : bool
599 607 flag determining whether the arguments are bytes or Messages
600 608
601 609 Returns
602 610 -------
603 (idents,msg_list) : two lists
604 idents will always be a list of bytes - the indentity prefix
605 msg_list will be a list of bytes or Messages, unchanged from input
606 msg_list should be unpackable via self.unpack_message at this point.
611 (idents, msg_list) : two lists
612 idents will always be a list of bytes, each of which is a ZMQ
613 identity. msg_list will be a list of bytes or zmq.Messages of the
614 form [HMAC,p_header,p_parent,p_content,buffer1,buffer2,...] and
615 should be unpackable/unserializable via self.unserialize at this
616 point.
607 617 """
608 618 if copy:
609 619 idx = msg_list.index(DELIM)
610 620 return msg_list[:idx], msg_list[idx+1:]
611 621 else:
612 622 failed = True
613 623 for idx,m in enumerate(msg_list):
614 624 if m.bytes == DELIM:
615 625 failed = False
616 626 break
617 627 if failed:
618 628 raise ValueError("DELIM not in msg_list")
619 629 idents, msg_list = msg_list[:idx], msg_list[idx+1:]
620 630 return [m.bytes for m in idents], msg_list
621 631
622 def unpack_message(self, msg_list, content=True, copy=True):
623 """Return a message object from the format
624 sent by self.send.
625
632 def unserialize(self, msg_list, content=True, copy=True):
633 """Unserialize a msg_list to a nested message dict.
634
635 This is roughly the inverse of serialize. The serialize/unserialize
636 methods work with full message lists, whereas pack/unpack work with
637 the individual message parts in the message list.
638
626 639 Parameters:
627 640 -----------
628
641 msg_list : list of bytes or Message objects
642 The list of message parts of the form [HMAC,p_header,p_parent,
643 p_content,buffer1,buffer2,...].
629 644 content : bool (True)
630 whether to unpack the content dict (True),
631 or leave it serialized (False)
632
645 Whether to unpack the content dict (True), or leave it packed
646 (False).
633 647 copy : bool (True)
634 whether to return the bytes (True),
635 or the non-copying Message object in each place (False)
636
648 Whether to return the bytes (True), or the non-copying Message
649 object in each place (False).
650
651 Returns
652 -------
653 msg : dict
654 The nested message dict with top-level keys [header, parent_header,
655 content, buffers].
637 656 """
638 657 minlen = 4
639 658 message = {}
640 659 if not copy:
641 660 for i in range(minlen):
642 661 msg_list[i] = msg_list[i].bytes
643 662 if self.auth is not None:
644 663 signature = msg_list[0]
645 664 if signature in self.digest_history:
646 665 raise ValueError("Duplicate Signature: %r"%signature)
647 666 self.digest_history.add(signature)
648 667 check = self.sign(msg_list[1:4])
649 668 if not signature == check:
650 669 raise ValueError("Invalid Signature: %r"%signature)
651 670 if not len(msg_list) >= minlen:
652 671 raise TypeError("malformed message, must have at least %i elements"%minlen)
653 672 message['header'] = self.unpack(msg_list[1])
654 message['msg_type'] = message['header']['msg_type']
655 673 message['parent_header'] = self.unpack(msg_list[2])
656 674 if content:
657 675 message['content'] = self.unpack(msg_list[3])
658 676 else:
659 677 message['content'] = msg_list[3]
660 678
661 679 message['buffers'] = msg_list[4:]
662 680 return message
663 681
664 682 def test_msg2obj():
665 683 am = dict(x=1)
666 684 ao = Message(am)
667 685 assert ao.x == am['x']
668 686
669 687 am['y'] = dict(z=1)
670 688 ao = Message(am)
671 689 assert ao.y.z == am['y']['z']
672 690
673 691 k1, k2 = 'y', 'z'
674 692 assert ao[k1][k2] == am[k1][k2]
675 693
676 694 am2 = dict(ao)
677 695 assert am['x'] == am2['x']
678 696 assert am['y']['z'] == am2['y']['z']
679 697
@@ -1,111 +1,177 b''
1 1 """test building messages with streamsession"""
2 2
3 3 #-------------------------------------------------------------------------------
4 4 # Copyright (C) 2011 The IPython Development Team
5 5 #
6 6 # Distributed under the terms of the BSD License. The full license is in
7 7 # the file COPYING, distributed as part of this software.
8 8 #-------------------------------------------------------------------------------
9 9
10 10 #-------------------------------------------------------------------------------
11 11 # Imports
12 12 #-------------------------------------------------------------------------------
13 13
14 14 import os
15 15 import uuid
16 16 import zmq
17 17
18 18 from zmq.tests import BaseZMQTestCase
19 19 from zmq.eventloop.zmqstream import ZMQStream
20 20
21 21 from IPython.zmq import session as ss
22 22
23 23 class SessionTestCase(BaseZMQTestCase):
24 24
25 25 def setUp(self):
26 26 BaseZMQTestCase.setUp(self)
27 27 self.session = ss.Session()
28 28
29
30 class MockSocket(zmq.Socket):
31
32 def __init__(self, *args, **kwargs):
33 super(MockSocket,self).__init__(*args,**kwargs)
34 self.data = []
35
36 def send_multipart(self, msgparts, *args, **kwargs):
37 self.data.extend(msgparts)
38
39 def send(self, part, *args, **kwargs):
40 self.data.append(part)
41
42 def recv_multipart(self, *args, **kwargs):
43 return self.data
44
29 45 class TestSession(SessionTestCase):
30 46
31 47 def test_msg(self):
32 48 """message format"""
33 49 msg = self.session.msg('execute')
34 thekeys = set('header msg_id parent_header msg_type content'.split())
50 thekeys = set('header parent_header content'.split())
35 51 s = set(msg.keys())
36 52 self.assertEquals(s, thekeys)
37 53 self.assertTrue(isinstance(msg['content'],dict))
38 54 self.assertTrue(isinstance(msg['header'],dict))
39 55 self.assertTrue(isinstance(msg['parent_header'],dict))
40 self.assertEquals(msg['msg_type'], 'execute')
41
42
43
56 self.assertEquals(msg['header']['msg_type'], 'execute')
57
58 def test_serialize(self):
59 msg = self.session.msg('execute',content=dict(a=10))
60 msg_list = self.session.serialize(msg, ident=b'foo')
61 ident, msg_list = self.session.feed_identities(msg_list)
62 new_msg = self.session.unserialize(msg_list)
63 self.assertEquals(ident[0], b'foo')
64 self.assertEquals(new_msg['header'],msg['header'])
65 self.assertEquals(new_msg['content'],msg['content'])
66 self.assertEquals(new_msg['parent_header'],msg['parent_header'])
67
68 def test_send(self):
69 socket = MockSocket(zmq.Context.instance(),zmq.PAIR)
70
71 msg = self.session.msg('execute', content=dict(a=10))
72 self.session.send(socket, msg, ident=b'foo', buffers=[b'bar'])
73 ident, msg_list = self.session.feed_identities(socket.data)
74 new_msg = self.session.unserialize(msg_list)
75 self.assertEquals(ident[0], b'foo')
76 self.assertEquals(new_msg['header'],msg['header'])
77 self.assertEquals(new_msg['content'],msg['content'])
78 self.assertEquals(new_msg['parent_header'],msg['parent_header'])
79 self.assertEquals(new_msg['buffers'],[b'bar'])
80
81 socket.data = []
82
83 content = msg['content']
84 header = msg['header']
85 parent = msg['parent_header']
86 msg_type = header['msg_type']
87 self.session.send(socket, None, content=content, parent=parent,
88 header=header, ident=b'foo', buffers=[b'bar'])
89 ident, msg_list = self.session.feed_identities(socket.data)
90 new_msg = self.session.unserialize(msg_list)
91 self.assertEquals(ident[0], b'foo')
92 self.assertEquals(new_msg['header'],msg['header'])
93 self.assertEquals(new_msg['content'],msg['content'])
94 self.assertEquals(new_msg['parent_header'],msg['parent_header'])
95 self.assertEquals(new_msg['buffers'],[b'bar'])
96
97 socket.data = []
98
99 self.session.send(socket, msg, ident=b'foo', buffers=[b'bar'])
100 ident, new_msg = self.session.recv(socket)
101 self.assertEquals(ident[0], b'foo')
102 self.assertEquals(new_msg['header'],msg['header'])
103 self.assertEquals(new_msg['content'],msg['content'])
104 self.assertEquals(new_msg['parent_header'],msg['parent_header'])
105 self.assertEquals(new_msg['buffers'],[b'bar'])
106
107 socket.close()
108
44 109 def test_args(self):
45 110 """initialization arguments for Session"""
46 111 s = self.session
47 112 self.assertTrue(s.pack is ss.default_packer)
48 113 self.assertTrue(s.unpack is ss.default_unpacker)
49 self.assertEquals(s.username, os.environ.get('USER', 'username'))
114 self.assertEquals(s.username, os.environ.get('USER', u'username'))
50 115
51 116 s = ss.Session()
52 self.assertEquals(s.username, os.environ.get('USER', 'username'))
117 self.assertEquals(s.username, os.environ.get('USER', u'username'))
53 118
54 119 self.assertRaises(TypeError, ss.Session, pack='hi')
55 120 self.assertRaises(TypeError, ss.Session, unpack='hi')
56 121 u = str(uuid.uuid4())
57 s = ss.Session(username='carrot', session=u)
122 s = ss.Session(username=u'carrot', session=u)
58 123 self.assertEquals(s.session, u)
59 self.assertEquals(s.username, 'carrot')
124 self.assertEquals(s.username, u'carrot')
60 125
61 126 def test_tracking(self):
62 127 """test tracking messages"""
63 128 a,b = self.create_bound_pair(zmq.PAIR, zmq.PAIR)
64 129 s = self.session
65 130 stream = ZMQStream(a)
66 131 msg = s.send(a, 'hello', track=False)
67 132 self.assertTrue(msg['tracker'] is None)
68 133 msg = s.send(a, 'hello', track=True)
69 134 self.assertTrue(isinstance(msg['tracker'], zmq.MessageTracker))
70 135 M = zmq.Message(b'hi there', track=True)
71 136 msg = s.send(a, 'hello', buffers=[M], track=True)
72 137 t = msg['tracker']
73 138 self.assertTrue(isinstance(t, zmq.MessageTracker))
74 139 self.assertRaises(zmq.NotDone, t.wait, .1)
75 140 del M
76 141 t.wait(1) # this will raise
77 142
78 143
79 144 # def test_rekey(self):
80 145 # """rekeying dict around json str keys"""
81 146 # d = {'0': uuid.uuid4(), 0:uuid.uuid4()}
82 147 # self.assertRaises(KeyError, ss.rekey, d)
83 148 #
84 149 # d = {'0': uuid.uuid4(), 1:uuid.uuid4(), 'asdf':uuid.uuid4()}
85 150 # d2 = {0:d['0'],1:d[1],'asdf':d['asdf']}
86 151 # rd = ss.rekey(d)
87 152 # self.assertEquals(d2,rd)
88 153 #
89 154 # d = {'1.5':uuid.uuid4(),'1':uuid.uuid4()}
90 155 # d2 = {1.5:d['1.5'],1:d['1']}
91 156 # rd = ss.rekey(d)
92 157 # self.assertEquals(d2,rd)
93 158 #
94 159 # d = {'1.0':uuid.uuid4(),'1':uuid.uuid4()}
95 160 # self.assertRaises(KeyError, ss.rekey, d)
96 161 #
97 162 def test_unique_msg_ids(self):
98 163 """test that messages receive unique ids"""
99 164 ids = set()
100 165 for i in range(2**12):
101 166 h = self.session.msg_header('test')
102 167 msg_id = h['msg_id']
103 168 self.assertTrue(msg_id not in ids)
104 169 ids.add(msg_id)
105 170
106 171 def test_feed_identities(self):
107 172 """scrub the front for zmq IDENTITIES"""
108 173 theids = "engine client other".split()
109 174 content = dict(code='whoda',stuff=object())
110 175 themsg = self.session.msg('execute',content=content)
111 176 pmsg = theids
177
@@ -1,937 +1,937 b''
1 1 .. _messaging:
2 2
3 3 ======================
4 4 Messaging in IPython
5 5 ======================
6 6
7 7
8 8 Introduction
9 9 ============
10 10
11 11 This document explains the basic communications design and messaging
12 12 specification for how the various IPython objects interact over a network
13 13 transport. The current implementation uses the ZeroMQ_ library for messaging
14 14 within and between hosts.
15 15
16 16 .. Note::
17 17
18 18 This document should be considered the authoritative description of the
19 19 IPython messaging protocol, and all developers are strongly encouraged to
20 20 keep it updated as the implementation evolves, so that we have a single
21 21 common reference for all protocol details.
22 22
23 23 The basic design is explained in the following diagram:
24 24
25 25 .. image:: figs/frontend-kernel.png
26 26 :width: 450px
27 27 :alt: IPython kernel/frontend messaging architecture.
28 28 :align: center
29 29 :target: ../_images/frontend-kernel.png
30 30
31 31 A single kernel can be simultaneously connected to one or more frontends. The
32 32 kernel has three sockets that serve the following functions:
33 33
34 34 1. REQ: this socket is connected to a *single* frontend at a time, and it allows
35 35 the kernel to request input from a frontend when :func:`raw_input` is called.
36 36 The frontend holding the matching REP socket acts as a 'virtual keyboard'
37 37 for the kernel while this communication is happening (illustrated in the
38 38 figure by the black outline around the central keyboard). In practice,
39 39 frontends may display such kernel requests using a special input widget or
40 40 otherwise indicating that the user is to type input for the kernel instead
41 41 of normal commands in the frontend.
42 42
43 43 2. XREP: this single sockets allows multiple incoming connections from
44 44 frontends, and this is the socket where requests for code execution, object
45 45 information, prompts, etc. are made to the kernel by any frontend. The
46 46 communication on this socket is a sequence of request/reply actions from
47 47 each frontend and the kernel.
48 48
49 49 3. PUB: this socket is the 'broadcast channel' where the kernel publishes all
50 50 side effects (stdout, stderr, etc.) as well as the requests coming from any
51 51 client over the XREP socket and its own requests on the REP socket. There
52 52 are a number of actions in Python which generate side effects: :func:`print`
53 53 writes to ``sys.stdout``, errors generate tracebacks, etc. Additionally, in
54 54 a multi-client scenario, we want all frontends to be able to know what each
55 55 other has sent to the kernel (this can be useful in collaborative scenarios,
56 56 for example). This socket allows both side effects and the information
57 57 about communications taking place with one client over the XREQ/XREP channel
58 58 to be made available to all clients in a uniform manner.
59 59
60 60 All messages are tagged with enough information (details below) for clients
61 61 to know which messages come from their own interaction with the kernel and
62 62 which ones are from other clients, so they can display each type
63 63 appropriately.
64 64
65 65 The actual format of the messages allowed on each of these channels is
66 66 specified below. Messages are dicts of dicts with string keys and values that
67 67 are reasonably representable in JSON. Our current implementation uses JSON
68 68 explicitly as its message format, but this shouldn't be considered a permanent
69 69 feature. As we've discovered that JSON has non-trivial performance issues due
70 70 to excessive copying, we may in the future move to a pure pickle-based raw
71 71 message format. However, it should be possible to easily convert from the raw
72 72 objects to JSON, since we may have non-python clients (e.g. a web frontend).
73 73 As long as it's easy to make a JSON version of the objects that is a faithful
74 74 representation of all the data, we can communicate with such clients.
75 75
76 76 .. Note::
77 77
78 78 Not all of these have yet been fully fleshed out, but the key ones are, see
79 79 kernel and frontend files for actual implementation details.
80 80
81 81
82 82 Python functional API
83 83 =====================
84 84
85 85 As messages are dicts, they map naturally to a ``func(**kw)`` call form. We
86 86 should develop, at a few key points, functional forms of all the requests that
87 87 take arguments in this manner and automatically construct the necessary dict
88 88 for sending.
89 89
90 90
91 91 General Message Format
92 92 ======================
93 93
94 94 All messages send or received by any IPython process should have the following
95 95 generic structure::
96 96
97 97 {
98 98 # The message header contains a pair of unique identifiers for the
99 99 # originating session and the actual message id, in addition to the
100 100 # username for the process that generated the message. This is useful in
101 101 # collaborative settings where multiple users may be interacting with the
102 102 # same kernel simultaneously, so that frontends can label the various
103 103 # messages in a meaningful way.
104 'header' : { 'msg_id' : uuid,
105 'username' : str,
106 'session' : uuid
104 'header' : {
105 'msg_id' : uuid,
106 'username' : str,
107 'session' : uuid
108 # All recognized message type strings are listed below.
109 'msg_type' : str,
107 110 },
108 111
109 112 # In a chain of messages, the header from the parent is copied so that
110 113 # clients can track where messages come from.
111 114 'parent_header' : dict,
112 115
113 # All recognized message type strings are listed below.
114 'msg_type' : str,
115
116 116 # The actual content of the message must be a dict, whose structure
117 117 # depends on the message type.x
118 118 'content' : dict,
119 119 }
120 120
121 121 For each message type, the actual content will differ and all existing message
122 122 types are specified in what follows of this document.
123 123
124 124
125 125 Messages on the XREP/XREQ socket
126 126 ================================
127 127
128 128 .. _execute:
129 129
130 130 Execute
131 131 -------
132 132
133 133 This message type is used by frontends to ask the kernel to execute code on
134 134 behalf of the user, in a namespace reserved to the user's variables (and thus
135 135 separate from the kernel's own internal code and variables).
136 136
137 137 Message type: ``execute_request``::
138 138
139 139 content = {
140 140 # Source code to be executed by the kernel, one or more lines.
141 141 'code' : str,
142 142
143 143 # A boolean flag which, if True, signals the kernel to execute this
144 144 # code as quietly as possible. This means that the kernel will compile
145 145 # the code witIPython/core/tests/h 'exec' instead of 'single' (so
146 146 # sys.displayhook will not fire), and will *not*:
147 147 # - broadcast exceptions on the PUB socket
148 148 # - do any logging
149 149 # - populate any history
150 150 #
151 151 # The default is False.
152 152 'silent' : bool,
153 153
154 154 # A list of variable names from the user's namespace to be retrieved. What
155 155 # returns is a JSON string of the variable's repr(), not a python object.
156 156 'user_variables' : list,
157 157
158 158 # Similarly, a dict mapping names to expressions to be evaluated in the
159 159 # user's dict.
160 160 'user_expressions' : dict,
161 161 }
162 162
163 163 The ``code`` field contains a single string (possibly multiline). The kernel
164 164 is responsible for splitting this into one or more independent execution blocks
165 165 and deciding whether to compile these in 'single' or 'exec' mode (see below for
166 166 detailed execution semantics).
167 167
168 168 The ``user_`` fields deserve a detailed explanation. In the past, IPython had
169 169 the notion of a prompt string that allowed arbitrary code to be evaluated, and
170 170 this was put to good use by many in creating prompts that displayed system
171 171 status, path information, and even more esoteric uses like remote instrument
172 172 status aqcuired over the network. But now that IPython has a clean separation
173 173 between the kernel and the clients, the kernel has no prompt knowledge; prompts
174 174 are a frontend-side feature, and it should be even possible for different
175 175 frontends to display different prompts while interacting with the same kernel.
176 176
177 177 The kernel now provides the ability to retrieve data from the user's namespace
178 178 after the execution of the main ``code``, thanks to two fields in the
179 179 ``execute_request`` message:
180 180
181 181 - ``user_variables``: If only variables from the user's namespace are needed, a
182 182 list of variable names can be passed and a dict with these names as keys and
183 183 their :func:`repr()` as values will be returned.
184 184
185 185 - ``user_expressions``: For more complex expressions that require function
186 186 evaluations, a dict can be provided with string keys and arbitrary python
187 187 expressions as values. The return message will contain also a dict with the
188 188 same keys and the :func:`repr()` of the evaluated expressions as value.
189 189
190 190 With this information, frontends can display any status information they wish
191 191 in the form that best suits each frontend (a status line, a popup, inline for a
192 192 terminal, etc).
193 193
194 194 .. Note::
195 195
196 196 In order to obtain the current execution counter for the purposes of
197 197 displaying input prompts, frontends simply make an execution request with an
198 198 empty code string and ``silent=True``.
199 199
200 200 Execution semantics
201 201 ~~~~~~~~~~~~~~~~~~~
202 202
203 203 When the silent flag is false, the execution of use code consists of the
204 204 following phases (in silent mode, only the ``code`` field is executed):
205 205
206 206 1. Run the ``pre_runcode_hook``.
207 207
208 208 2. Execute the ``code`` field, see below for details.
209 209
210 210 3. If #2 succeeds, compute ``user_variables`` and ``user_expressions`` are
211 211 computed. This ensures that any error in the latter don't harm the main
212 212 code execution.
213 213
214 214 4. Call any method registered with :meth:`register_post_execute`.
215 215
216 216 .. warning::
217 217
218 218 The API for running code before/after the main code block is likely to
219 219 change soon. Both the ``pre_runcode_hook`` and the
220 220 :meth:`register_post_execute` are susceptible to modification, as we find a
221 221 consistent model for both.
222 222
223 223 To understand how the ``code`` field is executed, one must know that Python
224 224 code can be compiled in one of three modes (controlled by the ``mode`` argument
225 225 to the :func:`compile` builtin):
226 226
227 227 *single*
228 228 Valid for a single interactive statement (though the source can contain
229 229 multiple lines, such as a for loop). When compiled in this mode, the
230 230 generated bytecode contains special instructions that trigger the calling of
231 231 :func:`sys.displayhook` for any expression in the block that returns a value.
232 232 This means that a single statement can actually produce multiple calls to
233 233 :func:`sys.displayhook`, if for example it contains a loop where each
234 234 iteration computes an unassigned expression would generate 10 calls::
235 235
236 236 for i in range(10):
237 237 i**2
238 238
239 239 *exec*
240 240 An arbitrary amount of source code, this is how modules are compiled.
241 241 :func:`sys.displayhook` is *never* implicitly called.
242 242
243 243 *eval*
244 244 A single expression that returns a value. :func:`sys.displayhook` is *never*
245 245 implicitly called.
246 246
247 247
248 248 The ``code`` field is split into individual blocks each of which is valid for
249 249 execution in 'single' mode, and then:
250 250
251 251 - If there is only a single block: it is executed in 'single' mode.
252 252
253 253 - If there is more than one block:
254 254
255 255 * if the last one is a single line long, run all but the last in 'exec' mode
256 256 and the very last one in 'single' mode. This makes it easy to type simple
257 257 expressions at the end to see computed values.
258 258
259 259 * if the last one is no more than two lines long, run all but the last in
260 260 'exec' mode and the very last one in 'single' mode. This makes it easy to
261 261 type simple expressions at the end to see computed values. - otherwise
262 262 (last one is also multiline), run all in 'exec' mode
263 263
264 264 * otherwise (last one is also multiline), run all in 'exec' mode as a single
265 265 unit.
266 266
267 267 Any error in retrieving the ``user_variables`` or evaluating the
268 268 ``user_expressions`` will result in a simple error message in the return fields
269 269 of the form::
270 270
271 271 [ERROR] ExceptionType: Exception message
272 272
273 273 The user can simply send the same variable name or expression for evaluation to
274 274 see a regular traceback.
275 275
276 276 Errors in any registered post_execute functions are also reported similarly,
277 277 and the failing function is removed from the post_execution set so that it does
278 278 not continue triggering failures.
279 279
280 280 Upon completion of the execution request, the kernel *always* sends a reply,
281 281 with a status code indicating what happened and additional data depending on
282 282 the outcome. See :ref:`below <execution_results>` for the possible return
283 283 codes and associated data.
284 284
285 285
286 286 Execution counter (old prompt number)
287 287 ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
288 288
289 289 The kernel has a single, monotonically increasing counter of all execution
290 290 requests that are made with ``silent=False``. This counter is used to populate
291 291 the ``In[n]``, ``Out[n]`` and ``_n`` variables, so clients will likely want to
292 292 display it in some form to the user, which will typically (but not necessarily)
293 293 be done in the prompts. The value of this counter will be returned as the
294 294 ``execution_count`` field of all ``execute_reply`` messages.
295 295
296 296 .. _execution_results:
297 297
298 298 Execution results
299 299 ~~~~~~~~~~~~~~~~~
300 300
301 301 Message type: ``execute_reply``::
302 302
303 303 content = {
304 304 # One of: 'ok' OR 'error' OR 'abort'
305 305 'status' : str,
306 306
307 307 # The global kernel counter that increases by one with each non-silent
308 308 # executed request. This will typically be used by clients to display
309 309 # prompt numbers to the user. If the request was a silent one, this will
310 310 # be the current value of the counter in the kernel.
311 311 'execution_count' : int,
312 312 }
313 313
314 314 When status is 'ok', the following extra fields are present::
315 315
316 316 {
317 317 # The execution payload is a dict with string keys that may have been
318 318 # produced by the code being executed. It is retrieved by the kernel at
319 319 # the end of the execution and sent back to the front end, which can take
320 320 # action on it as needed. See main text for further details.
321 321 'payload' : dict,
322 322
323 323 # Results for the user_variables and user_expressions.
324 324 'user_variables' : dict,
325 325 'user_expressions' : dict,
326 326
327 327 # The kernel will often transform the input provided to it. If the
328 328 # '---->' transform had been applied, this is filled, otherwise it's the
329 329 # empty string. So transformations like magics don't appear here, only
330 330 # autocall ones.
331 331 'transformed_code' : str,
332 332 }
333 333
334 334 .. admonition:: Execution payloads
335 335
336 336 The notion of an 'execution payload' is different from a return value of a
337 337 given set of code, which normally is just displayed on the pyout stream
338 338 through the PUB socket. The idea of a payload is to allow special types of
339 339 code, typically magics, to populate a data container in the IPython kernel
340 340 that will be shipped back to the caller via this channel. The kernel will
341 341 have an API for this, probably something along the lines of::
342 342
343 343 ip.exec_payload_add(key, value)
344 344
345 345 though this API is still in the design stages. The data returned in this
346 346 payload will allow frontends to present special views of what just happened.
347 347
348 348
349 349 When status is 'error', the following extra fields are present::
350 350
351 351 {
352 352 'exc_name' : str, # Exception name, as a string
353 353 'exc_value' : str, # Exception value, as a string
354 354
355 355 # The traceback will contain a list of frames, represented each as a
356 356 # string. For now we'll stick to the existing design of ultraTB, which
357 357 # controls exception level of detail statefully. But eventually we'll
358 358 # want to grow into a model where more information is collected and
359 359 # packed into the traceback object, with clients deciding how little or
360 360 # how much of it to unpack. But for now, let's start with a simple list
361 361 # of strings, since that requires only minimal changes to ultratb as
362 362 # written.
363 363 'traceback' : list,
364 364 }
365 365
366 366
367 367 When status is 'abort', there are for now no additional data fields. This
368 368 happens when the kernel was interrupted by a signal.
369 369
370 370 Kernel attribute access
371 371 -----------------------
372 372
373 373 .. warning::
374 374
375 375 This part of the messaging spec is not actually implemented in the kernel
376 376 yet.
377 377
378 378 While this protocol does not specify full RPC access to arbitrary methods of
379 379 the kernel object, the kernel does allow read (and in some cases write) access
380 380 to certain attributes.
381 381
382 382 The policy for which attributes can be read is: any attribute of the kernel, or
383 383 its sub-objects, that belongs to a :class:`Configurable` object and has been
384 384 declared at the class-level with Traits validation, is in principle accessible
385 385 as long as its name does not begin with a leading underscore. The attribute
386 386 itself will have metadata indicating whether it allows remote read and/or write
387 387 access. The message spec follows for attribute read and write requests.
388 388
389 389 Message type: ``getattr_request``::
390 390
391 391 content = {
392 392 # The (possibly dotted) name of the attribute
393 393 'name' : str,
394 394 }
395 395
396 396 When a ``getattr_request`` fails, there are two possible error types:
397 397
398 398 - AttributeError: this type of error was raised when trying to access the
399 399 given name by the kernel itself. This means that the attribute likely
400 400 doesn't exist.
401 401
402 402 - AccessError: the attribute exists but its value is not readable remotely.
403 403
404 404
405 405 Message type: ``getattr_reply``::
406 406
407 407 content = {
408 408 # One of ['ok', 'AttributeError', 'AccessError'].
409 409 'status' : str,
410 410 # If status is 'ok', a JSON object.
411 411 'value' : object,
412 412 }
413 413
414 414 Message type: ``setattr_request``::
415 415
416 416 content = {
417 417 # The (possibly dotted) name of the attribute
418 418 'name' : str,
419 419
420 420 # A JSON-encoded object, that will be validated by the Traits
421 421 # information in the kernel
422 422 'value' : object,
423 423 }
424 424
425 425 When a ``setattr_request`` fails, there are also two possible error types with
426 426 similar meanings as those of the ``getattr_request`` case, but for writing.
427 427
428 428 Message type: ``setattr_reply``::
429 429
430 430 content = {
431 431 # One of ['ok', 'AttributeError', 'AccessError'].
432 432 'status' : str,
433 433 }
434 434
435 435
436 436
437 437 Object information
438 438 ------------------
439 439
440 440 One of IPython's most used capabilities is the introspection of Python objects
441 441 in the user's namespace, typically invoked via the ``?`` and ``??`` characters
442 442 (which in reality are shorthands for the ``%pinfo`` magic). This is used often
443 443 enough that it warrants an explicit message type, especially because frontends
444 444 may want to get object information in response to user keystrokes (like Tab or
445 445 F1) besides from the user explicitly typing code like ``x??``.
446 446
447 447 Message type: ``object_info_request``::
448 448
449 449 content = {
450 450 # The (possibly dotted) name of the object to be searched in all
451 451 # relevant namespaces
452 452 'name' : str,
453 453
454 454 # The level of detail desired. The default (0) is equivalent to typing
455 455 # 'x?' at the prompt, 1 is equivalent to 'x??'.
456 456 'detail_level' : int,
457 457 }
458 458
459 459 The returned information will be a dictionary with keys very similar to the
460 460 field names that IPython prints at the terminal.
461 461
462 462 Message type: ``object_info_reply``::
463 463
464 464 content = {
465 465 # The name the object was requested under
466 466 'name' : str,
467 467
468 468 # Boolean flag indicating whether the named object was found or not. If
469 469 # it's false, all other fields will be empty.
470 470 'found' : bool,
471 471
472 472 # Flags for magics and system aliases
473 473 'ismagic' : bool,
474 474 'isalias' : bool,
475 475
476 476 # The name of the namespace where the object was found ('builtin',
477 477 # 'magics', 'alias', 'interactive', etc.)
478 478 'namespace' : str,
479 479
480 480 # The type name will be type.__name__ for normal Python objects, but it
481 481 # can also be a string like 'Magic function' or 'System alias'
482 482 'type_name' : str,
483 483
484 484 # The string form of the object, possibly truncated for length if
485 485 # detail_level is 0
486 486 'string_form' : str,
487 487
488 488 # For objects with a __class__ attribute this will be set
489 489 'base_class' : str,
490 490
491 491 # For objects with a __len__ attribute this will be set
492 492 'length' : int,
493 493
494 494 # If the object is a function, class or method whose file we can find,
495 495 # we give its full path
496 496 'file' : str,
497 497
498 498 # For pure Python callable objects, we can reconstruct the object
499 499 # definition line which provides its call signature. For convenience this
500 500 # is returned as a single 'definition' field, but below the raw parts that
501 501 # compose it are also returned as the argspec field.
502 502 'definition' : str,
503 503
504 504 # The individual parts that together form the definition string. Clients
505 505 # with rich display capabilities may use this to provide a richer and more
506 506 # precise representation of the definition line (e.g. by highlighting
507 507 # arguments based on the user's cursor position). For non-callable
508 508 # objects, this field is empty.
509 509 'argspec' : { # The names of all the arguments
510 510 args : list,
511 511 # The name of the varargs (*args), if any
512 512 varargs : str,
513 513 # The name of the varkw (**kw), if any
514 514 varkw : str,
515 515 # The values (as strings) of all default arguments. Note
516 516 # that these must be matched *in reverse* with the 'args'
517 517 # list above, since the first positional args have no default
518 518 # value at all.
519 519 defaults : list,
520 520 },
521 521
522 522 # For instances, provide the constructor signature (the definition of
523 523 # the __init__ method):
524 524 'init_definition' : str,
525 525
526 526 # Docstrings: for any object (function, method, module, package) with a
527 527 # docstring, we show it. But in addition, we may provide additional
528 528 # docstrings. For example, for instances we will show the constructor
529 529 # and class docstrings as well, if available.
530 530 'docstring' : str,
531 531
532 532 # For instances, provide the constructor and class docstrings
533 533 'init_docstring' : str,
534 534 'class_docstring' : str,
535 535
536 536 # If it's a callable object whose call method has a separate docstring and
537 537 # definition line:
538 538 'call_def' : str,
539 539 'call_docstring' : str,
540 540
541 541 # If detail_level was 1, we also try to find the source code that
542 542 # defines the object, if possible. The string 'None' will indicate
543 543 # that no source was found.
544 544 'source' : str,
545 545 }
546 546 '
547 547
548 548 Complete
549 549 --------
550 550
551 551 Message type: ``complete_request``::
552 552
553 553 content = {
554 554 # The text to be completed, such as 'a.is'
555 555 'text' : str,
556 556
557 557 # The full line, such as 'print a.is'. This allows completers to
558 558 # make decisions that may require information about more than just the
559 559 # current word.
560 560 'line' : str,
561 561
562 562 # The entire block of text where the line is. This may be useful in the
563 563 # case of multiline completions where more context may be needed. Note: if
564 564 # in practice this field proves unnecessary, remove it to lighten the
565 565 # messages.
566 566
567 567 'block' : str,
568 568
569 569 # The position of the cursor where the user hit 'TAB' on the line.
570 570 'cursor_pos' : int,
571 571 }
572 572
573 573 Message type: ``complete_reply``::
574 574
575 575 content = {
576 576 # The list of all matches to the completion request, such as
577 577 # ['a.isalnum', 'a.isalpha'] for the above example.
578 578 'matches' : list
579 579 }
580 580
581 581
582 582 History
583 583 -------
584 584
585 585 For clients to explicitly request history from a kernel. The kernel has all
586 586 the actual execution history stored in a single location, so clients can
587 587 request it from the kernel when needed.
588 588
589 589 Message type: ``history_request``::
590 590
591 591 content = {
592 592
593 593 # If True, also return output history in the resulting dict.
594 594 'output' : bool,
595 595
596 596 # If True, return the raw input history, else the transformed input.
597 597 'raw' : bool,
598 598
599 599 # So far, this can be 'range', 'tail' or 'search'.
600 600 'hist_access_type' : str,
601 601
602 602 # If hist_access_type is 'range', get a range of input cells. session can
603 603 # be a positive session number, or a negative number to count back from
604 604 # the current session.
605 605 'session' : int,
606 606 # start and stop are line numbers within that session.
607 607 'start' : int,
608 608 'stop' : int,
609 609
610 610 # If hist_access_type is 'tail', get the last n cells.
611 611 'n' : int,
612 612
613 613 # If hist_access_type is 'search', get cells matching the specified glob
614 614 # pattern (with * and ? as wildcards).
615 615 'pattern' : str,
616 616
617 617 }
618 618
619 619 Message type: ``history_reply``::
620 620
621 621 content = {
622 622 # A list of 3 tuples, either:
623 623 # (session, line_number, input) or
624 624 # (session, line_number, (input, output)),
625 625 # depending on whether output was False or True, respectively.
626 626 'history' : list,
627 627 }
628 628
629 629
630 630 Connect
631 631 -------
632 632
633 633 When a client connects to the request/reply socket of the kernel, it can issue
634 634 a connect request to get basic information about the kernel, such as the ports
635 635 the other ZeroMQ sockets are listening on. This allows clients to only have
636 636 to know about a single port (the XREQ/XREP channel) to connect to a kernel.
637 637
638 638 Message type: ``connect_request``::
639 639
640 640 content = {
641 641 }
642 642
643 643 Message type: ``connect_reply``::
644 644
645 645 content = {
646 646 'xrep_port' : int # The port the XREP socket is listening on.
647 647 'pub_port' : int # The port the PUB socket is listening on.
648 648 'req_port' : int # The port the REQ socket is listening on.
649 649 'hb_port' : int # The port the heartbeat socket is listening on.
650 650 }
651 651
652 652
653 653
654 654 Kernel shutdown
655 655 ---------------
656 656
657 657 The clients can request the kernel to shut itself down; this is used in
658 658 multiple cases:
659 659
660 660 - when the user chooses to close the client application via a menu or window
661 661 control.
662 662 - when the user types 'exit' or 'quit' (or their uppercase magic equivalents).
663 663 - when the user chooses a GUI method (like the 'Ctrl-C' shortcut in the
664 664 IPythonQt client) to force a kernel restart to get a clean kernel without
665 665 losing client-side state like history or inlined figures.
666 666
667 667 The client sends a shutdown request to the kernel, and once it receives the
668 668 reply message (which is otherwise empty), it can assume that the kernel has
669 669 completed shutdown safely.
670 670
671 671 Upon their own shutdown, client applications will typically execute a last
672 672 minute sanity check and forcefully terminate any kernel that is still alive, to
673 673 avoid leaving stray processes in the user's machine.
674 674
675 675 For both shutdown request and reply, there is no actual content that needs to
676 676 be sent, so the content dict is empty.
677 677
678 678 Message type: ``shutdown_request``::
679 679
680 680 content = {
681 681 'restart' : bool # whether the shutdown is final, or precedes a restart
682 682 }
683 683
684 684 Message type: ``shutdown_reply``::
685 685
686 686 content = {
687 687 'restart' : bool # whether the shutdown is final, or precedes a restart
688 688 }
689 689
690 690 .. Note::
691 691
692 692 When the clients detect a dead kernel thanks to inactivity on the heartbeat
693 693 socket, they simply send a forceful process termination signal, since a dead
694 694 process is unlikely to respond in any useful way to messages.
695 695
696 696
697 697 Messages on the PUB/SUB socket
698 698 ==============================
699 699
700 700 Streams (stdout, stderr, etc)
701 701 ------------------------------
702 702
703 703 Message type: ``stream``::
704 704
705 705 content = {
706 706 # The name of the stream is one of 'stdin', 'stdout', 'stderr'
707 707 'name' : str,
708 708
709 709 # The data is an arbitrary string to be written to that stream
710 710 'data' : str,
711 711 }
712 712
713 713 When a kernel receives a raw_input call, it should also broadcast it on the pub
714 714 socket with the names 'stdin' and 'stdin_reply'. This will allow other clients
715 715 to monitor/display kernel interactions and possibly replay them to their user
716 716 or otherwise expose them.
717 717
718 718 Display Data
719 719 ------------
720 720
721 721 This type of message is used to bring back data that should be diplayed (text,
722 722 html, svg, etc.) in the frontends. This data is published to all frontends.
723 723 Each message can have multiple representations of the data; it is up to the
724 724 frontend to decide which to use and how. A single message should contain all
725 725 possible representations of the same information. Each representation should
726 726 be a JSON'able data structure, and should be a valid MIME type.
727 727
728 728 Some questions remain about this design:
729 729
730 730 * Do we use this message type for pyout/displayhook? Probably not, because
731 731 the displayhook also has to handle the Out prompt display. On the other hand
732 732 we could put that information into the metadata secion.
733 733
734 734 Message type: ``display_data``::
735 735
736 736 content = {
737 737
738 738 # Who create the data
739 739 'source' : str,
740 740
741 741 # The data dict contains key/value pairs, where the kids are MIME
742 742 # types and the values are the raw data of the representation in that
743 743 # format. The data dict must minimally contain the ``text/plain``
744 744 # MIME type which is used as a backup representation.
745 745 'data' : dict,
746 746
747 747 # Any metadata that describes the data
748 748 'metadata' : dict
749 749 }
750 750
751 751 Python inputs
752 752 -------------
753 753
754 754 These messages are the re-broadcast of the ``execute_request``.
755 755
756 756 Message type: ``pyin``::
757 757
758 758 content = {
759 759 'code' : str # Source code to be executed, one or more lines
760 760 }
761 761
762 762 Python outputs
763 763 --------------
764 764
765 765 When Python produces output from code that has been compiled in with the
766 766 'single' flag to :func:`compile`, any expression that produces a value (such as
767 767 ``1+1``) is passed to ``sys.displayhook``, which is a callable that can do with
768 768 this value whatever it wants. The default behavior of ``sys.displayhook`` in
769 769 the Python interactive prompt is to print to ``sys.stdout`` the :func:`repr` of
770 770 the value as long as it is not ``None`` (which isn't printed at all). In our
771 771 case, the kernel instantiates as ``sys.displayhook`` an object which has
772 772 similar behavior, but which instead of printing to stdout, broadcasts these
773 773 values as ``pyout`` messages for clients to display appropriately.
774 774
775 775 IPython's displayhook can handle multiple simultaneous formats depending on its
776 776 configuration. The default pretty-printed repr text is always given with the
777 777 ``data`` entry in this message. Any other formats are provided in the
778 778 ``extra_formats`` list. Frontends are free to display any or all of these
779 779 according to its capabilities. ``extra_formats`` list contains 3-tuples of an ID
780 780 string, a type string, and the data. The ID is unique to the formatter
781 781 implementation that created the data. Frontends will typically ignore the ID
782 782 unless if it has requested a particular formatter. The type string tells the
783 783 frontend how to interpret the data. It is often, but not always a MIME type.
784 784 Frontends should ignore types that it does not understand. The data itself is
785 785 any JSON object and depends on the format. It is often, but not always a string.
786 786
787 787 Message type: ``pyout``::
788 788
789 789 content = {
790 790
791 791 # The counter for this execution is also provided so that clients can
792 792 # display it, since IPython automatically creates variables called _N
793 793 # (for prompt N).
794 794 'execution_count' : int,
795 795
796 796 # The data dict contains key/value pairs, where the kids are MIME
797 797 # types and the values are the raw data of the representation in that
798 798 # format. The data dict must minimally contain the ``text/plain``
799 799 # MIME type which is used as a backup representation.
800 800 'data' : dict,
801 801
802 802 }
803 803
804 804 Python errors
805 805 -------------
806 806
807 807 When an error occurs during code execution
808 808
809 809 Message type: ``pyerr``::
810 810
811 811 content = {
812 812 # Similar content to the execute_reply messages for the 'error' case,
813 813 # except the 'status' field is omitted.
814 814 }
815 815
816 816 Kernel status
817 817 -------------
818 818
819 819 This message type is used by frontends to monitor the status of the kernel.
820 820
821 821 Message type: ``status``::
822 822
823 823 content = {
824 824 # When the kernel starts to execute code, it will enter the 'busy'
825 825 # state and when it finishes, it will enter the 'idle' state.
826 826 execution_state : ('busy', 'idle')
827 827 }
828 828
829 829 Kernel crashes
830 830 --------------
831 831
832 832 When the kernel has an unexpected exception, caught by the last-resort
833 833 sys.excepthook, we should broadcast the crash handler's output before exiting.
834 834 This will allow clients to notice that a kernel died, inform the user and
835 835 propose further actions.
836 836
837 837 Message type: ``crash``::
838 838
839 839 content = {
840 840 # Similarly to the 'error' case for execute_reply messages, this will
841 841 # contain exc_name, exc_type and traceback fields.
842 842
843 843 # An additional field with supplementary information such as where to
844 844 # send the crash message
845 845 'info' : str,
846 846 }
847 847
848 848
849 849 Future ideas
850 850 ------------
851 851
852 852 Other potential message types, currently unimplemented, listed below as ideas.
853 853
854 854 Message type: ``file``::
855 855
856 856 content = {
857 857 'path' : 'cool.jpg',
858 858 'mimetype' : str,
859 859 'data' : str,
860 860 }
861 861
862 862
863 863 Messages on the REQ/REP socket
864 864 ==============================
865 865
866 866 This is a socket that goes in the opposite direction: from the kernel to a
867 867 *single* frontend, and its purpose is to allow ``raw_input`` and similar
868 868 operations that read from ``sys.stdin`` on the kernel to be fulfilled by the
869 869 client. For now we will keep these messages as simple as possible, since they
870 870 basically only mean to convey the ``raw_input(prompt)`` call.
871 871
872 872 Message type: ``input_request``::
873 873
874 874 content = { 'prompt' : str }
875 875
876 876 Message type: ``input_reply``::
877 877
878 878 content = { 'value' : str }
879 879
880 880 .. Note::
881 881
882 882 We do not explicitly try to forward the raw ``sys.stdin`` object, because in
883 883 practice the kernel should behave like an interactive program. When a
884 884 program is opened on the console, the keyboard effectively takes over the
885 885 ``stdin`` file descriptor, and it can't be used for raw reading anymore.
886 886 Since the IPython kernel effectively behaves like a console program (albeit
887 887 one whose "keyboard" is actually living in a separate process and
888 888 transported over the zmq connection), raw ``stdin`` isn't expected to be
889 889 available.
890 890
891 891
892 892 Heartbeat for kernels
893 893 =====================
894 894
895 895 Initially we had considered using messages like those above over ZMQ for a
896 896 kernel 'heartbeat' (a way to detect quickly and reliably whether a kernel is
897 897 alive at all, even if it may be busy executing user code). But this has the
898 898 problem that if the kernel is locked inside extension code, it wouldn't execute
899 899 the python heartbeat code. But it turns out that we can implement a basic
900 900 heartbeat with pure ZMQ, without using any Python messaging at all.
901 901
902 902 The monitor sends out a single zmq message (right now, it is a str of the
903 903 monitor's lifetime in seconds), and gets the same message right back, prefixed
904 904 with the zmq identity of the XREQ socket in the heartbeat process. This can be
905 905 a uuid, or even a full message, but there doesn't seem to be a need for packing
906 906 up a message when the sender and receiver are the exact same Python object.
907 907
908 908 The model is this::
909 909
910 910 monitor.send(str(self.lifetime)) # '1.2345678910'
911 911
912 912 and the monitor receives some number of messages of the form::
913 913
914 914 ['uuid-abcd-dead-beef', '1.2345678910']
915 915
916 916 where the first part is the zmq.IDENTITY of the heart's XREQ on the engine, and
917 917 the rest is the message sent by the monitor. No Python code ever has any
918 918 access to the message between the monitor's send, and the monitor's recv.
919 919
920 920
921 921 ToDo
922 922 ====
923 923
924 924 Missing things include:
925 925
926 926 * Important: finish thinking through the payload concept and API.
927 927
928 928 * Important: ensure that we have a good solution for magics like %edit. It's
929 929 likely that with the payload concept we can build a full solution, but not
930 930 100% clear yet.
931 931
932 932 * Finishing the details of the heartbeat protocol.
933 933
934 934 * Signal handling: specify what kind of information kernel should broadcast (or
935 935 not) when it receives signals.
936 936
937 937 .. include:: ../links.rst
General Comments 0
You need to be logged in to leave comments. Login now