##// END OF EJS Templates
Collapse ZMQSocketChannel into HBChannel class
Thomas Kluyver -
Show More
@@ -1,338 +1,259 b''
1 1 """Base classes to manage a Client's interaction with a running kernel"""
2 2
3 3 # Copyright (c) IPython Development Team.
4 4 # Distributed under the terms of the Modified BSD License.
5 5
6 6 from __future__ import absolute_import
7 7
8 8 import atexit
9 9 import errno
10 10 from threading import Thread
11 11 import time
12 12
13 13 import zmq
14 14 # import ZMQError in top-level namespace, to avoid ugly attribute-error messages
15 15 # during garbage collection of threads at exit:
16 16 from zmq import ZMQError
17 17 from zmq.eventloop import ioloop, zmqstream
18 18
19 19 from IPython.core.release import kernel_protocol_version_info
20 20
21 21 from .channelsabc import (
22 22 ShellChannelABC, IOPubChannelABC,
23 23 HBChannelABC, StdInChannelABC,
24 24 )
25 25 from IPython.utils.py3compat import string_types, iteritems
26 26
27 27 #-----------------------------------------------------------------------------
28 28 # Constants and exceptions
29 29 #-----------------------------------------------------------------------------
30 30
31 31 major_protocol_version = kernel_protocol_version_info[0]
32 32
33 33 class InvalidPortNumber(Exception):
34 34 pass
35 35
36 36 #-----------------------------------------------------------------------------
37 37 # Utility functions
38 38 #-----------------------------------------------------------------------------
39 39
40 40 # some utilities to validate message structure, these might get moved elsewhere
41 41 # if they prove to have more generic utility
42 42
43 43 def validate_string_list(lst):
44 44 """Validate that the input is a list of strings.
45 45
46 46 Raises ValueError if not."""
47 47 if not isinstance(lst, list):
48 48 raise ValueError('input %r must be a list' % lst)
49 49 for x in lst:
50 50 if not isinstance(x, string_types):
51 51 raise ValueError('element %r in list must be a string' % x)
52 52
53 53
54 54 def validate_string_dict(dct):
55 55 """Validate that the input is a dict with string keys and values.
56 56
57 57 Raises ValueError if not."""
58 58 for k,v in iteritems(dct):
59 59 if not isinstance(k, string_types):
60 60 raise ValueError('key %r in dict must be a string' % k)
61 61 if not isinstance(v, string_types):
62 62 raise ValueError('value %r in dict must be a string' % v)
63 63
64 64
65 #-----------------------------------------------------------------------------
66 # ZMQ Socket Channel classes
67 #-----------------------------------------------------------------------------
68
69 class ZMQSocketChannel(Thread):
70 """The base class for the channels that use ZMQ sockets."""
71 context = None
72 session = None
73 socket = None
74 ioloop = None
75 stream = None
76 _address = None
77 _exiting = False
78 proxy_methods = []
79
80 def __init__(self, context, session, address):
81 """Create a channel.
82
83 Parameters
84 ----------
85 context : :class:`zmq.Context`
86 The ZMQ context to use.
87 session : :class:`session.Session`
88 The session to use.
89 address : zmq url
90 Standard (ip, port) tuple that the kernel is listening on.
91 """
92 super(ZMQSocketChannel, self).__init__()
93 self.daemon = True
94
95 self.context = context
96 self.session = session
97 if isinstance(address, tuple):
98 if address[1] == 0:
99 message = 'The port number for a channel cannot be 0.'
100 raise InvalidPortNumber(message)
101 address = "tcp://%s:%i" % address
102 self._address = address
103 atexit.register(self._notice_exit)
104
105 def _notice_exit(self):
106 self._exiting = True
107
108 def _run_loop(self):
109 """Run my loop, ignoring EINTR events in the poller"""
110 while True:
111 try:
112 self.ioloop.start()
113 except ZMQError as e:
114 if e.errno == errno.EINTR:
115 continue
116 else:
117 raise
118 except Exception:
119 if self._exiting:
120 break
121 else:
122 raise
123 else:
124 break
125
126 def stop(self):
127 """Stop the channel's event loop and join its thread.
128
129 This calls :meth:`~threading.Thread.join` and returns when the thread
130 terminates. :class:`RuntimeError` will be raised if
131 :meth:`~threading.Thread.start` is called again.
132 """
133 if self.ioloop is not None:
134 self.ioloop.stop()
135 self.join()
136 self.close()
137
138 def close(self):
139 if self.ioloop is not None:
140 try:
141 self.ioloop.close(all_fds=True)
142 except Exception:
143 pass
144 if self.socket is not None:
145 try:
146 self.socket.close(linger=0)
147 except Exception:
148 pass
149 self.socket = None
150
151 @property
152 def address(self):
153 """Get the channel's address as a zmq url string.
154
155 These URLS have the form: 'tcp://127.0.0.1:5555'.
156 """
157 return self._address
158
159 def _queue_send(self, msg):
160 """Queue a message to be sent from the IOLoop's thread.
161
162 Parameters
163 ----------
164 msg : message to send
165
166 This is threadsafe, as it uses IOLoop.add_callback to give the loop's
167 thread control of the action.
168 """
169 def thread_send():
170 self.session.send(self.stream, msg)
171 self.ioloop.add_callback(thread_send)
172
173 def _handle_recv(self, msg):
174 """Callback for stream.on_recv.
175
176 Unpacks message, and calls handlers with it.
177 """
178 ident,smsg = self.session.feed_identities(msg)
179 msg = self.session.deserialize(smsg)
180 self.call_handlers(msg)
181
182
183 65 def make_shell_socket(context, identity, address):
184 66 socket = context.socket(zmq.DEALER)
185 67 socket.linger = 1000
186 68 socket.setsockopt(zmq.IDENTITY, identity)
187 69 socket.connect(address)
188 70 return socket
189 71
190 72 def make_iopub_socket(context, identity, address):
191 73 socket = context.socket(zmq.SUB)
192 74 socket.linger = 1000
193 75 socket.setsockopt(zmq.SUBSCRIBE,b'')
194 76 socket.setsockopt(zmq.IDENTITY, identity)
195 77 socket.connect(address)
196 78 return socket
197 79
198 80 def make_stdin_socket(context, identity, address):
199 81 socket = context.socket(zmq.DEALER)
200 82 socket.linger = 1000
201 83 socket.setsockopt(zmq.IDENTITY, identity)
202 84 socket.connect(address)
203 85 return socket
204 86
205 class HBChannel(ZMQSocketChannel):
87 class HBChannel(Thread):
206 88 """The heartbeat channel which monitors the kernel heartbeat.
207 89
208 90 Note that the heartbeat channel is paused by default. As long as you start
209 91 this channel, the kernel manager will ensure that it is paused and un-paused
210 92 as appropriate.
211 93 """
94 context = None
95 session = None
96 socket = None
97 address = None
98 _exiting = False
212 99
213 100 time_to_dead = 1.
214 socket = None
215 101 poller = None
216 102 _running = None
217 103 _pause = None
218 104 _beating = None
219 105
220 106 def __init__(self, context, session, address):
221 super(HBChannel, self).__init__(context, session, address)
107 """Create the heartbeat monitor thread.
108
109 Parameters
110 ----------
111 context : :class:`zmq.Context`
112 The ZMQ context to use.
113 session : :class:`session.Session`
114 The session to use.
115 address : zmq url
116 Standard (ip, port) tuple that the kernel is listening on.
117 """
118 super(HBChannel, self).__init__()
119 self.daemon = True
120
121 self.context = context
122 self.session = session
123 if isinstance(address, tuple):
124 if address[1] == 0:
125 message = 'The port number for a channel cannot be 0.'
126 raise InvalidPortNumber(message)
127 address = "tcp://%s:%i" % address
128 self.address = address
129 atexit.register(self._notice_exit)
130
222 131 self._running = False
223 self._pause =True
132 self._pause = True
224 133 self.poller = zmq.Poller()
225 134
135 def _notice_exit(self):
136 self._exiting = True
137
226 138 def _create_socket(self):
227 139 if self.socket is not None:
228 140 # close previous socket, before opening a new one
229 141 self.poller.unregister(self.socket)
230 142 self.socket.close()
231 143 self.socket = self.context.socket(zmq.REQ)
232 144 self.socket.linger = 1000
233 145 self.socket.connect(self.address)
234 146
235 147 self.poller.register(self.socket, zmq.POLLIN)
236 148
237 149 def _poll(self, start_time):
238 150 """poll for heartbeat replies until we reach self.time_to_dead.
239 151
240 152 Ignores interrupts, and returns the result of poll(), which
241 153 will be an empty list if no messages arrived before the timeout,
242 154 or the event tuple if there is a message to receive.
243 155 """
244 156
245 157 until_dead = self.time_to_dead - (time.time() - start_time)
246 158 # ensure poll at least once
247 159 until_dead = max(until_dead, 1e-3)
248 160 events = []
249 161 while True:
250 162 try:
251 163 events = self.poller.poll(1000 * until_dead)
252 164 except ZMQError as e:
253 165 if e.errno == errno.EINTR:
254 166 # ignore interrupts during heartbeat
255 167 # this may never actually happen
256 168 until_dead = self.time_to_dead - (time.time() - start_time)
257 169 until_dead = max(until_dead, 1e-3)
258 170 pass
259 171 else:
260 172 raise
261 173 except Exception:
262 174 if self._exiting:
263 175 break
264 176 else:
265 177 raise
266 178 else:
267 179 break
268 180 return events
269 181
270 182 def run(self):
271 183 """The thread's main activity. Call start() instead."""
272 184 self._create_socket()
273 185 self._running = True
274 186 self._beating = True
275 187
276 188 while self._running:
277 189 if self._pause:
278 190 # just sleep, and skip the rest of the loop
279 191 time.sleep(self.time_to_dead)
280 192 continue
281 193
282 194 since_last_heartbeat = 0.0
283 195 # io.rprint('Ping from HB channel') # dbg
284 196 # no need to catch EFSM here, because the previous event was
285 197 # either a recv or connect, which cannot be followed by EFSM
286 198 self.socket.send(b'ping')
287 199 request_time = time.time()
288 200 ready = self._poll(request_time)
289 201 if ready:
290 202 self._beating = True
291 203 # the poll above guarantees we have something to recv
292 204 self.socket.recv()
293 205 # sleep the remainder of the cycle
294 206 remainder = self.time_to_dead - (time.time() - request_time)
295 207 if remainder > 0:
296 208 time.sleep(remainder)
297 209 continue
298 210 else:
299 211 # nothing was received within the time limit, signal heart failure
300 212 self._beating = False
301 213 since_last_heartbeat = time.time() - request_time
302 214 self.call_handlers(since_last_heartbeat)
303 215 # and close/reopen the socket, because the REQ/REP cycle has been broken
304 216 self._create_socket()
305 217 continue
306 218
307 219 def pause(self):
308 220 """Pause the heartbeat."""
309 221 self._pause = True
310 222
311 223 def unpause(self):
312 224 """Unpause the heartbeat."""
313 225 self._pause = False
314 226
315 227 def is_beating(self):
316 228 """Is the heartbeat running and responsive (and not paused)."""
317 229 if self.is_alive() and not self._pause and self._beating:
318 230 return True
319 231 else:
320 232 return False
321 233
322 234 def stop(self):
323 235 """Stop the channel's event loop and join its thread."""
324 236 self._running = False
325 super(HBChannel, self).stop()
237 self.join()
238 self.close()
239
240 def close(self):
241 if self.socket is not None:
242 try:
243 self.socket.close(linger=0)
244 except Exception:
245 pass
246 self.socket = None
326 247
327 248 def call_handlers(self, since_last_heartbeat):
328 249 """This method is called in the ioloop thread when a message arrives.
329 250
330 251 Subclasses should override this method to handle incoming messages.
331 252 It is important to remember that this method is called in the thread
332 253 so that some logic must be done to ensure that the application level
333 254 handlers are called in the application thread.
334 255 """
335 256 pass
336 257
337 258
338 259 HBChannelABC.register(HBChannel)
General Comments 0
You need to be logged in to leave comments. Login now