##// END OF EJS Templates
Collapse ZMQSocketChannel into HBChannel class
Thomas Kluyver -
Show More
@@ -1,338 +1,259 b''
1 """Base classes to manage a Client's interaction with a running kernel"""
1 """Base classes to manage a Client's interaction with a running kernel"""
2
2
3 # Copyright (c) IPython Development Team.
3 # Copyright (c) IPython Development Team.
4 # Distributed under the terms of the Modified BSD License.
4 # Distributed under the terms of the Modified BSD License.
5
5
6 from __future__ import absolute_import
6 from __future__ import absolute_import
7
7
8 import atexit
8 import atexit
9 import errno
9 import errno
10 from threading import Thread
10 from threading import Thread
11 import time
11 import time
12
12
13 import zmq
13 import zmq
14 # import ZMQError in top-level namespace, to avoid ugly attribute-error messages
14 # import ZMQError in top-level namespace, to avoid ugly attribute-error messages
15 # during garbage collection of threads at exit:
15 # during garbage collection of threads at exit:
16 from zmq import ZMQError
16 from zmq import ZMQError
17 from zmq.eventloop import ioloop, zmqstream
17 from zmq.eventloop import ioloop, zmqstream
18
18
19 from IPython.core.release import kernel_protocol_version_info
19 from IPython.core.release import kernel_protocol_version_info
20
20
21 from .channelsabc import (
21 from .channelsabc import (
22 ShellChannelABC, IOPubChannelABC,
22 ShellChannelABC, IOPubChannelABC,
23 HBChannelABC, StdInChannelABC,
23 HBChannelABC, StdInChannelABC,
24 )
24 )
25 from IPython.utils.py3compat import string_types, iteritems
25 from IPython.utils.py3compat import string_types, iteritems
26
26
27 #-----------------------------------------------------------------------------
27 #-----------------------------------------------------------------------------
28 # Constants and exceptions
28 # Constants and exceptions
29 #-----------------------------------------------------------------------------
29 #-----------------------------------------------------------------------------
30
30
31 major_protocol_version = kernel_protocol_version_info[0]
31 major_protocol_version = kernel_protocol_version_info[0]
32
32
33 class InvalidPortNumber(Exception):
33 class InvalidPortNumber(Exception):
34 pass
34 pass
35
35
36 #-----------------------------------------------------------------------------
36 #-----------------------------------------------------------------------------
37 # Utility functions
37 # Utility functions
38 #-----------------------------------------------------------------------------
38 #-----------------------------------------------------------------------------
39
39
40 # some utilities to validate message structure, these might get moved elsewhere
40 # some utilities to validate message structure, these might get moved elsewhere
41 # if they prove to have more generic utility
41 # if they prove to have more generic utility
42
42
43 def validate_string_list(lst):
43 def validate_string_list(lst):
44 """Validate that the input is a list of strings.
44 """Validate that the input is a list of strings.
45
45
46 Raises ValueError if not."""
46 Raises ValueError if not."""
47 if not isinstance(lst, list):
47 if not isinstance(lst, list):
48 raise ValueError('input %r must be a list' % lst)
48 raise ValueError('input %r must be a list' % lst)
49 for x in lst:
49 for x in lst:
50 if not isinstance(x, string_types):
50 if not isinstance(x, string_types):
51 raise ValueError('element %r in list must be a string' % x)
51 raise ValueError('element %r in list must be a string' % x)
52
52
53
53
54 def validate_string_dict(dct):
54 def validate_string_dict(dct):
55 """Validate that the input is a dict with string keys and values.
55 """Validate that the input is a dict with string keys and values.
56
56
57 Raises ValueError if not."""
57 Raises ValueError if not."""
58 for k,v in iteritems(dct):
58 for k,v in iteritems(dct):
59 if not isinstance(k, string_types):
59 if not isinstance(k, string_types):
60 raise ValueError('key %r in dict must be a string' % k)
60 raise ValueError('key %r in dict must be a string' % k)
61 if not isinstance(v, string_types):
61 if not isinstance(v, string_types):
62 raise ValueError('value %r in dict must be a string' % v)
62 raise ValueError('value %r in dict must be a string' % v)
63
63
64
64
65 #-----------------------------------------------------------------------------
66 # ZMQ Socket Channel classes
67 #-----------------------------------------------------------------------------
68
69 class ZMQSocketChannel(Thread):
70 """The base class for the channels that use ZMQ sockets."""
71 context = None
72 session = None
73 socket = None
74 ioloop = None
75 stream = None
76 _address = None
77 _exiting = False
78 proxy_methods = []
79
80 def __init__(self, context, session, address):
81 """Create a channel.
82
83 Parameters
84 ----------
85 context : :class:`zmq.Context`
86 The ZMQ context to use.
87 session : :class:`session.Session`
88 The session to use.
89 address : zmq url
90 Standard (ip, port) tuple that the kernel is listening on.
91 """
92 super(ZMQSocketChannel, self).__init__()
93 self.daemon = True
94
95 self.context = context
96 self.session = session
97 if isinstance(address, tuple):
98 if address[1] == 0:
99 message = 'The port number for a channel cannot be 0.'
100 raise InvalidPortNumber(message)
101 address = "tcp://%s:%i" % address
102 self._address = address
103 atexit.register(self._notice_exit)
104
105 def _notice_exit(self):
106 self._exiting = True
107
108 def _run_loop(self):
109 """Run my loop, ignoring EINTR events in the poller"""
110 while True:
111 try:
112 self.ioloop.start()
113 except ZMQError as e:
114 if e.errno == errno.EINTR:
115 continue
116 else:
117 raise
118 except Exception:
119 if self._exiting:
120 break
121 else:
122 raise
123 else:
124 break
125
126 def stop(self):
127 """Stop the channel's event loop and join its thread.
128
129 This calls :meth:`~threading.Thread.join` and returns when the thread
130 terminates. :class:`RuntimeError` will be raised if
131 :meth:`~threading.Thread.start` is called again.
132 """
133 if self.ioloop is not None:
134 self.ioloop.stop()
135 self.join()
136 self.close()
137
138 def close(self):
139 if self.ioloop is not None:
140 try:
141 self.ioloop.close(all_fds=True)
142 except Exception:
143 pass
144 if self.socket is not None:
145 try:
146 self.socket.close(linger=0)
147 except Exception:
148 pass
149 self.socket = None
150
151 @property
152 def address(self):
153 """Get the channel's address as a zmq url string.
154
155 These URLS have the form: 'tcp://127.0.0.1:5555'.
156 """
157 return self._address
158
159 def _queue_send(self, msg):
160 """Queue a message to be sent from the IOLoop's thread.
161
162 Parameters
163 ----------
164 msg : message to send
165
166 This is threadsafe, as it uses IOLoop.add_callback to give the loop's
167 thread control of the action.
168 """
169 def thread_send():
170 self.session.send(self.stream, msg)
171 self.ioloop.add_callback(thread_send)
172
173 def _handle_recv(self, msg):
174 """Callback for stream.on_recv.
175
176 Unpacks message, and calls handlers with it.
177 """
178 ident,smsg = self.session.feed_identities(msg)
179 msg = self.session.deserialize(smsg)
180 self.call_handlers(msg)
181
182
183 def make_shell_socket(context, identity, address):
65 def make_shell_socket(context, identity, address):
184 socket = context.socket(zmq.DEALER)
66 socket = context.socket(zmq.DEALER)
185 socket.linger = 1000
67 socket.linger = 1000
186 socket.setsockopt(zmq.IDENTITY, identity)
68 socket.setsockopt(zmq.IDENTITY, identity)
187 socket.connect(address)
69 socket.connect(address)
188 return socket
70 return socket
189
71
190 def make_iopub_socket(context, identity, address):
72 def make_iopub_socket(context, identity, address):
191 socket = context.socket(zmq.SUB)
73 socket = context.socket(zmq.SUB)
192 socket.linger = 1000
74 socket.linger = 1000
193 socket.setsockopt(zmq.SUBSCRIBE,b'')
75 socket.setsockopt(zmq.SUBSCRIBE,b'')
194 socket.setsockopt(zmq.IDENTITY, identity)
76 socket.setsockopt(zmq.IDENTITY, identity)
195 socket.connect(address)
77 socket.connect(address)
196 return socket
78 return socket
197
79
198 def make_stdin_socket(context, identity, address):
80 def make_stdin_socket(context, identity, address):
199 socket = context.socket(zmq.DEALER)
81 socket = context.socket(zmq.DEALER)
200 socket.linger = 1000
82 socket.linger = 1000
201 socket.setsockopt(zmq.IDENTITY, identity)
83 socket.setsockopt(zmq.IDENTITY, identity)
202 socket.connect(address)
84 socket.connect(address)
203 return socket
85 return socket
204
86
205 class HBChannel(ZMQSocketChannel):
87 class HBChannel(Thread):
206 """The heartbeat channel which monitors the kernel heartbeat.
88 """The heartbeat channel which monitors the kernel heartbeat.
207
89
208 Note that the heartbeat channel is paused by default. As long as you start
90 Note that the heartbeat channel is paused by default. As long as you start
209 this channel, the kernel manager will ensure that it is paused and un-paused
91 this channel, the kernel manager will ensure that it is paused and un-paused
210 as appropriate.
92 as appropriate.
211 """
93 """
94 context = None
95 session = None
96 socket = None
97 address = None
98 _exiting = False
212
99
213 time_to_dead = 1.
100 time_to_dead = 1.
214 socket = None
215 poller = None
101 poller = None
216 _running = None
102 _running = None
217 _pause = None
103 _pause = None
218 _beating = None
104 _beating = None
219
105
220 def __init__(self, context, session, address):
106 def __init__(self, context, session, address):
221 super(HBChannel, self).__init__(context, session, address)
107 """Create the heartbeat monitor thread.
108
109 Parameters
110 ----------
111 context : :class:`zmq.Context`
112 The ZMQ context to use.
113 session : :class:`session.Session`
114 The session to use.
115 address : zmq url
116 Standard (ip, port) tuple that the kernel is listening on.
117 """
118 super(HBChannel, self).__init__()
119 self.daemon = True
120
121 self.context = context
122 self.session = session
123 if isinstance(address, tuple):
124 if address[1] == 0:
125 message = 'The port number for a channel cannot be 0.'
126 raise InvalidPortNumber(message)
127 address = "tcp://%s:%i" % address
128 self.address = address
129 atexit.register(self._notice_exit)
130
222 self._running = False
131 self._running = False
223 self._pause =True
132 self._pause = True
224 self.poller = zmq.Poller()
133 self.poller = zmq.Poller()
225
134
135 def _notice_exit(self):
136 self._exiting = True
137
226 def _create_socket(self):
138 def _create_socket(self):
227 if self.socket is not None:
139 if self.socket is not None:
228 # close previous socket, before opening a new one
140 # close previous socket, before opening a new one
229 self.poller.unregister(self.socket)
141 self.poller.unregister(self.socket)
230 self.socket.close()
142 self.socket.close()
231 self.socket = self.context.socket(zmq.REQ)
143 self.socket = self.context.socket(zmq.REQ)
232 self.socket.linger = 1000
144 self.socket.linger = 1000
233 self.socket.connect(self.address)
145 self.socket.connect(self.address)
234
146
235 self.poller.register(self.socket, zmq.POLLIN)
147 self.poller.register(self.socket, zmq.POLLIN)
236
148
237 def _poll(self, start_time):
149 def _poll(self, start_time):
238 """poll for heartbeat replies until we reach self.time_to_dead.
150 """poll for heartbeat replies until we reach self.time_to_dead.
239
151
240 Ignores interrupts, and returns the result of poll(), which
152 Ignores interrupts, and returns the result of poll(), which
241 will be an empty list if no messages arrived before the timeout,
153 will be an empty list if no messages arrived before the timeout,
242 or the event tuple if there is a message to receive.
154 or the event tuple if there is a message to receive.
243 """
155 """
244
156
245 until_dead = self.time_to_dead - (time.time() - start_time)
157 until_dead = self.time_to_dead - (time.time() - start_time)
246 # ensure poll at least once
158 # ensure poll at least once
247 until_dead = max(until_dead, 1e-3)
159 until_dead = max(until_dead, 1e-3)
248 events = []
160 events = []
249 while True:
161 while True:
250 try:
162 try:
251 events = self.poller.poll(1000 * until_dead)
163 events = self.poller.poll(1000 * until_dead)
252 except ZMQError as e:
164 except ZMQError as e:
253 if e.errno == errno.EINTR:
165 if e.errno == errno.EINTR:
254 # ignore interrupts during heartbeat
166 # ignore interrupts during heartbeat
255 # this may never actually happen
167 # this may never actually happen
256 until_dead = self.time_to_dead - (time.time() - start_time)
168 until_dead = self.time_to_dead - (time.time() - start_time)
257 until_dead = max(until_dead, 1e-3)
169 until_dead = max(until_dead, 1e-3)
258 pass
170 pass
259 else:
171 else:
260 raise
172 raise
261 except Exception:
173 except Exception:
262 if self._exiting:
174 if self._exiting:
263 break
175 break
264 else:
176 else:
265 raise
177 raise
266 else:
178 else:
267 break
179 break
268 return events
180 return events
269
181
270 def run(self):
182 def run(self):
271 """The thread's main activity. Call start() instead."""
183 """The thread's main activity. Call start() instead."""
272 self._create_socket()
184 self._create_socket()
273 self._running = True
185 self._running = True
274 self._beating = True
186 self._beating = True
275
187
276 while self._running:
188 while self._running:
277 if self._pause:
189 if self._pause:
278 # just sleep, and skip the rest of the loop
190 # just sleep, and skip the rest of the loop
279 time.sleep(self.time_to_dead)
191 time.sleep(self.time_to_dead)
280 continue
192 continue
281
193
282 since_last_heartbeat = 0.0
194 since_last_heartbeat = 0.0
283 # io.rprint('Ping from HB channel') # dbg
195 # io.rprint('Ping from HB channel') # dbg
284 # no need to catch EFSM here, because the previous event was
196 # no need to catch EFSM here, because the previous event was
285 # either a recv or connect, which cannot be followed by EFSM
197 # either a recv or connect, which cannot be followed by EFSM
286 self.socket.send(b'ping')
198 self.socket.send(b'ping')
287 request_time = time.time()
199 request_time = time.time()
288 ready = self._poll(request_time)
200 ready = self._poll(request_time)
289 if ready:
201 if ready:
290 self._beating = True
202 self._beating = True
291 # the poll above guarantees we have something to recv
203 # the poll above guarantees we have something to recv
292 self.socket.recv()
204 self.socket.recv()
293 # sleep the remainder of the cycle
205 # sleep the remainder of the cycle
294 remainder = self.time_to_dead - (time.time() - request_time)
206 remainder = self.time_to_dead - (time.time() - request_time)
295 if remainder > 0:
207 if remainder > 0:
296 time.sleep(remainder)
208 time.sleep(remainder)
297 continue
209 continue
298 else:
210 else:
299 # nothing was received within the time limit, signal heart failure
211 # nothing was received within the time limit, signal heart failure
300 self._beating = False
212 self._beating = False
301 since_last_heartbeat = time.time() - request_time
213 since_last_heartbeat = time.time() - request_time
302 self.call_handlers(since_last_heartbeat)
214 self.call_handlers(since_last_heartbeat)
303 # and close/reopen the socket, because the REQ/REP cycle has been broken
215 # and close/reopen the socket, because the REQ/REP cycle has been broken
304 self._create_socket()
216 self._create_socket()
305 continue
217 continue
306
218
307 def pause(self):
219 def pause(self):
308 """Pause the heartbeat."""
220 """Pause the heartbeat."""
309 self._pause = True
221 self._pause = True
310
222
311 def unpause(self):
223 def unpause(self):
312 """Unpause the heartbeat."""
224 """Unpause the heartbeat."""
313 self._pause = False
225 self._pause = False
314
226
315 def is_beating(self):
227 def is_beating(self):
316 """Is the heartbeat running and responsive (and not paused)."""
228 """Is the heartbeat running and responsive (and not paused)."""
317 if self.is_alive() and not self._pause and self._beating:
229 if self.is_alive() and not self._pause and self._beating:
318 return True
230 return True
319 else:
231 else:
320 return False
232 return False
321
233
322 def stop(self):
234 def stop(self):
323 """Stop the channel's event loop and join its thread."""
235 """Stop the channel's event loop and join its thread."""
324 self._running = False
236 self._running = False
325 super(HBChannel, self).stop()
237 self.join()
238 self.close()
239
240 def close(self):
241 if self.socket is not None:
242 try:
243 self.socket.close(linger=0)
244 except Exception:
245 pass
246 self.socket = None
326
247
327 def call_handlers(self, since_last_heartbeat):
248 def call_handlers(self, since_last_heartbeat):
328 """This method is called in the ioloop thread when a message arrives.
249 """This method is called in the ioloop thread when a message arrives.
329
250
330 Subclasses should override this method to handle incoming messages.
251 Subclasses should override this method to handle incoming messages.
331 It is important to remember that this method is called in the thread
252 It is important to remember that this method is called in the thread
332 so that some logic must be done to ensure that the application level
253 so that some logic must be done to ensure that the application level
333 handlers are called in the application thread.
254 handlers are called in the application thread.
334 """
255 """
335 pass
256 pass
336
257
337
258
338 HBChannelABC.register(HBChannel)
259 HBChannelABC.register(HBChannel)
General Comments 0
You need to be logged in to leave comments. Login now