##// END OF EJS Templates
Merge pull request #6110 from minrk/binarycomm...
Matthias Bussonnier -
r18418:2a8a2c87 merge
parent child Browse files
Show More
@@ -0,0 +1,114
1 // Copyright (c) IPython Development Team.
2 // Distributed under the terms of the Modified BSD License.
3
4 define([
5 'underscore',
6 ], function (_) {
7 "use strict";
8
9 var _deserialize_array_buffer = function (buf) {
10 var data = new DataView(buf);
11 // read the header: 1 + nbufs 32b integers
12 var nbufs = data.getUint32(0);
13 var offsets = [];
14 var i;
15 for (i = 1; i <= nbufs; i++) {
16 offsets.push(data.getUint32(i * 4));
17 }
18 var json_bytes = new Uint8Array(buf.slice(offsets[0], offsets[1]));
19 var msg = JSON.parse(
20 (new TextDecoder('utf8')).decode(json_bytes)
21 );
22 // the remaining chunks are stored as DataViews in msg.buffers
23 msg.buffers = [];
24 var start, stop;
25 for (i = 1; i < nbufs; i++) {
26 start = offsets[i];
27 stop = offsets[i+1] || buf.byteLength;
28 msg.buffers.push(new DataView(buf.slice(start, stop)));
29 }
30 return msg;
31 };
32
33 var _deserialize_binary = function(data, callback) {
34 // deserialize the binary message format
35 // callback will be called with a message whose buffers attribute
36 // will be an array of DataViews.
37 if (data instanceof Blob) {
38 // data is Blob, have to deserialize from ArrayBuffer in reader callback
39 var reader = new FileReader();
40 reader.onload = function () {
41 var msg = _deserialize_array_buffer(this.result);
42 callback(msg);
43 };
44 reader.readAsArrayBuffer(data);
45 } else {
46 // data is ArrayBuffer, can deserialize directly
47 var msg = _deserialize_array_buffer(data);
48 callback(msg);
49 }
50 };
51
52 var deserialize = function (data, callback) {
53 // deserialize a message and pass the unpacked message object to callback
54 if (typeof data === "string") {
55 // text JSON message
56 callback(JSON.parse(data));
57 } else {
58 // binary message
59 _deserialize_binary(data, callback);
60 }
61 };
62
63 var _serialize_binary = function (msg) {
64 // implement the binary serialization protocol
65 // serializes JSON message to ArrayBuffer
66 msg = _.clone(msg);
67 var offsets = [];
68 var buffers = [];
69 msg.buffers.map(function (buf) {
70 buffers.push(buf);
71 });
72 delete msg.buffers;
73 var json_utf8 = (new TextEncoder('utf8')).encode(JSON.stringify(msg));
74 buffers.unshift(json_utf8);
75 var nbufs = buffers.length;
76 offsets.push(4 * (nbufs + 1));
77 var i;
78 for (i = 0; i + 1 < buffers.length; i++) {
79 offsets.push(offsets[offsets.length-1] + buffers[i].byteLength);
80 }
81 var msg_buf = new Uint8Array(
82 offsets[offsets.length-1] + buffers[buffers.length-1].byteLength
83 );
84 // use DataView.setUint32 for network byte-order
85 var view = new DataView(msg_buf.buffer);
86 // write nbufs to first 4 bytes
87 view.setUint32(0, nbufs);
88 // write offsets to next 4 * nbufs bytes
89 for (i = 0; i < offsets.length; i++) {
90 view.setUint32(4 * (i+1), offsets[i]);
91 }
92 // write all the buffers at their respective offsets
93 for (i = 0; i < buffers.length; i++) {
94 msg_buf.set(new Uint8Array(buffers[i].buffer), offsets[i]);
95 }
96
97 // return raw ArrayBuffer
98 return msg_buf.buffer;
99 };
100
101 var serialize = function (msg) {
102 if (msg.buffers && msg.buffers.length) {
103 return _serialize_binary(msg);
104 } else {
105 return JSON.stringify(msg);
106 }
107 };
108
109 var exports = {
110 deserialize : deserialize,
111 serialize: serialize
112 };
113 return exports;
114 }); No newline at end of file
@@ -0,0 +1,113
1 //
2 // Test binary messages on websockets.
3 // Only works on slimer for now, due to old websocket impl in phantomjs.
4 //
5
6 casper.notebook_test(function () {
7 if (!this.slimerjs) {
8 console.log("Can't test binary websockets on phantomjs.");
9 return;
10 }
11 // create EchoBuffers target on js-side.
12 // it just captures and echos comm messages.
13 this.then(function () {
14 var success = this.evaluate(function () {
15 IPython._msgs = [];
16
17 var EchoBuffers = function(comm) {
18 this.comm = comm;
19 this.comm.on_msg($.proxy(this.on_msg, this));
20 };
21
22 EchoBuffers.prototype.on_msg = function (msg) {
23 IPython._msgs.push(msg);
24 this.comm.send(msg.content.data, {}, {}, msg.buffers);
25 };
26
27 IPython.notebook.kernel.comm_manager.register_target("echo", function (comm) {
28 return new EchoBuffers(comm);
29 });
30
31 return true;
32 });
33 this.test.assertEquals(success, true, "Created echo comm target");
34 });
35
36 // Create a similar comm that captures messages Python-side
37 this.then(function () {
38 var index = this.append_cell([
39 "import os",
40 "from IPython.kernel.comm import Comm",
41 "comm = Comm(target_name='echo')",
42 "msgs = []",
43 "def on_msg(msg):",
44 " msgs.append(msg)",
45 "comm.on_msg(on_msg)"
46 ].join('\n'), 'code');
47 this.execute_cell(index);
48 });
49
50 // send a message with binary data
51 this.then(function () {
52 var index = this.append_cell([
53 "buffers = [b'\\xFF\\x00', b'\\x00\\x01\\x02']",
54 "comm.send(data='hi', buffers=buffers)"
55 ].join('\n'), 'code');
56 this.execute_cell(index);
57 });
58
59 // wait for capture
60 this.waitFor(function () {
61 return this.evaluate(function () {
62 return IPython._msgs.length > 0;
63 });
64 });
65
66 // validate captured buffers js-side
67 this.then(function () {
68 var msgs = this.evaluate(function () {
69 return IPython._msgs;
70 });
71 this.test.assertEquals(msgs.length, 1, "Captured comm message");
72 var buffers = msgs[0].buffers;
73 this.test.assertEquals(buffers.length, 2, "comm message has buffers");
74
75 // extract attributes to test in evaluate,
76 // because the raw DataViews can't be passed across
77 var buf_info = function (index) {
78 var buf = IPython._msgs[0].buffers[index];
79 var data = {};
80 data.byteLength = buf.byteLength;
81 data.bytes = [];
82 for (var i = 0; i < data.byteLength; i++) {
83 data.bytes.push(buf.getUint8(i));
84 }
85 return data;
86 };
87
88 buf0 = this.evaluate(buf_info, 0);
89 buf1 = this.evaluate(buf_info, 1);
90 this.test.assertEquals(buf0.byteLength, 2, 'buf[0] has correct size');
91 this.test.assertEquals(buf0.bytes, [255, 0], 'buf[0] has correct bytes');
92 this.test.assertEquals(buf1.byteLength, 3, 'buf[1] has correct size');
93 this.test.assertEquals(buf1.bytes, [0, 1, 2], 'buf[1] has correct bytes');
94 });
95
96 // validate captured buffers Python-side
97 this.then(function () {
98 var index = this.append_cell([
99 "assert len(msgs) == 1, len(msgs)",
100 "bufs = msgs[0]['buffers']",
101 "assert len(bufs) == len(buffers), bufs",
102 "assert bufs[0].bytes == buffers[0], bufs[0].bytes",
103 "assert bufs[1].bytes == buffers[1], bufs[1].bytes",
104 "1",
105 ].join('\n'), 'code');
106 this.execute_cell(index);
107 this.wait_for_output(index);
108 this.then(function () {
109 var out = this.get_output_cell(index);
110 this.test.assertEquals(out['text/plain'], '1', "Python received buffers");
111 });
112 });
113 });
@@ -0,0 +1,26
1 """Test serialize/deserialize messages with buffers"""
2
3 import os
4
5 import nose.tools as nt
6
7 from IPython.kernel.zmq.session import Session
8 from ..base.zmqhandlers import (
9 serialize_binary_message,
10 deserialize_binary_message,
11 )
12
13 def test_serialize_binary():
14 s = Session()
15 msg = s.msg('data_pub', content={'a': 'b'})
16 msg['buffers'] = [ os.urandom(3) for i in range(3) ]
17 bmsg = serialize_binary_message(msg)
18 nt.assert_is_instance(bmsg, bytes)
19
20 def test_deserialize_binary():
21 s = Session()
22 msg = s.msg('data_pub', content={'a': 'b'})
23 msg['buffers'] = [ os.urandom(2) for i in range(3) ]
24 bmsg = serialize_binary_message(msg)
25 msg2 = deserialize_binary_message(bmsg)
26 nt.assert_equal(msg2, msg)
@@ -1,202 +1,257
1 1 """Tornado handlers for WebSocket <-> ZMQ sockets."""
2 2
3 3 # Copyright (c) IPython Development Team.
4 4 # Distributed under the terms of the Modified BSD License.
5 5
6 6 import json
7 import struct
7 8
8 9 try:
9 10 from urllib.parse import urlparse # Py 3
10 11 except ImportError:
11 12 from urlparse import urlparse # Py 2
12 13
13 14 try:
14 15 from http.cookies import SimpleCookie # Py 3
15 16 except ImportError:
16 17 from Cookie import SimpleCookie # Py 2
17 18 import logging
18 19
19 20 import tornado
20 21 from tornado import ioloop
21 22 from tornado import web
22 23 from tornado import websocket
23 24
24 25 from IPython.kernel.zmq.session import Session
25 from IPython.utils.jsonutil import date_default
26 from IPython.utils.jsonutil import date_default, extract_dates
26 27 from IPython.utils.py3compat import PY3, cast_unicode
27 28
28 29 from .handlers import IPythonHandler
29 30
30 31
32 def serialize_binary_message(msg):
33 """serialize a message as a binary blob
34
35 Header:
36
37 4 bytes: number of msg parts (nbufs) as 32b int
38 4 * nbufs bytes: offset for each buffer as integer as 32b int
39
40 Offsets are from the start of the buffer, including the header.
41
42 Returns
43 -------
44
45 The message serialized to bytes.
46
47 """
48 # don't modify msg or buffer list in-place
49 msg = msg.copy()
50 buffers = list(msg.pop('buffers'))
51 bmsg = json.dumps(msg, default=date_default).encode('utf8')
52 buffers.insert(0, bmsg)
53 nbufs = len(buffers)
54 offsets = [4 * (nbufs + 1)]
55 for buf in buffers[:-1]:
56 offsets.append(offsets[-1] + len(buf))
57 offsets_buf = struct.pack('!' + 'I' * (nbufs + 1), nbufs, *offsets)
58 buffers.insert(0, offsets_buf)
59 return b''.join(buffers)
60
61
62 def deserialize_binary_message(bmsg):
63 """deserialize a message from a binary blog
64
65 Header:
66
67 4 bytes: number of msg parts (nbufs) as 32b int
68 4 * nbufs bytes: offset for each buffer as integer as 32b int
69
70 Offsets are from the start of the buffer, including the header.
71
72 Returns
73 -------
74
75 message dictionary
76 """
77 nbufs = struct.unpack('!i', bmsg[:4])[0]
78 offsets = list(struct.unpack('!' + 'I' * nbufs, bmsg[4:4*(nbufs+1)]))
79 offsets.append(None)
80 bufs = []
81 for start, stop in zip(offsets[:-1], offsets[1:]):
82 bufs.append(bmsg[start:stop])
83 msg = json.loads(bufs[0].decode('utf8'))
84 msg['header'] = extract_dates(msg['header'])
85 msg['parent_header'] = extract_dates(msg['parent_header'])
86 msg['buffers'] = bufs[1:]
87 return msg
88
89
31 90 class ZMQStreamHandler(websocket.WebSocketHandler):
32 91
33 92 def check_origin(self, origin):
34 93 """Check Origin == Host or Access-Control-Allow-Origin.
35 94
36 95 Tornado >= 4 calls this method automatically, raising 403 if it returns False.
37 96 We call it explicitly in `open` on Tornado < 4.
38 97 """
39 98 if self.allow_origin == '*':
40 99 return True
41 100
42 101 host = self.request.headers.get("Host")
43 102
44 103 # If no header is provided, assume we can't verify origin
45 104 if origin is None:
46 105 self.log.warn("Missing Origin header, rejecting WebSocket connection.")
47 106 return False
48 107 if host is None:
49 108 self.log.warn("Missing Host header, rejecting WebSocket connection.")
50 109 return False
51 110
52 111 origin = origin.lower()
53 112 origin_host = urlparse(origin).netloc
54 113
55 114 # OK if origin matches host
56 115 if origin_host == host:
57 116 return True
58 117
59 118 # Check CORS headers
60 119 if self.allow_origin:
61 120 allow = self.allow_origin == origin
62 121 elif self.allow_origin_pat:
63 122 allow = bool(self.allow_origin_pat.match(origin))
64 123 else:
65 124 # No CORS headers deny the request
66 125 allow = False
67 126 if not allow:
68 127 self.log.warn("Blocking Cross Origin WebSocket Attempt. Origin: %s, Host: %s",
69 128 origin, host,
70 129 )
71 130 return allow
72 131
73 132 def clear_cookie(self, *args, **kwargs):
74 133 """meaningless for websockets"""
75 134 pass
76 135
77 136 def _reserialize_reply(self, msg_list):
78 137 """Reserialize a reply message using JSON.
79 138
80 This takes the msg list from the ZMQ socket, unserializes it using
139 This takes the msg list from the ZMQ socket, deserializes it using
81 140 self.session and then serializes the result using JSON. This method
82 141 should be used by self._on_zmq_reply to build messages that can
83 142 be sent back to the browser.
84 143 """
85 144 idents, msg_list = self.session.feed_identities(msg_list)
86 msg = self.session.unserialize(msg_list)
87 try:
88 msg['header'].pop('date')
89 except KeyError:
90 pass
91 try:
92 msg['parent_header'].pop('date')
93 except KeyError:
94 pass
95 msg.pop('buffers')
96 return json.dumps(msg, default=date_default)
145 msg = self.session.deserialize(msg_list)
146 if msg['buffers']:
147 buf = serialize_binary_message(msg)
148 return buf
149 else:
150 smsg = json.dumps(msg, default=date_default)
151 return cast_unicode(smsg)
97 152
98 153 def _on_zmq_reply(self, msg_list):
99 154 # Sometimes this gets triggered when the on_close method is scheduled in the
100 155 # eventloop but hasn't been called.
101 156 if self.stream.closed(): return
102 157 try:
103 158 msg = self._reserialize_reply(msg_list)
104 159 except Exception:
105 160 self.log.critical("Malformed message: %r" % msg_list, exc_info=True)
106 161 else:
107 self.write_message(msg)
162 self.write_message(msg, binary=isinstance(msg, bytes))
108 163
109 164 def allow_draft76(self):
110 165 """Allow draft 76, until browsers such as Safari update to RFC 6455.
111 166
112 167 This has been disabled by default in tornado in release 2.2.0, and
113 168 support will be removed in later versions.
114 169 """
115 170 return True
116 171
117 172 # ping interval for keeping websockets alive (30 seconds)
118 173 WS_PING_INTERVAL = 30000
119 174
120 175 class AuthenticatedZMQStreamHandler(ZMQStreamHandler, IPythonHandler):
121 176 ping_callback = None
122 177 last_ping = 0
123 178 last_pong = 0
124 179
125 180 @property
126 181 def ping_interval(self):
127 182 """The interval for websocket keep-alive pings.
128 183
129 184 Set ws_ping_interval = 0 to disable pings.
130 185 """
131 186 return self.settings.get('ws_ping_interval', WS_PING_INTERVAL)
132 187
133 188 @property
134 189 def ping_timeout(self):
135 190 """If no ping is received in this many milliseconds,
136 191 close the websocket connection (VPNs, etc. can fail to cleanly close ws connections).
137 192 Default is max of 3 pings or 30 seconds.
138 193 """
139 194 return self.settings.get('ws_ping_timeout',
140 195 max(3 * self.ping_interval, WS_PING_INTERVAL)
141 196 )
142 197
143 198 def set_default_headers(self):
144 199 """Undo the set_default_headers in IPythonHandler
145 200
146 201 which doesn't make sense for websockets
147 202 """
148 203 pass
149 204
150 205 def get(self, *args, **kwargs):
151 206 # Check to see that origin matches host directly, including ports
152 207 # Tornado 4 already does CORS checking
153 208 if tornado.version_info[0] < 4:
154 209 if not self.check_origin(self.get_origin()):
155 210 raise web.HTTPError(403)
156 211
157 212 # authenticate the request before opening the websocket
158 213 if self.get_current_user() is None:
159 214 self.log.warn("Couldn't authenticate WebSocket connection")
160 215 raise web.HTTPError(403)
161 216
162 217 if self.get_argument('session_id', False):
163 218 self.session.session = cast_unicode(self.get_argument('session_id'))
164 219 else:
165 220 self.log.warn("No session ID specified")
166 221
167 222 return super(AuthenticatedZMQStreamHandler, self).get(*args, **kwargs)
168 223
169 224 def initialize(self):
170 225 self.session = Session(config=self.config)
171 226
172 227 def open(self, kernel_id):
173 228 self.kernel_id = cast_unicode(kernel_id, 'ascii')
174 229
175 230 # start the pinging
176 231 if self.ping_interval > 0:
177 232 self.last_ping = ioloop.IOLoop.instance().time() # Remember time of last ping
178 233 self.last_pong = self.last_ping
179 234 self.ping_callback = ioloop.PeriodicCallback(self.send_ping, self.ping_interval)
180 235 self.ping_callback.start()
181 236
182 237 def send_ping(self):
183 238 """send a ping to keep the websocket alive"""
184 239 if self.stream.closed() and self.ping_callback is not None:
185 240 self.ping_callback.stop()
186 241 return
187 242
188 243 # check for timeout on pong. Make sure that we really have sent a recent ping in
189 244 # case the machine with both server and client has been suspended since the last ping.
190 245 now = ioloop.IOLoop.instance().time()
191 246 since_last_pong = 1e3 * (now - self.last_pong)
192 247 since_last_ping = 1e3 * (now - self.last_ping)
193 248 if since_last_ping < 2*self.ping_interval and since_last_pong > self.ping_timeout:
194 249 self.log.warn("WebSocket ping timeout after %i ms.", since_last_pong)
195 250 self.close()
196 251 return
197 252
198 253 self.ping(b'')
199 254 self.last_ping = now
200 255
201 256 def on_pong(self, data):
202 257 self.last_pong = ioloop.IOLoop.instance().time()
@@ -1,230 +1,233
1 1 """Tornado handlers for kernels."""
2 2
3 3 # Copyright (c) IPython Development Team.
4 4 # Distributed under the terms of the Modified BSD License.
5 5
6 6 import json
7 7 import logging
8 8 from tornado import web
9 9
10 10 from IPython.utils.jsonutil import date_default
11 11 from IPython.utils.py3compat import string_types
12 12 from IPython.html.utils import url_path_join, url_escape
13 13
14 14 from ...base.handlers import IPythonHandler, json_errors
15 from ...base.zmqhandlers import AuthenticatedZMQStreamHandler
15 from ...base.zmqhandlers import AuthenticatedZMQStreamHandler, deserialize_binary_message
16 16
17 17 from IPython.core.release import kernel_protocol_version
18 18
19 19 class MainKernelHandler(IPythonHandler):
20 20
21 21 @web.authenticated
22 22 @json_errors
23 23 def get(self):
24 24 km = self.kernel_manager
25 25 self.finish(json.dumps(km.list_kernels()))
26 26
27 27 @web.authenticated
28 28 @json_errors
29 29 def post(self):
30 30 km = self.kernel_manager
31 31 model = self.get_json_body()
32 32 if model is None:
33 33 model = {
34 34 'name': km.default_kernel_name
35 35 }
36 36 else:
37 37 model.setdefault('name', km.default_kernel_name)
38 38
39 39 kernel_id = km.start_kernel(kernel_name=model['name'])
40 40 model = km.kernel_model(kernel_id)
41 41 location = url_path_join(self.base_url, 'api', 'kernels', kernel_id)
42 42 self.set_header('Location', url_escape(location))
43 43 self.set_status(201)
44 44 self.finish(json.dumps(model))
45 45
46 46
47 47 class KernelHandler(IPythonHandler):
48 48
49 49 SUPPORTED_METHODS = ('DELETE', 'GET')
50 50
51 51 @web.authenticated
52 52 @json_errors
53 53 def get(self, kernel_id):
54 54 km = self.kernel_manager
55 55 km._check_kernel_id(kernel_id)
56 56 model = km.kernel_model(kernel_id)
57 57 self.finish(json.dumps(model))
58 58
59 59 @web.authenticated
60 60 @json_errors
61 61 def delete(self, kernel_id):
62 62 km = self.kernel_manager
63 63 km.shutdown_kernel(kernel_id)
64 64 self.set_status(204)
65 65 self.finish()
66 66
67 67
68 68 class KernelActionHandler(IPythonHandler):
69 69
70 70 @web.authenticated
71 71 @json_errors
72 72 def post(self, kernel_id, action):
73 73 km = self.kernel_manager
74 74 if action == 'interrupt':
75 75 km.interrupt_kernel(kernel_id)
76 76 self.set_status(204)
77 77 if action == 'restart':
78 78 km.restart_kernel(kernel_id)
79 79 model = km.kernel_model(kernel_id)
80 80 self.set_header('Location', '{0}api/kernels/{1}'.format(self.base_url, kernel_id))
81 81 self.write(json.dumps(model))
82 82 self.finish()
83 83
84 84
85 85 class ZMQChannelHandler(AuthenticatedZMQStreamHandler):
86 86
87 87 def __repr__(self):
88 88 return "%s(%s)" % (self.__class__.__name__, getattr(self, 'kernel_id', 'uninitialized'))
89 89
90 90 def create_stream(self):
91 91 km = self.kernel_manager
92 92 meth = getattr(km, 'connect_%s' % self.channel)
93 93 self.zmq_stream = meth(self.kernel_id, identity=self.session.bsession)
94 94 # Create a kernel_info channel to query the kernel protocol version.
95 95 # This channel will be closed after the kernel_info reply is received.
96 96 self.kernel_info_channel = None
97 97 self.kernel_info_channel = km.connect_shell(self.kernel_id)
98 98 self.kernel_info_channel.on_recv(self._handle_kernel_info_reply)
99 99 self._request_kernel_info()
100 100
101 101 def _request_kernel_info(self):
102 102 """send a request for kernel_info"""
103 103 self.log.debug("requesting kernel info")
104 104 self.session.send(self.kernel_info_channel, "kernel_info_request")
105 105
106 106 def _handle_kernel_info_reply(self, msg):
107 107 """process the kernel_info_reply
108 108
109 109 enabling msg spec adaptation, if necessary
110 110 """
111 111 idents,msg = self.session.feed_identities(msg)
112 112 try:
113 msg = self.session.unserialize(msg)
113 msg = self.session.deserialize(msg)
114 114 except:
115 115 self.log.error("Bad kernel_info reply", exc_info=True)
116 116 self._request_kernel_info()
117 117 return
118 118 else:
119 119 if msg['msg_type'] != 'kernel_info_reply' or 'protocol_version' not in msg['content']:
120 120 self.log.error("Kernel info request failed, assuming current %s", msg['content'])
121 121 else:
122 122 protocol_version = msg['content']['protocol_version']
123 123 if protocol_version != kernel_protocol_version:
124 124 self.session.adapt_version = int(protocol_version.split('.')[0])
125 125 self.log.info("adapting kernel to %s" % protocol_version)
126 126 self.kernel_info_channel.close()
127 127 self.kernel_info_channel = None
128 128
129 129 def initialize(self):
130 130 super(ZMQChannelHandler, self).initialize()
131 131 self.zmq_stream = None
132 132
133 133 def open(self, kernel_id):
134 134 super(ZMQChannelHandler, self).open(kernel_id)
135 135 try:
136 136 self.create_stream()
137 137 except web.HTTPError:
138 138 # WebSockets don't response to traditional error codes so we
139 139 # close the connection.
140 140 if not self.stream.closed():
141 141 self.stream.close()
142 142 self.close()
143 143 else:
144 144 self.zmq_stream.on_recv(self._on_zmq_reply)
145 145
146 146 def on_message(self, msg):
147 147 if self.zmq_stream is None:
148 148 return
149 149 elif self.zmq_stream.closed():
150 150 self.log.info("%s closed, closing websocket.", self)
151 151 self.close()
152 152 return
153 if isinstance(msg, bytes):
154 msg = deserialize_binary_message(msg)
155 else:
153 156 msg = json.loads(msg)
154 157 self.session.send(self.zmq_stream, msg)
155 158
156 159 def on_close(self):
157 160 # This method can be called twice, once by self.kernel_died and once
158 161 # from the WebSocket close event. If the WebSocket connection is
159 162 # closed before the ZMQ streams are setup, they could be None.
160 163 if self.zmq_stream is not None and not self.zmq_stream.closed():
161 164 self.zmq_stream.on_recv(None)
162 165 # close the socket directly, don't wait for the stream
163 166 socket = self.zmq_stream.socket
164 167 self.zmq_stream.close()
165 168 socket.close()
166 169
167 170
168 171 class IOPubHandler(ZMQChannelHandler):
169 172 channel = 'iopub'
170 173
171 174 def create_stream(self):
172 175 super(IOPubHandler, self).create_stream()
173 176 km = self.kernel_manager
174 177 km.add_restart_callback(self.kernel_id, self.on_kernel_restarted)
175 178 km.add_restart_callback(self.kernel_id, self.on_restart_failed, 'dead')
176 179
177 180 def on_close(self):
178 181 km = self.kernel_manager
179 182 if self.kernel_id in km:
180 183 km.remove_restart_callback(
181 184 self.kernel_id, self.on_kernel_restarted,
182 185 )
183 186 km.remove_restart_callback(
184 187 self.kernel_id, self.on_restart_failed, 'dead',
185 188 )
186 189 super(IOPubHandler, self).on_close()
187 190
188 191 def _send_status_message(self, status):
189 192 msg = self.session.msg("status",
190 193 {'execution_state': status}
191 194 )
192 195 self.write_message(json.dumps(msg, default=date_default))
193 196
194 197 def on_kernel_restarted(self):
195 198 logging.warn("kernel %s restarted", self.kernel_id)
196 199 self._send_status_message('restarting')
197 200
198 201 def on_restart_failed(self):
199 202 logging.error("kernel %s restarted failed!", self.kernel_id)
200 203 self._send_status_message('dead')
201 204
202 205 def on_message(self, msg):
203 206 """IOPub messages make no sense"""
204 207 pass
205 208
206 209
207 210 class ShellHandler(ZMQChannelHandler):
208 211 channel = 'shell'
209 212
210 213
211 214 class StdinHandler(ZMQChannelHandler):
212 215 channel = 'stdin'
213 216
214 217
215 218 #-----------------------------------------------------------------------------
216 219 # URL to handler mappings
217 220 #-----------------------------------------------------------------------------
218 221
219 222
220 223 _kernel_id_regex = r"(?P<kernel_id>\w+-\w+-\w+-\w+-\w+)"
221 224 _kernel_action_regex = r"(?P<action>restart|interrupt)"
222 225
223 226 default_handlers = [
224 227 (r"/api/kernels", MainKernelHandler),
225 228 (r"/api/kernels/%s" % _kernel_id_regex, KernelHandler),
226 229 (r"/api/kernels/%s/%s" % (_kernel_id_regex, _kernel_action_regex), KernelActionHandler),
227 230 (r"/api/kernels/%s/iopub" % _kernel_id_regex, IOPubHandler),
228 231 (r"/api/kernels/%s/shell" % _kernel_id_regex, ShellHandler),
229 232 (r"/api/kernels/%s/stdin" % _kernel_id_regex, StdinHandler)
230 233 ]
1 NO CONTENT: modified file
@@ -1,1 +1,1
1 Subproject commit 56b35d85bb0ea150458282f4064292a5c211025a
1 Subproject commit 1968f4f78d7e8cd227d0b3f4cc3183591969b52a
@@ -1,190 +1,190
1 1 // Copyright (c) IPython Development Team.
2 2 // Distributed under the terms of the Modified BSD License.
3 3
4 4 define([
5 5 'base/js/namespace',
6 6 'jquery',
7 7 'base/js/utils',
8 8 ], function(IPython, $, utils) {
9 9 "use strict";
10 10
11 11 //-----------------------------------------------------------------------
12 12 // CommManager class
13 13 //-----------------------------------------------------------------------
14 14
15 15 var CommManager = function (kernel) {
16 16 this.comms = {};
17 17 this.targets = {};
18 18 if (kernel !== undefined) {
19 19 this.init_kernel(kernel);
20 20 }
21 21 };
22 22
23 23 CommManager.prototype.init_kernel = function (kernel) {
24 24 // connect the kernel, and register message handlers
25 25 this.kernel = kernel;
26 26 var msg_types = ['comm_open', 'comm_msg', 'comm_close'];
27 27 for (var i = 0; i < msg_types.length; i++) {
28 28 var msg_type = msg_types[i];
29 29 kernel.register_iopub_handler(msg_type, $.proxy(this[msg_type], this));
30 30 }
31 31 };
32 32
33 33 CommManager.prototype.new_comm = function (target_name, data, callbacks, metadata) {
34 34 // Create a new Comm, register it, and open its Kernel-side counterpart
35 35 // Mimics the auto-registration in `Comm.__init__` in the IPython Comm
36 36 var comm = new Comm(target_name);
37 37 this.register_comm(comm);
38 38 comm.open(data, callbacks, metadata);
39 39 return comm;
40 40 };
41 41
42 42 CommManager.prototype.register_target = function (target_name, f) {
43 43 // Register a target function for a given target name
44 44 this.targets[target_name] = f;
45 45 };
46 46
47 47 CommManager.prototype.unregister_target = function (target_name, f) {
48 48 // Unregister a target function for a given target name
49 49 delete this.targets[target_name];
50 50 };
51 51
52 52 CommManager.prototype.register_comm = function (comm) {
53 53 // Register a comm in the mapping
54 54 this.comms[comm.comm_id] = comm;
55 55 comm.kernel = this.kernel;
56 56 return comm.comm_id;
57 57 };
58 58
59 59 CommManager.prototype.unregister_comm = function (comm) {
60 60 // Remove a comm from the mapping
61 61 delete this.comms[comm.comm_id];
62 62 };
63 63
64 64 // comm message handlers
65 65
66 66 CommManager.prototype.comm_open = function (msg) {
67 67 var content = msg.content;
68 68 var f = this.targets[content.target_name];
69 69 if (f === undefined) {
70 70 console.log("No such target registered: ", content.target_name);
71 71 console.log("Available targets are: ", this.targets);
72 72 return;
73 73 }
74 74 var comm = new Comm(content.target_name, content.comm_id);
75 75 this.register_comm(comm);
76 76 try {
77 77 f(comm, msg);
78 78 } catch (e) {
79 79 console.log("Exception opening new comm:", e, e.stack, msg);
80 80 comm.close();
81 81 this.unregister_comm(comm);
82 82 }
83 83 };
84 84
85 85 CommManager.prototype.comm_close = function (msg) {
86 86 var content = msg.content;
87 87 var comm = this.comms[content.comm_id];
88 88 if (comm === undefined) {
89 89 return;
90 90 }
91 91 this.unregister_comm(comm);
92 92 try {
93 93 comm.handle_close(msg);
94 94 } catch (e) {
95 95 console.log("Exception closing comm: ", e, e.stack, msg);
96 96 }
97 97 };
98 98
99 99 CommManager.prototype.comm_msg = function (msg) {
100 100 var content = msg.content;
101 101 var comm = this.comms[content.comm_id];
102 102 if (comm === undefined) {
103 103 return;
104 104 }
105 105 try {
106 106 comm.handle_msg(msg);
107 107 } catch (e) {
108 108 console.log("Exception handling comm msg: ", e, e.stack, msg);
109 109 }
110 110 };
111 111
112 112 //-----------------------------------------------------------------------
113 113 // Comm base class
114 114 //-----------------------------------------------------------------------
115 115
116 116 var Comm = function (target_name, comm_id) {
117 117 this.target_name = target_name;
118 118 this.comm_id = comm_id || utils.uuid();
119 119 this._msg_callback = this._close_callback = null;
120 120 };
121 121
122 122 // methods for sending messages
123 123 Comm.prototype.open = function (data, callbacks, metadata) {
124 124 var content = {
125 125 comm_id : this.comm_id,
126 126 target_name : this.target_name,
127 127 data : data || {},
128 128 };
129 129 return this.kernel.send_shell_message("comm_open", content, callbacks, metadata);
130 130 };
131 131
132 Comm.prototype.send = function (data, callbacks, metadata) {
132 Comm.prototype.send = function (data, callbacks, metadata, buffers) {
133 133 var content = {
134 134 comm_id : this.comm_id,
135 135 data : data || {},
136 136 };
137 return this.kernel.send_shell_message("comm_msg", content, callbacks, metadata);
137 return this.kernel.send_shell_message("comm_msg", content, callbacks, metadata, buffers);
138 138 };
139 139
140 140 Comm.prototype.close = function (data, callbacks, metadata) {
141 141 var content = {
142 142 comm_id : this.comm_id,
143 143 data : data || {},
144 144 };
145 145 return this.kernel.send_shell_message("comm_close", content, callbacks, metadata);
146 146 };
147 147
148 148 // methods for registering callbacks for incoming messages
149 149 Comm.prototype._register_callback = function (key, callback) {
150 150 this['_' + key + '_callback'] = callback;
151 151 };
152 152
153 153 Comm.prototype.on_msg = function (callback) {
154 154 this._register_callback('msg', callback);
155 155 };
156 156
157 157 Comm.prototype.on_close = function (callback) {
158 158 this._register_callback('close', callback);
159 159 };
160 160
161 161 // methods for handling incoming messages
162 162
163 163 Comm.prototype._maybe_callback = function (key, msg) {
164 164 var callback = this['_' + key + '_callback'];
165 165 if (callback) {
166 166 try {
167 167 callback(msg);
168 168 } catch (e) {
169 169 console.log("Exception in Comm callback", e, e.stack, msg);
170 170 }
171 171 }
172 172 };
173 173
174 174 Comm.prototype.handle_msg = function (msg) {
175 175 this._maybe_callback('msg', msg);
176 176 };
177 177
178 178 Comm.prototype.handle_close = function (msg) {
179 179 this._maybe_callback('close', msg);
180 180 };
181 181
182 182 // For backwards compatability.
183 183 IPython.CommManager = CommManager;
184 184 IPython.Comm = Comm;
185 185
186 186 return {
187 187 'CommManager': CommManager,
188 188 'Comm': Comm
189 189 };
190 190 });
@@ -1,1014 +1,1026
1 1 // Copyright (c) IPython Development Team.
2 2 // Distributed under the terms of the Modified BSD License.
3 3
4 4 define([
5 5 'base/js/namespace',
6 6 'jquery',
7 7 'base/js/utils',
8 'services/kernels/js/comm',
9 'widgets/js/init',
10 ], function(IPython, $, utils, comm, widgetmanager) {
8 './comm',
9 './serialize',
10 'widgets/js/init'
11 ], function(IPython, $, utils, comm, serialize, widgetmanager) {
11 12 "use strict";
12 13
13 14 /**
14 15 * A Kernel class to communicate with the Python kernel. This
15 16 * should generally not be constructed directly, but be created
16 17 * by. the `Session` object. Once created, this object should be
17 18 * used to communicate with the kernel.
18 19 *
19 20 * @class Kernel
20 21 * @param {string} kernel_service_url - the URL to access the kernel REST api
21 22 * @param {string} ws_url - the websockets URL
22 23 * @param {Notebook} notebook - notebook object
23 24 * @param {string} name - the kernel type (e.g. python3)
24 25 */
25 26 var Kernel = function (kernel_service_url, ws_url, notebook, name) {
26 27 this.events = notebook.events;
27 28
28 29 this.id = null;
29 30 this.name = name;
30 31
31 32 this.channels = {
32 33 'shell': null,
33 34 'iopub': null,
34 35 'stdin': null
35 36 };
36 37
37 38 this.kernel_service_url = kernel_service_url;
38 39 this.kernel_url = null;
39 40 this.ws_url = ws_url || IPython.utils.get_body_data("wsUrl");
40 41 if (!this.ws_url) {
41 42 // trailing 's' in https will become wss for secure web sockets
42 43 this.ws_url = location.protocol.replace('http', 'ws') + "//" + location.host;
43 44 }
44 45
45 46 this.username = "username";
46 47 this.session_id = utils.uuid();
47 48 this._msg_callbacks = {};
48 49
49 50 if (typeof(WebSocket) !== 'undefined') {
50 51 this.WebSocket = WebSocket;
51 52 } else if (typeof(MozWebSocket) !== 'undefined') {
52 53 this.WebSocket = MozWebSocket;
53 54 } else {
54 55 alert('Your browser does not have WebSocket support, please try Chrome, Safari or Firefox β‰₯ 6. Firefox 4 and 5 are also supported by you have to enable WebSockets in about:config.');
55 56 }
56 57
57 58 this.bind_events();
58 59 this.init_iopub_handlers();
59 60 this.comm_manager = new comm.CommManager(this);
60 61 this.widget_manager = new widgetmanager.WidgetManager(this.comm_manager, notebook);
61 62
62 63 this.last_msg_id = null;
63 64 this.last_msg_callbacks = {};
64 65
65 66 this._autorestart_attempt = 0;
66 67 this._reconnect_attempt = 0;
67 68 };
68 69
69 70 /**
70 71 * @function _get_msg
71 72 */
72 Kernel.prototype._get_msg = function (msg_type, content, metadata) {
73 Kernel.prototype._get_msg = function (msg_type, content, metadata, buffers) {
73 74 var msg = {
74 75 header : {
75 76 msg_id : utils.uuid(),
76 77 username : this.username,
77 78 session : this.session_id,
78 79 msg_type : msg_type,
79 80 version : "5.0"
80 81 },
81 82 metadata : metadata || {},
82 83 content : content,
84 buffers : buffers || [],
83 85 parent_header : {}
84 86 };
85 87 return msg;
86 88 };
87 89
88 90 /**
89 91 * @function bind_events
90 92 */
91 93 Kernel.prototype.bind_events = function () {
92 94 var that = this;
93 95 this.events.on('send_input_reply.Kernel', function(evt, data) {
94 96 that.send_input_reply(data);
95 97 });
96 98
97 99 var record_status = function (evt, info) {
98 100 console.log('Kernel: ' + evt.type + ' (' + info.kernel.id + ')');
99 101 };
100 102
101 103 this.events.on('kernel_created.Kernel', record_status);
102 104 this.events.on('kernel_reconnecting.Kernel', record_status);
103 105 this.events.on('kernel_connected.Kernel', record_status);
104 106 this.events.on('kernel_starting.Kernel', record_status);
105 107 this.events.on('kernel_restarting.Kernel', record_status);
106 108 this.events.on('kernel_autorestarting.Kernel', record_status);
107 109 this.events.on('kernel_interrupting.Kernel', record_status);
108 110 this.events.on('kernel_disconnected.Kernel', record_status);
109 111 // these are commented out because they are triggered a lot, but can
110 112 // be uncommented for debugging purposes
111 113 //this.events.on('kernel_idle.Kernel', record_status);
112 114 //this.events.on('kernel_busy.Kernel', record_status);
113 115 this.events.on('kernel_ready.Kernel', record_status);
114 116 this.events.on('kernel_killed.Kernel', record_status);
115 117 this.events.on('kernel_dead.Kernel', record_status);
116 118
117 119 this.events.on('kernel_ready.Kernel', function () {
118 120 that._autorestart_attempt = 0;
119 121 });
120 122 this.events.on('kernel_connected.Kernel', function () {
121 123 that._reconnect_attempt = 0;
122 124 });
123 125 };
124 126
125 127 /**
126 128 * Initialize the iopub handlers.
127 129 *
128 130 * @function init_iopub_handlers
129 131 */
130 132 Kernel.prototype.init_iopub_handlers = function () {
131 133 var output_msg_types = ['stream', 'display_data', 'execute_result', 'error'];
132 134 this._iopub_handlers = {};
133 135 this.register_iopub_handler('status', $.proxy(this._handle_status_message, this));
134 136 this.register_iopub_handler('clear_output', $.proxy(this._handle_clear_output, this));
135 137
136 138 for (var i=0; i < output_msg_types.length; i++) {
137 139 this.register_iopub_handler(output_msg_types[i], $.proxy(this._handle_output_message, this));
138 140 }
139 141 };
140 142
141 143 /**
142 144 * GET /api/kernels
143 145 *
144 146 * Get the list of running kernels.
145 147 *
146 148 * @function list
147 149 * @param {function} [success] - function executed on ajax success
148 150 * @param {function} [error] - functon executed on ajax error
149 151 */
150 152 Kernel.prototype.list = function (success, error) {
151 153 $.ajax(this.kernel_service_url, {
152 154 processData: false,
153 155 cache: false,
154 156 type: "GET",
155 157 dataType: "json",
156 158 success: success,
157 159 error: this._on_error(error)
158 160 });
159 161 };
160 162
161 163 /**
162 164 * POST /api/kernels
163 165 *
164 166 * Start a new kernel.
165 167 *
166 168 * In general this shouldn't be used -- the kernel should be
167 169 * started through the session API. If you use this function and
168 170 * are also using the session API then your session and kernel
169 171 * WILL be out of sync!
170 172 *
171 173 * @function start
172 174 * @param {params} [Object] - parameters to include in the query string
173 175 * @param {function} [success] - function executed on ajax success
174 176 * @param {function} [error] - functon executed on ajax error
175 177 */
176 178 Kernel.prototype.start = function (params, success, error) {
177 179 var url = this.kernel_service_url;
178 180 var qs = $.param(params || {}); // query string for sage math stuff
179 181 if (qs !== "") {
180 182 url = url + "?" + qs;
181 183 }
182 184
183 185 var that = this;
184 186 var on_success = function (data, status, xhr) {
185 187 that.events.trigger('kernel_created.Kernel', {kernel: that});
186 188 that._kernel_created(data);
187 189 if (success) {
188 190 success(data, status, xhr);
189 191 }
190 192 };
191 193
192 194 $.ajax(url, {
193 195 processData: false,
194 196 cache: false,
195 197 type: "POST",
196 198 data: JSON.stringify({name: this.name}),
197 199 dataType: "json",
198 200 success: this._on_success(on_success),
199 201 error: this._on_error(error)
200 202 });
201 203
202 204 return url;
203 205 };
204 206
205 207 /**
206 208 * GET /api/kernels/[:kernel_id]
207 209 *
208 210 * Get information about the kernel.
209 211 *
210 212 * @function get_info
211 213 * @param {function} [success] - function executed on ajax success
212 214 * @param {function} [error] - functon executed on ajax error
213 215 */
214 216 Kernel.prototype.get_info = function (success, error) {
215 217 $.ajax(this.kernel_url, {
216 218 processData: false,
217 219 cache: false,
218 220 type: "GET",
219 221 dataType: "json",
220 222 success: this._on_success(success),
221 223 error: this._on_error(error)
222 224 });
223 225 };
224 226
225 227 /**
226 228 * DELETE /api/kernels/[:kernel_id]
227 229 *
228 230 * Shutdown the kernel.
229 231 *
230 232 * If you are also using sessions, then this function shoul NOT be
231 233 * used. Instead, use Session.delete. Otherwise, the session and
232 234 * kernel WILL be out of sync.
233 235 *
234 236 * @function kill
235 237 * @param {function} [success] - function executed on ajax success
236 238 * @param {function} [error] - functon executed on ajax error
237 239 */
238 240 Kernel.prototype.kill = function (success, error) {
239 241 this.events.trigger('kernel_killed.Kernel', {kernel: this});
240 242 this._kernel_dead();
241 243 $.ajax(this.kernel_url, {
242 244 processData: false,
243 245 cache: false,
244 246 type: "DELETE",
245 247 dataType: "json",
246 248 success: this._on_success(success),
247 249 error: this._on_error(error)
248 250 });
249 251 };
250 252
251 253 /**
252 254 * POST /api/kernels/[:kernel_id]/interrupt
253 255 *
254 256 * Interrupt the kernel.
255 257 *
256 258 * @function interrupt
257 259 * @param {function} [success] - function executed on ajax success
258 260 * @param {function} [error] - functon executed on ajax error
259 261 */
260 262 Kernel.prototype.interrupt = function (success, error) {
261 263 this.events.trigger('kernel_interrupting.Kernel', {kernel: this});
262 264
263 265 var that = this;
264 266 var on_success = function (data, status, xhr) {
265 267 // get kernel info so we know what state the kernel is in
266 268 that.kernel_info();
267 269 if (success) {
268 270 success(data, status, xhr);
269 271 }
270 272 };
271 273
272 274 var url = utils.url_join_encode(this.kernel_url, 'interrupt');
273 275 $.ajax(url, {
274 276 processData: false,
275 277 cache: false,
276 278 type: "POST",
277 279 dataType: "json",
278 280 success: this._on_success(on_success),
279 281 error: this._on_error(error)
280 282 });
281 283 };
282 284
283 285 /**
284 286 * POST /api/kernels/[:kernel_id]/restart
285 287 *
286 288 * Restart the kernel.
287 289 *
288 290 * @function interrupt
289 291 * @param {function} [success] - function executed on ajax success
290 292 * @param {function} [error] - functon executed on ajax error
291 293 */
292 294 Kernel.prototype.restart = function (success, error) {
293 295 this.events.trigger('kernel_restarting.Kernel', {kernel: this});
294 296 this.stop_channels();
295 297
296 298 var that = this;
297 299 var on_success = function (data, status, xhr) {
298 300 that.events.trigger('kernel_created.Kernel', {kernel: that});
299 301 that._kernel_created(data);
300 302 if (success) {
301 303 success(data, status, xhr);
302 304 }
303 305 };
304 306
305 307 var on_error = function (xhr, status, err) {
306 308 that.events.trigger('kernel_dead.Kernel', {kernel: that});
307 309 that._kernel_dead();
308 310 if (error) {
309 311 error(xhr, status, err);
310 312 }
311 313 };
312 314
313 315 var url = utils.url_join_encode(this.kernel_url, 'restart');
314 316 $.ajax(url, {
315 317 processData: false,
316 318 cache: false,
317 319 type: "POST",
318 320 dataType: "json",
319 321 success: this._on_success(on_success),
320 322 error: this._on_error(on_error)
321 323 });
322 324 };
323 325
324 326 /**
325 327 * Reconnect to a disconnected kernel. This is not actually a
326 328 * standard HTTP request, but useful function nonetheless for
327 329 * reconnecting to the kernel if the connection is somehow lost.
328 330 *
329 331 * @function reconnect
330 332 */
331 333 Kernel.prototype.reconnect = function () {
332 334 this.events.trigger('kernel_reconnecting.Kernel', {kernel: this});
333 335 setTimeout($.proxy(this.start_channels, this), 3000);
334 336 };
335 337
336 338 /**
337 339 * Handle a successful AJAX request by updating the kernel id and
338 340 * name from the response, and then optionally calling a provided
339 341 * callback.
340 342 *
341 343 * @function _on_success
342 344 * @param {function} success - callback
343 345 */
344 346 Kernel.prototype._on_success = function (success) {
345 347 var that = this;
346 348 return function (data, status, xhr) {
347 349 if (data) {
348 350 that.id = data.id;
349 351 that.name = data.name;
350 352 }
351 353 that.kernel_url = utils.url_join_encode(that.kernel_service_url, that.id);
352 354 if (success) {
353 355 success(data, status, xhr);
354 356 }
355 357 };
356 358 };
357 359
358 360 /**
359 361 * Handle a failed AJAX request by logging the error message, and
360 362 * then optionally calling a provided callback.
361 363 *
362 364 * @function _on_error
363 365 * @param {function} error - callback
364 366 */
365 367 Kernel.prototype._on_error = function (error) {
366 368 return function (xhr, status, err) {
367 369 utils.log_ajax_error(xhr, status, err);
368 370 if (error) {
369 371 error(xhr, status, err);
370 372 }
371 373 };
372 374 };
373 375
374 376 /**
375 377 * Perform necessary tasks once the kernel has been started,
376 378 * including actually connecting to the kernel.
377 379 *
378 380 * @function _kernel_created
379 381 * @param {Object} data - information about the kernel including id
380 382 */
381 383 Kernel.prototype._kernel_created = function (data) {
382 384 this.id = data.id;
383 385 this.kernel_url = utils.url_join_encode(this.kernel_service_url, this.id);
384 386 this.start_channels();
385 387 };
386 388
387 389 /**
388 390 * Perform necessary tasks once the connection to the kernel has
389 391 * been established. This includes requesting information about
390 392 * the kernel.
391 393 *
392 394 * @function _kernel_connected
393 395 */
394 396 Kernel.prototype._kernel_connected = function () {
395 397 this.events.trigger('kernel_connected.Kernel', {kernel: this});
396 398 this.events.trigger('kernel_starting.Kernel', {kernel: this});
397 399 // get kernel info so we know what state the kernel is in
398 400 var that = this;
399 401 this.kernel_info(function () {
400 402 that.events.trigger('kernel_ready.Kernel', {kernel: that});
401 403 });
402 404 };
403 405
404 406 /**
405 407 * Perform necessary tasks after the kernel has died. This closing
406 408 * communication channels to the kernel if they are still somehow
407 409 * open.
408 410 *
409 411 * @function _kernel_dead
410 412 */
411 413 Kernel.prototype._kernel_dead = function () {
412 414 this.stop_channels();
413 415 };
414 416
415 417 /**
416 418 * Start the `shell`and `iopub` channels.
417 419 * Will stop and restart them if they already exist.
418 420 *
419 421 * @function start_channels
420 422 */
421 423 Kernel.prototype.start_channels = function () {
422 424 var that = this;
423 425 this.stop_channels();
424 426 var ws_host_url = this.ws_url + this.kernel_url;
425 427
426 428 console.log("Starting WebSockets:", ws_host_url);
427 429
428 430 var channel_url = function(channel) {
429 431 return [
430 432 that.ws_url,
431 433 utils.url_join_encode(that.kernel_url, channel),
432 434 "?session_id=" + that.session_id
433 435 ].join('');
434 436 };
435 437 this.channels.shell = new this.WebSocket(channel_url("shell"));
436 438 this.channels.stdin = new this.WebSocket(channel_url("stdin"));
437 439 this.channels.iopub = new this.WebSocket(channel_url("iopub"));
438 440
439 441 var already_called_onclose = false; // only alert once
440 442 var ws_closed_early = function(evt){
441 443 if (already_called_onclose){
442 444 return;
443 445 }
444 446 already_called_onclose = true;
445 447 if ( ! evt.wasClean ){
446 448 // If the websocket was closed early, that could mean
447 449 // that the kernel is actually dead. Try getting
448 450 // information about the kernel from the API call --
449 451 // if that fails, then assume the kernel is dead,
450 452 // otherwise just follow the typical websocket closed
451 453 // protocol.
452 454 that.get_info(function () {
453 455 that._ws_closed(ws_host_url, false);
454 456 }, function () {
455 457 that.events.trigger('kernel_dead.Kernel', {kernel: that});
456 458 that._kernel_dead();
457 459 });
458 460 }
459 461 };
460 462 var ws_closed_late = function(evt){
461 463 if (already_called_onclose){
462 464 return;
463 465 }
464 466 already_called_onclose = true;
465 467 if ( ! evt.wasClean ){
466 468 that._ws_closed(ws_host_url, false);
467 469 }
468 470 };
469 471 var ws_error = function(evt){
470 472 if (already_called_onclose){
471 473 return;
472 474 }
473 475 already_called_onclose = true;
474 476 that._ws_closed(ws_host_url, true);
475 477 };
476 478
477 479 for (var c in this.channels) {
478 480 this.channels[c].onopen = $.proxy(this._ws_opened, this);
479 481 this.channels[c].onclose = ws_closed_early;
480 482 this.channels[c].onerror = ws_error;
481 483 }
482 484 // switch from early-close to late-close message after 1s
483 485 setTimeout(function() {
484 486 for (var c in that.channels) {
485 487 if (that.channels[c] !== null) {
486 488 that.channels[c].onclose = ws_closed_late;
487 489 }
488 490 }
489 491 }, 1000);
490 492 this.channels.shell.onmessage = $.proxy(this._handle_shell_reply, this);
491 493 this.channels.iopub.onmessage = $.proxy(this._handle_iopub_message, this);
492 494 this.channels.stdin.onmessage = $.proxy(this._handle_input_request, this);
493 495 };
494 496
495 497 /**
496 498 * Handle a websocket entering the open state,
497 499 * signaling that the kernel is connected when all channels are open.
498 500 *
499 501 * @function _ws_opened
500 502 */
501 503 Kernel.prototype._ws_opened = function (evt) {
502 504 if (this.is_connected()) {
503 505 // all events ready, trigger started event.
504 506 this._kernel_connected();
505 507 }
506 508 };
507 509
508 510 /**
509 511 * Handle a websocket entering the closed state. This closes the
510 512 * other communication channels if they are open. If the websocket
511 513 * was not closed due to an error, try to reconnect to the kernel.
512 514 *
513 515 * @function _ws_closed
514 516 * @param {string} ws_url - the websocket url
515 517 * @param {bool} error - whether the connection was closed due to an error
516 518 */
517 519 Kernel.prototype._ws_closed = function(ws_url, error) {
518 520 this.stop_channels();
519 521
520 522 this.events.trigger('kernel_disconnected.Kernel', {kernel: this});
521 523 if (error) {
522 524 console.log('WebSocket connection failed: ', ws_url);
523 525 this._reconnect_attempt = this._reconnect_attempt + 1;
524 526 this.events.trigger('kernel_connection_failed.Kernel', {kernel: this, ws_url: ws_url, attempt: this._reconnect_attempt});
525 527 }
526 528 this.reconnect();
527 529 };
528 530
529 531 /**
530 532 * Close the websocket channels. After successful close, the value
531 533 * in `this.channels[channel_name]` will be null.
532 534 *
533 535 * @function stop_channels
534 536 */
535 537 Kernel.prototype.stop_channels = function () {
536 538 var that = this;
537 539 var close = function (c) {
538 540 return function () {
539 541 if (that.channels[c] && that.channels[c].readyState === WebSocket.CLOSED) {
540 542 that.channels[c] = null;
541 543 }
542 544 };
543 545 };
544 546 for (var c in this.channels) {
545 547 if ( this.channels[c] !== null ) {
546 548 if (this.channels[c].readyState === WebSocket.OPEN) {
547 549 this.channels[c].onclose = close(c);
548 550 this.channels[c].close();
549 551 } else {
550 552 close(c)();
551 553 }
552 554 }
553 555 }
554 556 };
555 557
556 558 /**
557 559 * Check whether there is a connection to the kernel. This
558 560 * function only returns true if all channel objects have been
559 561 * created and have a state of WebSocket.OPEN.
560 562 *
561 563 * @function is_connected
562 564 * @returns {bool} - whether there is a connection
563 565 */
564 566 Kernel.prototype.is_connected = function () {
565 567 for (var c in this.channels) {
566 568 // if any channel is not ready, then we're not connected
567 569 if (this.channels[c] === null) {
568 570 return false;
569 571 }
570 572 if (this.channels[c].readyState !== WebSocket.OPEN) {
571 573 return false;
572 574 }
573 575 }
574 576 return true;
575 577 };
576 578
577 579 /**
578 580 * Check whether the connection to the kernel has been completely
579 581 * severed. This function only returns true if all channel objects
580 582 * are null.
581 583 *
582 584 * @function is_fully_disconnected
583 585 * @returns {bool} - whether the kernel is fully disconnected
584 586 */
585 587 Kernel.prototype.is_fully_disconnected = function () {
586 588 for (var c in this.channels) {
587 589 if (this.channels[c] === null) {
588 590 return true;
589 591 }
590 592 }
591 593 return false;
592 594 };
593 595
594 596 /**
595 597 * Send a message on the Kernel's shell channel
596 598 *
597 599 * @function send_shell_message
598 600 */
599 Kernel.prototype.send_shell_message = function (msg_type, content, callbacks, metadata) {
601 Kernel.prototype.send_shell_message = function (msg_type, content, callbacks, metadata, buffers) {
600 602 if (!this.is_connected()) {
601 603 throw new Error("kernel is not connected");
602 604 }
603 var msg = this._get_msg(msg_type, content, metadata);
604 this.channels.shell.send(JSON.stringify(msg));
605 var msg = this._get_msg(msg_type, content, metadata, buffers);
606 this.channels.shell.send(serialize.serialize(msg));
605 607 this.set_callbacks_for_msg(msg.header.msg_id, callbacks);
606 608 return msg.header.msg_id;
607 609 };
608 610
609 611 /**
610 612 * Get kernel info
611 613 *
612 614 * @function kernel_info
613 615 * @param callback {function}
614 616 *
615 617 * When calling this method, pass a callback function that expects one argument.
616 618 * The callback will be passed the complete `kernel_info_reply` message documented
617 619 * [here](http://ipython.org/ipython-doc/dev/development/messaging.html#kernel-info)
618 620 */
619 621 Kernel.prototype.kernel_info = function (callback) {
620 622 var callbacks;
621 623 if (callback) {
622 624 callbacks = { shell : { reply : callback } };
623 625 }
624 626 return this.send_shell_message("kernel_info_request", {}, callbacks);
625 627 };
626 628
627 629 /**
628 630 * Get info on an object
629 631 *
630 632 * When calling this method, pass a callback function that expects one argument.
631 633 * The callback will be passed the complete `inspect_reply` message documented
632 634 * [here](http://ipython.org/ipython-doc/dev/development/messaging.html#object-information)
633 635 *
634 636 * @function inspect
635 637 * @param code {string}
636 638 * @param cursor_pos {integer}
637 639 * @param callback {function}
638 640 */
639 641 Kernel.prototype.inspect = function (code, cursor_pos, callback) {
640 642 var callbacks;
641 643 if (callback) {
642 644 callbacks = { shell : { reply : callback } };
643 645 }
644 646
645 647 var content = {
646 648 code : code,
647 649 cursor_pos : cursor_pos,
648 650 detail_level : 0
649 651 };
650 652 return this.send_shell_message("inspect_request", content, callbacks);
651 653 };
652 654
653 655 /**
654 656 * Execute given code into kernel, and pass result to callback.
655 657 *
656 658 * @async
657 659 * @function execute
658 660 * @param {string} code
659 661 * @param [callbacks] {Object} With the following keys (all optional)
660 662 * @param callbacks.shell.reply {function}
661 663 * @param callbacks.shell.payload.[payload_name] {function}
662 664 * @param callbacks.iopub.output {function}
663 665 * @param callbacks.iopub.clear_output {function}
664 666 * @param callbacks.input {function}
665 667 * @param {object} [options]
666 668 * @param [options.silent=false] {Boolean}
667 669 * @param [options.user_expressions=empty_dict] {Dict}
668 670 * @param [options.allow_stdin=false] {Boolean} true|false
669 671 *
670 672 * @example
671 673 *
672 674 * The options object should contain the options for the execute
673 675 * call. Its default values are:
674 676 *
675 677 * options = {
676 678 * silent : true,
677 679 * user_expressions : {},
678 680 * allow_stdin : false
679 681 * }
680 682 *
681 683 * When calling this method pass a callbacks structure of the
682 684 * form:
683 685 *
684 686 * callbacks = {
685 687 * shell : {
686 688 * reply : execute_reply_callback,
687 689 * payload : {
688 690 * set_next_input : set_next_input_callback,
689 691 * }
690 692 * },
691 693 * iopub : {
692 694 * output : output_callback,
693 695 * clear_output : clear_output_callback,
694 696 * },
695 697 * input : raw_input_callback
696 698 * }
697 699 *
698 700 * Each callback will be passed the entire message as a single
699 701 * arugment. Payload handlers will be passed the corresponding
700 702 * payload and the execute_reply message.
701 703 */
702 704 Kernel.prototype.execute = function (code, callbacks, options) {
703 705 var content = {
704 706 code : code,
705 707 silent : true,
706 708 store_history : false,
707 709 user_expressions : {},
708 710 allow_stdin : false
709 711 };
710 712 callbacks = callbacks || {};
711 713 if (callbacks.input !== undefined) {
712 714 content.allow_stdin = true;
713 715 }
714 716 $.extend(true, content, options);
715 717 this.events.trigger('execution_request.Kernel', {kernel: this, content: content});
716 718 return this.send_shell_message("execute_request", content, callbacks);
717 719 };
718 720
719 721 /**
720 722 * When calling this method, pass a function to be called with the
721 723 * `complete_reply` message as its only argument when it arrives.
722 724 *
723 725 * `complete_reply` is documented
724 726 * [here](http://ipython.org/ipython-doc/dev/development/messaging.html#complete)
725 727 *
726 728 * @function complete
727 729 * @param code {string}
728 730 * @param cursor_pos {integer}
729 731 * @param callback {function}
730 732 */
731 733 Kernel.prototype.complete = function (code, cursor_pos, callback) {
732 734 var callbacks;
733 735 if (callback) {
734 736 callbacks = { shell : { reply : callback } };
735 737 }
736 738 var content = {
737 739 code : code,
738 740 cursor_pos : cursor_pos
739 741 };
740 742 return this.send_shell_message("complete_request", content, callbacks);
741 743 };
742 744
743 745 /**
744 746 * @function send_input_reply
745 747 */
746 748 Kernel.prototype.send_input_reply = function (input) {
747 749 if (!this.is_connected()) {
748 750 throw new Error("kernel is not connected");
749 751 }
750 752 var content = {
751 753 value : input
752 754 };
753 755 this.events.trigger('input_reply.Kernel', {kernel: this, content: content});
754 756 var msg = this._get_msg("input_reply", content);
755 this.channels.stdin.send(JSON.stringify(msg));
757 this.channels.stdin.send(serialize.serialize(msg));
756 758 return msg.header.msg_id;
757 759 };
758 760
759 761 /**
760 762 * @function register_iopub_handler
761 763 */
762 764 Kernel.prototype.register_iopub_handler = function (msg_type, callback) {
763 765 this._iopub_handlers[msg_type] = callback;
764 766 };
765 767
766 768 /**
767 769 * Get the iopub handler for a specific message type.
768 770 *
769 771 * @function get_iopub_handler
770 772 */
771 773 Kernel.prototype.get_iopub_handler = function (msg_type) {
772 774 return this._iopub_handlers[msg_type];
773 775 };
774 776
775 777 /**
776 778 * Get callbacks for a specific message.
777 779 *
778 780 * @function get_callbacks_for_msg
779 781 */
780 782 Kernel.prototype.get_callbacks_for_msg = function (msg_id) {
781 783 if (msg_id == this.last_msg_id) {
782 784 return this.last_msg_callbacks;
783 785 } else {
784 786 return this._msg_callbacks[msg_id];
785 787 }
786 788 };
787 789
788 790 /**
789 791 * Clear callbacks for a specific message.
790 792 *
791 793 * @function clear_callbacks_for_msg
792 794 */
793 795 Kernel.prototype.clear_callbacks_for_msg = function (msg_id) {
794 796 if (this._msg_callbacks[msg_id] !== undefined ) {
795 797 delete this._msg_callbacks[msg_id];
796 798 }
797 799 };
798 800
799 801 /**
800 802 * @function _finish_shell
801 803 */
802 804 Kernel.prototype._finish_shell = function (msg_id) {
803 805 var callbacks = this._msg_callbacks[msg_id];
804 806 if (callbacks !== undefined) {
805 807 callbacks.shell_done = true;
806 808 if (callbacks.iopub_done) {
807 809 this.clear_callbacks_for_msg(msg_id);
808 810 }
809 811 }
810 812 };
811 813
812 814 /**
813 815 * @function _finish_iopub
814 816 */
815 817 Kernel.prototype._finish_iopub = function (msg_id) {
816 818 var callbacks = this._msg_callbacks[msg_id];
817 819 if (callbacks !== undefined) {
818 820 callbacks.iopub_done = true;
819 821 if (callbacks.shell_done) {
820 822 this.clear_callbacks_for_msg(msg_id);
821 823 }
822 824 }
823 825 };
824 826
825 827 /**
826 828 * Set callbacks for a particular message.
827 829 * Callbacks should be a struct of the following form:
828 830 * shell : {
829 831 *
830 832 * }
831 833 *
832 834 * @function set_callbacks_for_msg
833 835 */
834 836 Kernel.prototype.set_callbacks_for_msg = function (msg_id, callbacks) {
835 837 this.last_msg_id = msg_id;
836 838 if (callbacks) {
837 839 // shallow-copy mapping, because we will modify it at the top level
838 840 var cbcopy = this._msg_callbacks[msg_id] = this.last_msg_callbacks = {};
839 841 cbcopy.shell = callbacks.shell;
840 842 cbcopy.iopub = callbacks.iopub;
841 843 cbcopy.input = callbacks.input;
842 844 cbcopy.shell_done = (!callbacks.shell);
843 845 cbcopy.iopub_done = (!callbacks.iopub);
844 846 } else {
845 847 this.last_msg_callbacks = {};
846 848 }
847 849 };
848 850
849 851 /**
850 852 * @function _handle_shell_reply
851 853 */
852 854 Kernel.prototype._handle_shell_reply = function (e) {
853 var reply = $.parseJSON(e.data);
855 serialize.deserialize(e.data, $.proxy(this._finish_shell_reply, this));
856 };
857
858 Kernel.prototype._finish_shell_reply = function (reply) {
854 859 this.events.trigger('shell_reply.Kernel', {kernel: this, reply: reply});
855 860 var content = reply.content;
856 861 var metadata = reply.metadata;
857 862 var parent_id = reply.parent_header.msg_id;
858 863 var callbacks = this.get_callbacks_for_msg(parent_id);
859 864 if (!callbacks || !callbacks.shell) {
860 865 return;
861 866 }
862 867 var shell_callbacks = callbacks.shell;
863 868
864 869 // signal that shell callbacks are done
865 870 this._finish_shell(parent_id);
866 871
867 872 if (shell_callbacks.reply !== undefined) {
868 873 shell_callbacks.reply(reply);
869 874 }
870 875 if (content.payload && shell_callbacks.payload) {
871 876 this._handle_payloads(content.payload, shell_callbacks.payload, reply);
872 877 }
873 878 };
874 879
875 880 /**
876 881 * @function _handle_payloads
877 882 */
878 883 Kernel.prototype._handle_payloads = function (payloads, payload_callbacks, msg) {
879 884 var l = payloads.length;
880 885 // Payloads are handled by triggering events because we don't want the Kernel
881 886 // to depend on the Notebook or Pager classes.
882 887 for (var i=0; i<l; i++) {
883 888 var payload = payloads[i];
884 889 var callback = payload_callbacks[payload.source];
885 890 if (callback) {
886 891 callback(payload, msg);
887 892 }
888 893 }
889 894 };
890 895
891 896 /**
892 897 * @function _handle_status_message
893 898 */
894 899 Kernel.prototype._handle_status_message = function (msg) {
895 900 var execution_state = msg.content.execution_state;
896 901 var parent_id = msg.parent_header.msg_id;
897 902
898 903 // dispatch status msg callbacks, if any
899 904 var callbacks = this.get_callbacks_for_msg(parent_id);
900 905 if (callbacks && callbacks.iopub && callbacks.iopub.status) {
901 906 try {
902 907 callbacks.iopub.status(msg);
903 908 } catch (e) {
904 909 console.log("Exception in status msg handler", e, e.stack);
905 910 }
906 911 }
907 912
908 913 if (execution_state === 'busy') {
909 914 this.events.trigger('kernel_busy.Kernel', {kernel: this});
910 915
911 916 } else if (execution_state === 'idle') {
912 917 // signal that iopub callbacks are (probably) done
913 918 // async output may still arrive,
914 919 // but only for the most recent request
915 920 this._finish_iopub(parent_id);
916 921
917 922 // trigger status_idle event
918 923 this.events.trigger('kernel_idle.Kernel', {kernel: this});
919 924
920 925 } else if (execution_state === 'starting') {
921 926 this.events.trigger('kernel_starting.Kernel', {kernel: this});
922 927 var that = this;
923 928 this.kernel_info(function () {
924 929 that.events.trigger('kernel_ready.Kernel', {kernel: that});
925 930 });
926 931
927 932 } else if (execution_state === 'restarting') {
928 933 // autorestarting is distinct from restarting,
929 934 // in that it means the kernel died and the server is restarting it.
930 935 // kernel_restarting sets the notification widget,
931 936 // autorestart shows the more prominent dialog.
932 937 this._autorestart_attempt = this._autorestart_attempt + 1;
933 938 this.events.trigger('kernel_restarting.Kernel', {kernel: this});
934 939 this.events.trigger('kernel_autorestarting.Kernel', {kernel: this, attempt: this._autorestart_attempt});
935 940
936 941 } else if (execution_state === 'dead') {
937 942 this.events.trigger('kernel_dead.Kernel', {kernel: this});
938 943 this._kernel_dead();
939 944 }
940 945 };
941 946
942 947 /**
943 948 * Handle clear_output message
944 949 *
945 950 * @function _handle_clear_output
946 951 */
947 952 Kernel.prototype._handle_clear_output = function (msg) {
948 953 var callbacks = this.get_callbacks_for_msg(msg.parent_header.msg_id);
949 954 if (!callbacks || !callbacks.iopub) {
950 955 return;
951 956 }
952 957 var callback = callbacks.iopub.clear_output;
953 958 if (callback) {
954 959 callback(msg);
955 960 }
956 961 };
957 962
958 963 /**
959 964 * handle an output message (execute_result, display_data, etc.)
960 965 *
961 966 * @function _handle_output_message
962 967 */
963 968 Kernel.prototype._handle_output_message = function (msg) {
964 969 var callbacks = this.get_callbacks_for_msg(msg.parent_header.msg_id);
965 970 if (!callbacks || !callbacks.iopub) {
966 971 return;
967 972 }
968 973 var callback = callbacks.iopub.output;
969 974 if (callback) {
970 975 callback(msg);
971 976 }
972 977 };
973 978
974 979 /**
975 980 * Dispatch IOPub messages to respective handlers. Each message
976 981 * type should have a handler.
977 982 *
978 983 * @function _handle_iopub_message
979 984 */
980 985 Kernel.prototype._handle_iopub_message = function (e) {
981 var msg = $.parseJSON(e.data);
986 serialize.deserialize(e.data, $.proxy(this._finish_iopub_message, this));
987 };
988
982 989
990 Kernel.prototype._finish_iopub_message = function (msg) {
983 991 var handler = this.get_iopub_handler(msg.header.msg_type);
984 992 if (handler !== undefined) {
985 993 handler(msg);
986 994 }
987 995 };
988 996
989 997 /**
990 998 * @function _handle_input_request
991 999 */
992 1000 Kernel.prototype._handle_input_request = function (e) {
993 var request = $.parseJSON(e.data);
1001 serialize.deserialize(e.data, $.proxy(this._finish_input_request, this));
1002 };
1003
1004
1005 Kernel.prototype._finish_input_request = function (request) {
994 1006 var header = request.header;
995 1007 var content = request.content;
996 1008 var metadata = request.metadata;
997 1009 var msg_type = header.msg_type;
998 1010 if (msg_type !== 'input_request') {
999 1011 console.log("Invalid input request!", request);
1000 1012 return;
1001 1013 }
1002 1014 var callbacks = this.get_callbacks_for_msg(request.parent_header.msg_id);
1003 1015 if (callbacks) {
1004 1016 if (callbacks.input) {
1005 1017 callbacks.input(request);
1006 1018 }
1007 1019 }
1008 1020 };
1009 1021
1010 1022 // Backwards compatability.
1011 1023 IPython.Kernel = Kernel;
1012 1024
1013 1025 return {'Kernel': Kernel};
1014 1026 });
@@ -1,323 +1,324
1 1 {% extends "page.html" %}
2 2
3 3 {% block stylesheet %}
4 4
5 5 {% if mathjax_url %}
6 6 <script type="text/javascript" src="{{mathjax_url}}?config=TeX-AMS_HTML-full&delayStartupUntil=configured" charset="utf-8"></script>
7 7 {% endif %}
8 8 <script type="text/javascript">
9 9 // MathJax disabled, set as null to distingish from *missing* MathJax,
10 10 // where it will be undefined, and should prompt a dialog later.
11 11 window.mathjax_url = "{{mathjax_url}}";
12 12 </script>
13 13
14 14 <link rel="stylesheet" href="{{ static_url("components/bootstrap-tour/build/css/bootstrap-tour.min.css") }}" type="text/css" />
15 15 <link rel="stylesheet" href="{{ static_url("components/codemirror/lib/codemirror.css") }}">
16 16
17 17 {{super()}}
18 18
19 19 <link rel="stylesheet" href="{{ static_url("notebook/css/override.css") }}" type="text/css" />
20 20
21 21 {% endblock %}
22 22
23 23 {% block params %}
24 24
25 25 data-project="{{project}}"
26 26 data-base-url="{{base_url}}"
27 27 data-ws-url="{{ws_url}}"
28 28 data-notebook-name="{{notebook_name}}"
29 29 data-notebook-path="{{notebook_path}}"
30 30 class="notebook_app"
31 31
32 32 {% endblock %}
33 33
34 34
35 35 {% block header %}
36 36
37 37
38 38 <span id="save_widget" class="nav pull-left">
39 39 <span id="notebook_name"></span>
40 40 <span id="checkpoint_status"></span>
41 41 <span id="autosave_status"></span>
42 42 </span>
43 43
44 44 <span id="kernel_selector_widget" class="pull-right dropdown">
45 45 <button class="dropdown-toggle" data-toggle="dropdown" type='button' id="current_kernel_spec">
46 46 <span class='kernel_name'>Python</span>
47 47 <span class="caret"></span>
48 48 </button>
49 49 <ul id="kernel_selector" class="dropdown-menu">
50 50 </ul>
51 51 </span>
52 52
53 53 {% endblock %}
54 54
55 55
56 56 {% block site %}
57 57
58 58 <div id="menubar-container" class="container">
59 59 <div id="menubar">
60 60 <div id="menus" class="navbar navbar-default" role="navigation">
61 61 <div class="container-fluid">
62 62 <button type="button" class="btn btn-default navbar-toggle" data-toggle="collapse" data-target=".navbar-collapse">
63 63 <i class="fa fa-bars"></i>
64 64 <span class="navbar-text">Menu</span>
65 65 </button>
66 66 <ul class="nav navbar-nav navbar-right">
67 67 <li id="kernel_indicator">
68 68 <i id="kernel_indicator_icon"></i>
69 69 </li>
70 70 <li id="modal_indicator">
71 71 <i id="modal_indicator_icon"></i>
72 72 </li>
73 73 <li id="notification_area"></li>
74 74 </ul>
75 75 <div class="navbar-collapse collapse">
76 76 <ul class="nav navbar-nav">
77 77 <li class="dropdown"><a href="#" class="dropdown-toggle" data-toggle="dropdown">File</a>
78 78 <ul id="file_menu" class="dropdown-menu">
79 79 <li id="new_notebook"
80 80 title="Make a new notebook (Opens a new window)">
81 81 <a href="#">New</a></li>
82 82 <li id="open_notebook"
83 83 title="Opens a new window with the Dashboard view">
84 84 <a href="#">Open...</a></li>
85 85 <!-- <hr/> -->
86 86 <li class="divider"></li>
87 87 <li id="copy_notebook"
88 88 title="Open a copy of this notebook's contents and start a new kernel">
89 89 <a href="#">Make a Copy...</a></li>
90 90 <li id="rename_notebook"><a href="#">Rename...</a></li>
91 91 <li id="save_checkpoint"><a href="#">Save and Checkpoint</a></li>
92 92 <!-- <hr/> -->
93 93 <li class="divider"></li>
94 94 <li id="restore_checkpoint" class="dropdown-submenu"><a href="#">Revert to Checkpoint</a>
95 95 <ul class="dropdown-menu">
96 96 <li><a href="#"></a></li>
97 97 <li><a href="#"></a></li>
98 98 <li><a href="#"></a></li>
99 99 <li><a href="#"></a></li>
100 100 <li><a href="#"></a></li>
101 101 </ul>
102 102 </li>
103 103 <li class="divider"></li>
104 104 <li id="print_preview"><a href="#">Print Preview</a></li>
105 105 <li class="dropdown-submenu"><a href="#">Download as</a>
106 106 <ul class="dropdown-menu">
107 107 <li id="download_ipynb"><a href="#">IPython Notebook (.ipynb)</a></li>
108 108 <li id="download_py"><a href="#">Python (.py)</a></li>
109 109 <li id="download_html"><a href="#">HTML (.html)</a></li>
110 110 <li id="download_rst"><a href="#">reST (.rst)</a></li>
111 111 <li id="download_pdf"><a href="#">PDF (.pdf)</a></li>
112 112 </ul>
113 113 </li>
114 114 <li class="divider"></li>
115 115 <li id="trust_notebook"
116 116 title="Trust the output of this notebook">
117 117 <a href="#" >Trust Notebook</a></li>
118 118 <li class="divider"></li>
119 119 <li id="kill_and_exit"
120 120 title="Shutdown this notebook's kernel, and close this window">
121 121 <a href="#" >Close and halt</a></li>
122 122 </ul>
123 123 </li>
124 124 <li class="dropdown"><a href="#" class="dropdown-toggle" data-toggle="dropdown">Edit</a>
125 125 <ul id="edit_menu" class="dropdown-menu">
126 126 <li id="cut_cell"><a href="#">Cut Cell</a></li>
127 127 <li id="copy_cell"><a href="#">Copy Cell</a></li>
128 128 <li id="paste_cell_above" class="disabled"><a href="#">Paste Cell Above</a></li>
129 129 <li id="paste_cell_below" class="disabled"><a href="#">Paste Cell Below</a></li>
130 130 <li id="paste_cell_replace" class="disabled"><a href="#">Paste Cell &amp; Replace</a></li>
131 131 <li id="delete_cell"><a href="#">Delete Cell</a></li>
132 132 <li id="undelete_cell" class="disabled"><a href="#">Undo Delete Cell</a></li>
133 133 <li class="divider"></li>
134 134 <li id="split_cell"><a href="#">Split Cell</a></li>
135 135 <li id="merge_cell_above"><a href="#">Merge Cell Above</a></li>
136 136 <li id="merge_cell_below"><a href="#">Merge Cell Below</a></li>
137 137 <li class="divider"></li>
138 138 <li id="move_cell_up"><a href="#">Move Cell Up</a></li>
139 139 <li id="move_cell_down"><a href="#">Move Cell Down</a></li>
140 140 <li class="divider"></li>
141 141 <li id="edit_nb_metadata"><a href="#">Edit Notebook Metadata</a></li>
142 142 </ul>
143 143 </li>
144 144 <li class="dropdown"><a href="#" class="dropdown-toggle" data-toggle="dropdown">View</a>
145 145 <ul id="view_menu" class="dropdown-menu">
146 146 <li id="toggle_header"
147 147 title="Show/Hide the IPython Notebook logo and notebook title (above menu bar)">
148 148 <a href="#">Toggle Header</a></li>
149 149 <li id="toggle_toolbar"
150 150 title="Show/Hide the action icons (below menu bar)">
151 151 <a href="#">Toggle Toolbar</a></li>
152 152 </ul>
153 153 </li>
154 154 <li class="dropdown"><a href="#" class="dropdown-toggle" data-toggle="dropdown">Insert</a>
155 155 <ul id="insert_menu" class="dropdown-menu">
156 156 <li id="insert_cell_above"
157 157 title="Insert an empty Code cell above the currently active cell">
158 158 <a href="#">Insert Cell Above</a></li>
159 159 <li id="insert_cell_below"
160 160 title="Insert an empty Code cell below the currently active cell">
161 161 <a href="#">Insert Cell Below</a></li>
162 162 </ul>
163 163 </li>
164 164 <li class="dropdown"><a href="#" class="dropdown-toggle" data-toggle="dropdown">Cell</a>
165 165 <ul id="cell_menu" class="dropdown-menu">
166 166 <li id="run_cell" title="Run this cell, and move cursor to the next one">
167 167 <a href="#">Run</a></li>
168 168 <li id="run_cell_select_below" title="Run this cell, select below">
169 169 <a href="#">Run and Select Below</a></li>
170 170 <li id="run_cell_insert_below" title="Run this cell, insert below">
171 171 <a href="#">Run and Insert Below</a></li>
172 172 <li id="run_all_cells" title="Run all cells in the notebook">
173 173 <a href="#">Run All</a></li>
174 174 <li id="run_all_cells_above" title="Run all cells above (but not including) this cell">
175 175 <a href="#">Run All Above</a></li>
176 176 <li id="run_all_cells_below" title="Run this cell and all cells below it">
177 177 <a href="#">Run All Below</a></li>
178 178 <li class="divider"></li>
179 179 <li id="change_cell_type" class="dropdown-submenu"
180 180 title="All cells in the notebook have a cell type. By default, new cells are created as 'Code' cells">
181 181 <a href="#">Cell Type</a>
182 182 <ul class="dropdown-menu">
183 183 <li id="to_code"
184 184 title="Contents will be sent to the kernel for execution, and output will display in the footer of cell">
185 185 <a href="#">Code</a></li>
186 186 <li id="to_markdown"
187 187 title="Contents will be rendered as HTML and serve as explanatory text">
188 188 <a href="#">Markdown</a></li>
189 189 <li id="to_raw"
190 190 title="Contents will pass through nbconvert unmodified">
191 191 <a href="#">Raw NBConvert</a></li>
192 192 <li id="to_heading1"><a href="#">Heading 1</a></li>
193 193 <li id="to_heading2"><a href="#">Heading 2</a></li>
194 194 <li id="to_heading3"><a href="#">Heading 3</a></li>
195 195 <li id="to_heading4"><a href="#">Heading 4</a></li>
196 196 <li id="to_heading5"><a href="#">Heading 5</a></li>
197 197 <li id="to_heading6"><a href="#">Heading 6</a></li>
198 198 </ul>
199 199 </li>
200 200 <li class="divider"></li>
201 201 <li id="current_outputs" class="dropdown-submenu"><a href="#">Current Output</a>
202 202 <ul class="dropdown-menu">
203 203 <li id="toggle_current_output"
204 204 title="Hide/Show the output of the current cell">
205 205 <a href="#">Toggle</a>
206 206 </li>
207 207 <li id="toggle_current_output_scroll"
208 208 title="Scroll the output of the current cell">
209 209 <a href="#">Toggle Scrolling</a>
210 210 </li>
211 211 <li id="clear_current_output"
212 212 title="Clear the output of the current cell">
213 213 <a href="#">Clear</a>
214 214 </li>
215 215 </ul>
216 216 </li>
217 217 <li id="all_outputs" class="dropdown-submenu"><a href="#">All Output</a>
218 218 <ul class="dropdown-menu">
219 219 <li id="toggle_all_output"
220 220 title="Hide/Show the output of all cells">
221 221 <a href="#">Toggle</a>
222 222 </li>
223 223 <li id="toggle_all_output_scroll"
224 224 title="Scroll the output of all cells">
225 225 <a href="#">Toggle Scrolling</a>
226 226 </li>
227 227 <li id="clear_all_output"
228 228 title="Clear the output of all cells">
229 229 <a href="#">Clear</a>
230 230 </li>
231 231 </ul>
232 232 </li>
233 233 </ul>
234 234 </li>
235 235 <li class="dropdown"><a href="#" class="dropdown-toggle" data-toggle="dropdown">Kernel</a>
236 236 <ul id="kernel_menu" class="dropdown-menu">
237 237 <li id="int_kernel"
238 238 title="Send KeyboardInterrupt (CTRL-C) to the Kernel">
239 239 <a href="#">Interrupt</a></li>
240 240 <li id="restart_kernel"
241 241 title="Restart the Kernel">
242 242 <a href="#">Restart</a></li>
243 243 <li class="divider"></li>
244 244 <li id="menu-change-kernel" class="dropdown-submenu">
245 245 <a href="#">Change kernel</a>
246 246 <ul class="dropdown-menu" id="menu-change-kernel-submenu"></ul>
247 247 </li>
248 248 </ul>
249 249 </li>
250 250 <li class="dropdown"><a href="#" class="dropdown-toggle" data-toggle="dropdown">Help</a>
251 251 <ul id="help_menu" class="dropdown-menu">
252 252 <li id="notebook_tour" title="A quick tour of the notebook user interface"><a href="#">User Interface Tour</a></li>
253 253 <li id="keyboard_shortcuts" title="Opens a tooltip with all keyboard shortcuts"><a href="#">Keyboard Shortcuts</a></li>
254 254 <li class="divider"></li>
255 255 {% set
256 256 sections = (
257 257 (
258 258 ("http://ipython.org/documentation.html","IPython Help",True),
259 259 ("http://nbviewer.ipython.org/github/ipython/ipython/tree/2.x/examples/Index.ipynb", "Notebook Help", True),
260 260 ),(
261 261 ("http://docs.python.org","Python",True),
262 262 ("http://help.github.com/articles/github-flavored-markdown","Markdown",True),
263 263 ("http://docs.scipy.org/doc/numpy/reference/","NumPy",True),
264 264 ("http://docs.scipy.org/doc/scipy/reference/","SciPy",True),
265 265 ("http://matplotlib.org/contents.html","Matplotlib",True),
266 266 ("http://docs.sympy.org/latest/index.html","SymPy",True),
267 267 ("http://pandas.pydata.org/pandas-docs/stable/","pandas", True)
268 268 )
269 269 )
270 270 %}
271 271
272 272 {% for helplinks in sections %}
273 273 {% for link in helplinks %}
274 274 <li><a href="{{link[0]}}" {{'target="_blank" title="Opens in a new window"' if link[2]}}>
275 275 {{'<i class="fa fa-external-link menu-icon pull-right"></i>' if link[2]}}
276 276 {{link[1]}}
277 277 </a></li>
278 278 {% endfor %}
279 279 {% if not loop.last %}
280 280 <li class="divider"></li>
281 281 {% endif %}
282 282 {% endfor %}
283 283 </li>
284 284 </ul>
285 285 </li>
286 286 </ul>
287 287 </div>
288 288 </div>
289 289 </div>
290 290 </div>
291 291 <div id="maintoolbar" class="navbar">
292 292 <div class="toolbar-inner navbar-inner navbar-nobg">
293 293 <div id="maintoolbar-container" class="container"></div>
294 294 </div>
295 295 </div>
296 296 </div>
297 297
298 298 <div id="ipython-main-app">
299 299
300 300 <div id="notebook_panel">
301 301 <div id="notebook"></div>
302 302 <div id="pager_splitter"></div>
303 303 <div id="pager">
304 304 <div id='pager_button_area'>
305 305 </div>
306 306 <div id="pager-container" class="container"></div>
307 307 </div>
308 308 </div>
309 309
310 310 </div>
311 311 <div id='tooltip' class='ipython_tooltip' style='display:none'></div>
312 312
313 313
314 314 {% endblock %}
315 315
316 316
317 317 {% block script %}
318 318 {{super()}}
319 319
320 <script src="{{ static_url("components/text-encoding/lib/encoding.js") }}" charset="utf-8"></script>
320 321
321 322 <script src="{{ static_url("notebook/js/main.js") }}" charset="utf-8"></script>
322 323
323 324 {% endblock %}
@@ -1,644 +1,644
1 1 """Base classes to manage a Client's interaction with a running kernel"""
2 2
3 3 # Copyright (c) IPython Development Team.
4 4 # Distributed under the terms of the Modified BSD License.
5 5
6 6 from __future__ import absolute_import
7 7
8 8 import atexit
9 9 import errno
10 10 from threading import Thread
11 11 import time
12 12
13 13 import zmq
14 14 # import ZMQError in top-level namespace, to avoid ugly attribute-error messages
15 15 # during garbage collection of threads at exit:
16 16 from zmq import ZMQError
17 17 from zmq.eventloop import ioloop, zmqstream
18 18
19 19 from IPython.core.release import kernel_protocol_version_info
20 20
21 21 from .channelsabc import (
22 22 ShellChannelABC, IOPubChannelABC,
23 23 HBChannelABC, StdInChannelABC,
24 24 )
25 25 from IPython.utils.py3compat import string_types, iteritems
26 26
27 27 #-----------------------------------------------------------------------------
28 28 # Constants and exceptions
29 29 #-----------------------------------------------------------------------------
30 30
31 31 major_protocol_version = kernel_protocol_version_info[0]
32 32
33 33 class InvalidPortNumber(Exception):
34 34 pass
35 35
36 36 #-----------------------------------------------------------------------------
37 37 # Utility functions
38 38 #-----------------------------------------------------------------------------
39 39
40 40 # some utilities to validate message structure, these might get moved elsewhere
41 41 # if they prove to have more generic utility
42 42
43 43 def validate_string_list(lst):
44 44 """Validate that the input is a list of strings.
45 45
46 46 Raises ValueError if not."""
47 47 if not isinstance(lst, list):
48 48 raise ValueError('input %r must be a list' % lst)
49 49 for x in lst:
50 50 if not isinstance(x, string_types):
51 51 raise ValueError('element %r in list must be a string' % x)
52 52
53 53
54 54 def validate_string_dict(dct):
55 55 """Validate that the input is a dict with string keys and values.
56 56
57 57 Raises ValueError if not."""
58 58 for k,v in iteritems(dct):
59 59 if not isinstance(k, string_types):
60 60 raise ValueError('key %r in dict must be a string' % k)
61 61 if not isinstance(v, string_types):
62 62 raise ValueError('value %r in dict must be a string' % v)
63 63
64 64
65 65 #-----------------------------------------------------------------------------
66 66 # ZMQ Socket Channel classes
67 67 #-----------------------------------------------------------------------------
68 68
69 69 class ZMQSocketChannel(Thread):
70 70 """The base class for the channels that use ZMQ sockets."""
71 71 context = None
72 72 session = None
73 73 socket = None
74 74 ioloop = None
75 75 stream = None
76 76 _address = None
77 77 _exiting = False
78 78 proxy_methods = []
79 79
80 80 def __init__(self, context, session, address):
81 81 """Create a channel.
82 82
83 83 Parameters
84 84 ----------
85 85 context : :class:`zmq.Context`
86 86 The ZMQ context to use.
87 87 session : :class:`session.Session`
88 88 The session to use.
89 89 address : zmq url
90 90 Standard (ip, port) tuple that the kernel is listening on.
91 91 """
92 92 super(ZMQSocketChannel, self).__init__()
93 93 self.daemon = True
94 94
95 95 self.context = context
96 96 self.session = session
97 97 if isinstance(address, tuple):
98 98 if address[1] == 0:
99 99 message = 'The port number for a channel cannot be 0.'
100 100 raise InvalidPortNumber(message)
101 101 address = "tcp://%s:%i" % address
102 102 self._address = address
103 103 atexit.register(self._notice_exit)
104 104
105 105 def _notice_exit(self):
106 106 self._exiting = True
107 107
108 108 def _run_loop(self):
109 109 """Run my loop, ignoring EINTR events in the poller"""
110 110 while True:
111 111 try:
112 112 self.ioloop.start()
113 113 except ZMQError as e:
114 114 if e.errno == errno.EINTR:
115 115 continue
116 116 else:
117 117 raise
118 118 except Exception:
119 119 if self._exiting:
120 120 break
121 121 else:
122 122 raise
123 123 else:
124 124 break
125 125
126 126 def stop(self):
127 127 """Stop the channel's event loop and join its thread.
128 128
129 129 This calls :meth:`~threading.Thread.join` and returns when the thread
130 130 terminates. :class:`RuntimeError` will be raised if
131 131 :meth:`~threading.Thread.start` is called again.
132 132 """
133 133 if self.ioloop is not None:
134 134 self.ioloop.stop()
135 135 self.join()
136 136 self.close()
137 137
138 138 def close(self):
139 139 if self.ioloop is not None:
140 140 try:
141 141 self.ioloop.close(all_fds=True)
142 142 except Exception:
143 143 pass
144 144 if self.socket is not None:
145 145 try:
146 146 self.socket.close(linger=0)
147 147 except Exception:
148 148 pass
149 149 self.socket = None
150 150
151 151 @property
152 152 def address(self):
153 153 """Get the channel's address as a zmq url string.
154 154
155 155 These URLS have the form: 'tcp://127.0.0.1:5555'.
156 156 """
157 157 return self._address
158 158
159 159 def _queue_send(self, msg):
160 160 """Queue a message to be sent from the IOLoop's thread.
161 161
162 162 Parameters
163 163 ----------
164 164 msg : message to send
165 165
166 166 This is threadsafe, as it uses IOLoop.add_callback to give the loop's
167 167 thread control of the action.
168 168 """
169 169 def thread_send():
170 170 self.session.send(self.stream, msg)
171 171 self.ioloop.add_callback(thread_send)
172 172
173 173 def _handle_recv(self, msg):
174 174 """Callback for stream.on_recv.
175 175
176 176 Unpacks message, and calls handlers with it.
177 177 """
178 178 ident,smsg = self.session.feed_identities(msg)
179 msg = self.session.unserialize(smsg)
179 msg = self.session.deserialize(smsg)
180 180 self.call_handlers(msg)
181 181
182 182
183 183
184 184 class ShellChannel(ZMQSocketChannel):
185 185 """The shell channel for issuing request/replies to the kernel."""
186 186
187 187 command_queue = None
188 188 # flag for whether execute requests should be allowed to call raw_input:
189 189 allow_stdin = True
190 190 proxy_methods = [
191 191 'execute',
192 192 'complete',
193 193 'inspect',
194 194 'history',
195 195 'kernel_info',
196 196 'shutdown',
197 197 'is_complete',
198 198 ]
199 199
200 200 def __init__(self, context, session, address):
201 201 super(ShellChannel, self).__init__(context, session, address)
202 202 self.ioloop = ioloop.IOLoop()
203 203
204 204 def run(self):
205 205 """The thread's main activity. Call start() instead."""
206 206 self.socket = self.context.socket(zmq.DEALER)
207 207 self.socket.linger = 1000
208 208 self.socket.setsockopt(zmq.IDENTITY, self.session.bsession)
209 209 self.socket.connect(self.address)
210 210 self.stream = zmqstream.ZMQStream(self.socket, self.ioloop)
211 211 self.stream.on_recv(self._handle_recv)
212 212 self._run_loop()
213 213
214 214 def call_handlers(self, msg):
215 215 """This method is called in the ioloop thread when a message arrives.
216 216
217 217 Subclasses should override this method to handle incoming messages.
218 218 It is important to remember that this method is called in the thread
219 219 so that some logic must be done to ensure that the application level
220 220 handlers are called in the application thread.
221 221 """
222 222 raise NotImplementedError('call_handlers must be defined in a subclass.')
223 223
224 224 def execute(self, code, silent=False, store_history=True,
225 225 user_expressions=None, allow_stdin=None):
226 226 """Execute code in the kernel.
227 227
228 228 Parameters
229 229 ----------
230 230 code : str
231 231 A string of Python code.
232 232
233 233 silent : bool, optional (default False)
234 234 If set, the kernel will execute the code as quietly possible, and
235 235 will force store_history to be False.
236 236
237 237 store_history : bool, optional (default True)
238 238 If set, the kernel will store command history. This is forced
239 239 to be False if silent is True.
240 240
241 241 user_expressions : dict, optional
242 242 A dict mapping names to expressions to be evaluated in the user's
243 243 dict. The expression values are returned as strings formatted using
244 244 :func:`repr`.
245 245
246 246 allow_stdin : bool, optional (default self.allow_stdin)
247 247 Flag for whether the kernel can send stdin requests to frontends.
248 248
249 249 Some frontends (e.g. the Notebook) do not support stdin requests.
250 250 If raw_input is called from code executed from such a frontend, a
251 251 StdinNotImplementedError will be raised.
252 252
253 253 Returns
254 254 -------
255 255 The msg_id of the message sent.
256 256 """
257 257 if user_expressions is None:
258 258 user_expressions = {}
259 259 if allow_stdin is None:
260 260 allow_stdin = self.allow_stdin
261 261
262 262
263 263 # Don't waste network traffic if inputs are invalid
264 264 if not isinstance(code, string_types):
265 265 raise ValueError('code %r must be a string' % code)
266 266 validate_string_dict(user_expressions)
267 267
268 268 # Create class for content/msg creation. Related to, but possibly
269 269 # not in Session.
270 270 content = dict(code=code, silent=silent, store_history=store_history,
271 271 user_expressions=user_expressions,
272 272 allow_stdin=allow_stdin,
273 273 )
274 274 msg = self.session.msg('execute_request', content)
275 275 self._queue_send(msg)
276 276 return msg['header']['msg_id']
277 277
278 278 def complete(self, code, cursor_pos=None):
279 279 """Tab complete text in the kernel's namespace.
280 280
281 281 Parameters
282 282 ----------
283 283 code : str
284 284 The context in which completion is requested.
285 285 Can be anything between a variable name and an entire cell.
286 286 cursor_pos : int, optional
287 287 The position of the cursor in the block of code where the completion was requested.
288 288 Default: ``len(code)``
289 289
290 290 Returns
291 291 -------
292 292 The msg_id of the message sent.
293 293 """
294 294 if cursor_pos is None:
295 295 cursor_pos = len(code)
296 296 content = dict(code=code, cursor_pos=cursor_pos)
297 297 msg = self.session.msg('complete_request', content)
298 298 self._queue_send(msg)
299 299 return msg['header']['msg_id']
300 300
301 301 def inspect(self, code, cursor_pos=None, detail_level=0):
302 302 """Get metadata information about an object in the kernel's namespace.
303 303
304 304 It is up to the kernel to determine the appropriate object to inspect.
305 305
306 306 Parameters
307 307 ----------
308 308 code : str
309 309 The context in which info is requested.
310 310 Can be anything between a variable name and an entire cell.
311 311 cursor_pos : int, optional
312 312 The position of the cursor in the block of code where the info was requested.
313 313 Default: ``len(code)``
314 314 detail_level : int, optional
315 315 The level of detail for the introspection (0-2)
316 316
317 317 Returns
318 318 -------
319 319 The msg_id of the message sent.
320 320 """
321 321 if cursor_pos is None:
322 322 cursor_pos = len(code)
323 323 content = dict(code=code, cursor_pos=cursor_pos,
324 324 detail_level=detail_level,
325 325 )
326 326 msg = self.session.msg('inspect_request', content)
327 327 self._queue_send(msg)
328 328 return msg['header']['msg_id']
329 329
330 330 def history(self, raw=True, output=False, hist_access_type='range', **kwargs):
331 331 """Get entries from the kernel's history list.
332 332
333 333 Parameters
334 334 ----------
335 335 raw : bool
336 336 If True, return the raw input.
337 337 output : bool
338 338 If True, then return the output as well.
339 339 hist_access_type : str
340 340 'range' (fill in session, start and stop params), 'tail' (fill in n)
341 341 or 'search' (fill in pattern param).
342 342
343 343 session : int
344 344 For a range request, the session from which to get lines. Session
345 345 numbers are positive integers; negative ones count back from the
346 346 current session.
347 347 start : int
348 348 The first line number of a history range.
349 349 stop : int
350 350 The final (excluded) line number of a history range.
351 351
352 352 n : int
353 353 The number of lines of history to get for a tail request.
354 354
355 355 pattern : str
356 356 The glob-syntax pattern for a search request.
357 357
358 358 Returns
359 359 -------
360 360 The msg_id of the message sent.
361 361 """
362 362 content = dict(raw=raw, output=output, hist_access_type=hist_access_type,
363 363 **kwargs)
364 364 msg = self.session.msg('history_request', content)
365 365 self._queue_send(msg)
366 366 return msg['header']['msg_id']
367 367
368 368 def kernel_info(self):
369 369 """Request kernel info."""
370 370 msg = self.session.msg('kernel_info_request')
371 371 self._queue_send(msg)
372 372 return msg['header']['msg_id']
373 373
374 374 def _handle_kernel_info_reply(self, msg):
375 375 """handle kernel info reply
376 376
377 377 sets protocol adaptation version
378 378 """
379 379 adapt_version = int(msg['content']['protocol_version'].split('.')[0])
380 380 if adapt_version != major_protocol_version:
381 381 self.session.adapt_version = adapt_version
382 382
383 383 def shutdown(self, restart=False):
384 384 """Request an immediate kernel shutdown.
385 385
386 386 Upon receipt of the (empty) reply, client code can safely assume that
387 387 the kernel has shut down and it's safe to forcefully terminate it if
388 388 it's still alive.
389 389
390 390 The kernel will send the reply via a function registered with Python's
391 391 atexit module, ensuring it's truly done as the kernel is done with all
392 392 normal operation.
393 393 """
394 394 # Send quit message to kernel. Once we implement kernel-side setattr,
395 395 # this should probably be done that way, but for now this will do.
396 396 msg = self.session.msg('shutdown_request', {'restart':restart})
397 397 self._queue_send(msg)
398 398 return msg['header']['msg_id']
399 399
400 400 def is_complete(self, code):
401 401 msg = self.session.msg('is_complete_request', {'code': code})
402 402 self._queue_send(msg)
403 403 return msg['header']['msg_id']
404 404
405 405
406 406 class IOPubChannel(ZMQSocketChannel):
407 407 """The iopub channel which listens for messages that the kernel publishes.
408 408
409 409 This channel is where all output is published to frontends.
410 410 """
411 411
412 412 def __init__(self, context, session, address):
413 413 super(IOPubChannel, self).__init__(context, session, address)
414 414 self.ioloop = ioloop.IOLoop()
415 415
416 416 def run(self):
417 417 """The thread's main activity. Call start() instead."""
418 418 self.socket = self.context.socket(zmq.SUB)
419 419 self.socket.linger = 1000
420 420 self.socket.setsockopt(zmq.SUBSCRIBE,b'')
421 421 self.socket.setsockopt(zmq.IDENTITY, self.session.bsession)
422 422 self.socket.connect(self.address)
423 423 self.stream = zmqstream.ZMQStream(self.socket, self.ioloop)
424 424 self.stream.on_recv(self._handle_recv)
425 425 self._run_loop()
426 426
427 427 def call_handlers(self, msg):
428 428 """This method is called in the ioloop thread when a message arrives.
429 429
430 430 Subclasses should override this method to handle incoming messages.
431 431 It is important to remember that this method is called in the thread
432 432 so that some logic must be done to ensure that the application leve
433 433 handlers are called in the application thread.
434 434 """
435 435 raise NotImplementedError('call_handlers must be defined in a subclass.')
436 436
437 437 def flush(self, timeout=1.0):
438 438 """Immediately processes all pending messages on the iopub channel.
439 439
440 440 Callers should use this method to ensure that :meth:`call_handlers`
441 441 has been called for all messages that have been received on the
442 442 0MQ SUB socket of this channel.
443 443
444 444 This method is thread safe.
445 445
446 446 Parameters
447 447 ----------
448 448 timeout : float, optional
449 449 The maximum amount of time to spend flushing, in seconds. The
450 450 default is one second.
451 451 """
452 452 # We do the IOLoop callback process twice to ensure that the IOLoop
453 453 # gets to perform at least one full poll.
454 454 stop_time = time.time() + timeout
455 455 for i in range(2):
456 456 self._flushed = False
457 457 self.ioloop.add_callback(self._flush)
458 458 while not self._flushed and time.time() < stop_time:
459 459 time.sleep(0.01)
460 460
461 461 def _flush(self):
462 462 """Callback for :method:`self.flush`."""
463 463 self.stream.flush()
464 464 self._flushed = True
465 465
466 466
467 467 class StdInChannel(ZMQSocketChannel):
468 468 """The stdin channel to handle raw_input requests that the kernel makes."""
469 469
470 470 msg_queue = None
471 471 proxy_methods = ['input']
472 472
473 473 def __init__(self, context, session, address):
474 474 super(StdInChannel, self).__init__(context, session, address)
475 475 self.ioloop = ioloop.IOLoop()
476 476
477 477 def run(self):
478 478 """The thread's main activity. Call start() instead."""
479 479 self.socket = self.context.socket(zmq.DEALER)
480 480 self.socket.linger = 1000
481 481 self.socket.setsockopt(zmq.IDENTITY, self.session.bsession)
482 482 self.socket.connect(self.address)
483 483 self.stream = zmqstream.ZMQStream(self.socket, self.ioloop)
484 484 self.stream.on_recv(self._handle_recv)
485 485 self._run_loop()
486 486
487 487 def call_handlers(self, msg):
488 488 """This method is called in the ioloop thread when a message arrives.
489 489
490 490 Subclasses should override this method to handle incoming messages.
491 491 It is important to remember that this method is called in the thread
492 492 so that some logic must be done to ensure that the application leve
493 493 handlers are called in the application thread.
494 494 """
495 495 raise NotImplementedError('call_handlers must be defined in a subclass.')
496 496
497 497 def input(self, string):
498 498 """Send a string of raw input to the kernel."""
499 499 content = dict(value=string)
500 500 msg = self.session.msg('input_reply', content)
501 501 self._queue_send(msg)
502 502
503 503
504 504 class HBChannel(ZMQSocketChannel):
505 505 """The heartbeat channel which monitors the kernel heartbeat.
506 506
507 507 Note that the heartbeat channel is paused by default. As long as you start
508 508 this channel, the kernel manager will ensure that it is paused and un-paused
509 509 as appropriate.
510 510 """
511 511
512 512 time_to_dead = 3.0
513 513 socket = None
514 514 poller = None
515 515 _running = None
516 516 _pause = None
517 517 _beating = None
518 518
519 519 def __init__(self, context, session, address):
520 520 super(HBChannel, self).__init__(context, session, address)
521 521 self._running = False
522 522 self._pause =True
523 523 self.poller = zmq.Poller()
524 524
525 525 def _create_socket(self):
526 526 if self.socket is not None:
527 527 # close previous socket, before opening a new one
528 528 self.poller.unregister(self.socket)
529 529 self.socket.close()
530 530 self.socket = self.context.socket(zmq.REQ)
531 531 self.socket.linger = 1000
532 532 self.socket.connect(self.address)
533 533
534 534 self.poller.register(self.socket, zmq.POLLIN)
535 535
536 536 def _poll(self, start_time):
537 537 """poll for heartbeat replies until we reach self.time_to_dead.
538 538
539 539 Ignores interrupts, and returns the result of poll(), which
540 540 will be an empty list if no messages arrived before the timeout,
541 541 or the event tuple if there is a message to receive.
542 542 """
543 543
544 544 until_dead = self.time_to_dead - (time.time() - start_time)
545 545 # ensure poll at least once
546 546 until_dead = max(until_dead, 1e-3)
547 547 events = []
548 548 while True:
549 549 try:
550 550 events = self.poller.poll(1000 * until_dead)
551 551 except ZMQError as e:
552 552 if e.errno == errno.EINTR:
553 553 # ignore interrupts during heartbeat
554 554 # this may never actually happen
555 555 until_dead = self.time_to_dead - (time.time() - start_time)
556 556 until_dead = max(until_dead, 1e-3)
557 557 pass
558 558 else:
559 559 raise
560 560 except Exception:
561 561 if self._exiting:
562 562 break
563 563 else:
564 564 raise
565 565 else:
566 566 break
567 567 return events
568 568
569 569 def run(self):
570 570 """The thread's main activity. Call start() instead."""
571 571 self._create_socket()
572 572 self._running = True
573 573 self._beating = True
574 574
575 575 while self._running:
576 576 if self._pause:
577 577 # just sleep, and skip the rest of the loop
578 578 time.sleep(self.time_to_dead)
579 579 continue
580 580
581 581 since_last_heartbeat = 0.0
582 582 # io.rprint('Ping from HB channel') # dbg
583 583 # no need to catch EFSM here, because the previous event was
584 584 # either a recv or connect, which cannot be followed by EFSM
585 585 self.socket.send(b'ping')
586 586 request_time = time.time()
587 587 ready = self._poll(request_time)
588 588 if ready:
589 589 self._beating = True
590 590 # the poll above guarantees we have something to recv
591 591 self.socket.recv()
592 592 # sleep the remainder of the cycle
593 593 remainder = self.time_to_dead - (time.time() - request_time)
594 594 if remainder > 0:
595 595 time.sleep(remainder)
596 596 continue
597 597 else:
598 598 # nothing was received within the time limit, signal heart failure
599 599 self._beating = False
600 600 since_last_heartbeat = time.time() - request_time
601 601 self.call_handlers(since_last_heartbeat)
602 602 # and close/reopen the socket, because the REQ/REP cycle has been broken
603 603 self._create_socket()
604 604 continue
605 605
606 606 def pause(self):
607 607 """Pause the heartbeat."""
608 608 self._pause = True
609 609
610 610 def unpause(self):
611 611 """Unpause the heartbeat."""
612 612 self._pause = False
613 613
614 614 def is_beating(self):
615 615 """Is the heartbeat running and responsive (and not paused)."""
616 616 if self.is_alive() and not self._pause and self._beating:
617 617 return True
618 618 else:
619 619 return False
620 620
621 621 def stop(self):
622 622 """Stop the channel's event loop and join its thread."""
623 623 self._running = False
624 624 super(HBChannel, self).stop()
625 625
626 626 def call_handlers(self, since_last_heartbeat):
627 627 """This method is called in the ioloop thread when a message arrives.
628 628
629 629 Subclasses should override this method to handle incoming messages.
630 630 It is important to remember that this method is called in the thread
631 631 so that some logic must be done to ensure that the application level
632 632 handlers are called in the application thread.
633 633 """
634 634 raise NotImplementedError('call_handlers must be defined in a subclass.')
635 635
636 636
637 637 #---------------------------------------------------------------------#-----------------------------------------------------------------------------
638 638 # ABC Registration
639 639 #-----------------------------------------------------------------------------
640 640
641 641 ShellChannelABC.register(ShellChannel)
642 642 IOPubChannelABC.register(IOPubChannel)
643 643 HBChannelABC.register(HBChannel)
644 644 StdInChannelABC.register(StdInChannel)
@@ -1,145 +1,153
1 1 """Base class for a Comm"""
2 2
3 3 # Copyright (c) IPython Development Team.
4 4 # Distributed under the terms of the Modified BSD License.
5 5
6 6 import uuid
7 7
8 8 from IPython.config import LoggingConfigurable
9 9 from IPython.kernel.zmq.kernelbase import Kernel
10 10
11 11 from IPython.utils.jsonutil import json_clean
12 12 from IPython.utils.traitlets import Instance, Unicode, Bytes, Bool, Dict, Any
13 13
14 14
15 15 class Comm(LoggingConfigurable):
16 16
17 17 # If this is instantiated by a non-IPython kernel, shell will be None
18 18 shell = Instance('IPython.core.interactiveshell.InteractiveShellABC',
19 19 allow_none=True)
20 20 kernel = Instance('IPython.kernel.zmq.kernelbase.Kernel')
21 21 def _kernel_default(self):
22 22 if Kernel.initialized():
23 23 return Kernel.instance()
24 24
25 25 iopub_socket = Any()
26 26 def _iopub_socket_default(self):
27 27 return self.kernel.iopub_socket
28 28 session = Instance('IPython.kernel.zmq.session.Session')
29 29 def _session_default(self):
30 30 if self.kernel is not None:
31 31 return self.kernel.session
32 32
33 33 target_name = Unicode('comm')
34 34
35 35 topic = Bytes()
36 36 def _topic_default(self):
37 37 return ('comm-%s' % self.comm_id).encode('ascii')
38 38
39 39 _open_data = Dict(help="data dict, if any, to be included in comm_open")
40 40 _close_data = Dict(help="data dict, if any, to be included in comm_close")
41 41
42 42 _msg_callback = Any()
43 43 _close_callback = Any()
44 44
45 45 _closed = Bool(False)
46 46 comm_id = Unicode()
47 47 def _comm_id_default(self):
48 48 return uuid.uuid4().hex
49 49
50 50 primary = Bool(True, help="Am I the primary or secondary Comm?")
51 51
52 52 def __init__(self, target_name='', data=None, **kwargs):
53 53 if target_name:
54 54 kwargs['target_name'] = target_name
55 55 super(Comm, self).__init__(**kwargs)
56 56 if self.primary:
57 57 # I am primary, open my peer.
58 58 self.open(data)
59 59
60 def _publish_msg(self, msg_type, data=None, metadata=None, **keys):
60 def _publish_msg(self, msg_type, data=None, metadata=None, buffers=None, **keys):
61 61 """Helper for sending a comm message on IOPub"""
62 62 data = {} if data is None else data
63 63 metadata = {} if metadata is None else metadata
64 64 content = json_clean(dict(data=data, comm_id=self.comm_id, **keys))
65 65 self.session.send(self.iopub_socket, msg_type,
66 66 content,
67 67 metadata=json_clean(metadata),
68 68 parent=self.kernel._parent_header,
69 69 ident=self.topic,
70 buffers=buffers,
70 71 )
71 72
72 73 def __del__(self):
73 74 """trigger close on gc"""
74 75 self.close()
75 76
76 77 # publishing messages
77 78
78 def open(self, data=None, metadata=None):
79 def open(self, data=None, metadata=None, buffers=None):
79 80 """Open the frontend-side version of this comm"""
80 81 if data is None:
81 82 data = self._open_data
82 83 comm_manager = getattr(self.kernel, 'comm_manager', None)
83 84 if comm_manager is None:
84 85 raise RuntimeError("Comms cannot be opened without a kernel "
85 86 "and a comm_manager attached to that kernel.")
86 87
87 88 comm_manager.register_comm(self)
88 89 self._closed = False
89 self._publish_msg('comm_open', data, metadata, target_name=self.target_name)
90 self._publish_msg('comm_open',
91 data=data, metadata=metadata, buffers=buffers,
92 target_name=self.target_name,
93 )
90 94
91 def close(self, data=None, metadata=None):
95 def close(self, data=None, metadata=None, buffers=None):
92 96 """Close the frontend-side version of this comm"""
93 97 if self._closed:
94 98 # only close once
95 99 return
96 100 if data is None:
97 101 data = self._close_data
98 self._publish_msg('comm_close', data, metadata)
102 self._publish_msg('comm_close',
103 data=data, metadata=metadata, buffers=buffers,
104 )
99 105 self.kernel.comm_manager.unregister_comm(self)
100 106 self._closed = True
101 107
102 def send(self, data=None, metadata=None):
108 def send(self, data=None, metadata=None, buffers=None):
103 109 """Send a message to the frontend-side version of this comm"""
104 self._publish_msg('comm_msg', data, metadata)
110 self._publish_msg('comm_msg',
111 data=data, metadata=metadata, buffers=buffers,
112 )
105 113
106 114 # registering callbacks
107 115
108 116 def on_close(self, callback):
109 117 """Register a callback for comm_close
110 118
111 119 Will be called with the `data` of the close message.
112 120
113 121 Call `on_close(None)` to disable an existing callback.
114 122 """
115 123 self._close_callback = callback
116 124
117 125 def on_msg(self, callback):
118 126 """Register a callback for comm_msg
119 127
120 128 Will be called with the `data` of any comm_msg messages.
121 129
122 130 Call `on_msg(None)` to disable an existing callback.
123 131 """
124 132 self._msg_callback = callback
125 133
126 134 # handling of incoming messages
127 135
128 136 def handle_close(self, msg):
129 137 """Handle a comm_close message"""
130 138 self.log.debug("handle_close[%s](%s)", self.comm_id, msg)
131 139 if self._close_callback:
132 140 self._close_callback(msg)
133 141
134 142 def handle_msg(self, msg):
135 143 """Handle a comm_msg message"""
136 144 self.log.debug("handle_msg[%s](%s)", self.comm_id, msg)
137 145 if self._msg_callback:
138 146 if self.shell:
139 147 self.shell.events.trigger('pre_execute')
140 148 self._msg_callback(msg)
141 149 if self.shell:
142 150 self.shell.events.trigger('post_execute')
143 151
144 152
145 153 __all__ = ['Comm']
@@ -1,692 +1,692
1 1 """Base class for a kernel that talks to frontends over 0MQ."""
2 2
3 3 # Copyright (c) IPython Development Team.
4 4 # Distributed under the terms of the Modified BSD License.
5 5
6 6 from __future__ import print_function
7 7
8 8 import sys
9 9 import time
10 10 import logging
11 11 import uuid
12 12
13 13 from datetime import datetime
14 14 from signal import (
15 15 signal, default_int_handler, SIGINT
16 16 )
17 17
18 18 import zmq
19 19 from zmq.eventloop import ioloop
20 20 from zmq.eventloop.zmqstream import ZMQStream
21 21
22 22 from IPython.config.configurable import SingletonConfigurable
23 23 from IPython.core.error import StdinNotImplementedError
24 24 from IPython.core import release
25 25 from IPython.utils import py3compat
26 26 from IPython.utils.py3compat import unicode_type, string_types
27 27 from IPython.utils.jsonutil import json_clean
28 28 from IPython.utils.traitlets import (
29 29 Any, Instance, Float, Dict, List, Set, Integer, Unicode, Bool,
30 30 )
31 31
32 32 from .session import Session
33 33
34 34
35 35 class Kernel(SingletonConfigurable):
36 36
37 37 #---------------------------------------------------------------------------
38 38 # Kernel interface
39 39 #---------------------------------------------------------------------------
40 40
41 41 # attribute to override with a GUI
42 42 eventloop = Any(None)
43 43 def _eventloop_changed(self, name, old, new):
44 44 """schedule call to eventloop from IOLoop"""
45 45 loop = ioloop.IOLoop.instance()
46 46 loop.add_callback(self.enter_eventloop)
47 47
48 48 session = Instance(Session)
49 49 profile_dir = Instance('IPython.core.profiledir.ProfileDir')
50 50 shell_streams = List()
51 51 control_stream = Instance(ZMQStream)
52 52 iopub_socket = Instance(zmq.Socket)
53 53 stdin_socket = Instance(zmq.Socket)
54 54 log = Instance(logging.Logger)
55 55
56 56 # identities:
57 57 int_id = Integer(-1)
58 58 ident = Unicode()
59 59
60 60 def _ident_default(self):
61 61 return unicode_type(uuid.uuid4())
62 62
63 63 # Private interface
64 64
65 65 _darwin_app_nap = Bool(True, config=True,
66 66 help="""Whether to use appnope for compatiblity with OS X App Nap.
67 67
68 68 Only affects OS X >= 10.9.
69 69 """
70 70 )
71 71
72 72 # track associations with current request
73 73 _allow_stdin = Bool(False)
74 74 _parent_header = Dict()
75 75 _parent_ident = Any(b'')
76 76 # Time to sleep after flushing the stdout/err buffers in each execute
77 77 # cycle. While this introduces a hard limit on the minimal latency of the
78 78 # execute cycle, it helps prevent output synchronization problems for
79 79 # clients.
80 80 # Units are in seconds. The minimum zmq latency on local host is probably
81 81 # ~150 microseconds, set this to 500us for now. We may need to increase it
82 82 # a little if it's not enough after more interactive testing.
83 83 _execute_sleep = Float(0.0005, config=True)
84 84
85 85 # Frequency of the kernel's event loop.
86 86 # Units are in seconds, kernel subclasses for GUI toolkits may need to
87 87 # adapt to milliseconds.
88 88 _poll_interval = Float(0.05, config=True)
89 89
90 90 # If the shutdown was requested over the network, we leave here the
91 91 # necessary reply message so it can be sent by our registered atexit
92 92 # handler. This ensures that the reply is only sent to clients truly at
93 93 # the end of our shutdown process (which happens after the underlying
94 94 # IPython shell's own shutdown).
95 95 _shutdown_message = None
96 96
97 97 # This is a dict of port number that the kernel is listening on. It is set
98 98 # by record_ports and used by connect_request.
99 99 _recorded_ports = Dict()
100 100
101 101 # set of aborted msg_ids
102 102 aborted = Set()
103 103
104 104 # Track execution count here. For IPython, we override this to use the
105 105 # execution count we store in the shell.
106 106 execution_count = 0
107 107
108 108
109 109 def __init__(self, **kwargs):
110 110 super(Kernel, self).__init__(**kwargs)
111 111
112 112 # Build dict of handlers for message types
113 113 msg_types = [ 'execute_request', 'complete_request',
114 114 'inspect_request', 'history_request',
115 115 'kernel_info_request',
116 116 'connect_request', 'shutdown_request',
117 117 'apply_request', 'is_complete_request',
118 118 ]
119 119 self.shell_handlers = {}
120 120 for msg_type in msg_types:
121 121 self.shell_handlers[msg_type] = getattr(self, msg_type)
122 122
123 123 control_msg_types = msg_types + [ 'clear_request', 'abort_request' ]
124 124 self.control_handlers = {}
125 125 for msg_type in control_msg_types:
126 126 self.control_handlers[msg_type] = getattr(self, msg_type)
127 127
128 128
129 129 def dispatch_control(self, msg):
130 130 """dispatch control requests"""
131 131 idents,msg = self.session.feed_identities(msg, copy=False)
132 132 try:
133 msg = self.session.unserialize(msg, content=True, copy=False)
133 msg = self.session.deserialize(msg, content=True, copy=False)
134 134 except:
135 135 self.log.error("Invalid Control Message", exc_info=True)
136 136 return
137 137
138 138 self.log.debug("Control received: %s", msg)
139 139
140 140 # Set the parent message for side effects.
141 141 self.set_parent(idents, msg)
142 142 self._publish_status(u'busy')
143 143
144 144 header = msg['header']
145 145 msg_type = header['msg_type']
146 146
147 147 handler = self.control_handlers.get(msg_type, None)
148 148 if handler is None:
149 149 self.log.error("UNKNOWN CONTROL MESSAGE TYPE: %r", msg_type)
150 150 else:
151 151 try:
152 152 handler(self.control_stream, idents, msg)
153 153 except Exception:
154 154 self.log.error("Exception in control handler:", exc_info=True)
155 155
156 156 sys.stdout.flush()
157 157 sys.stderr.flush()
158 158 self._publish_status(u'idle')
159 159
160 160 def dispatch_shell(self, stream, msg):
161 161 """dispatch shell requests"""
162 162 # flush control requests first
163 163 if self.control_stream:
164 164 self.control_stream.flush()
165 165
166 166 idents,msg = self.session.feed_identities(msg, copy=False)
167 167 try:
168 msg = self.session.unserialize(msg, content=True, copy=False)
168 msg = self.session.deserialize(msg, content=True, copy=False)
169 169 except:
170 170 self.log.error("Invalid Message", exc_info=True)
171 171 return
172 172
173 173 # Set the parent message for side effects.
174 174 self.set_parent(idents, msg)
175 175 self._publish_status(u'busy')
176 176
177 177 header = msg['header']
178 178 msg_id = header['msg_id']
179 179 msg_type = msg['header']['msg_type']
180 180
181 181 # Print some info about this message and leave a '--->' marker, so it's
182 182 # easier to trace visually the message chain when debugging. Each
183 183 # handler prints its message at the end.
184 184 self.log.debug('\n*** MESSAGE TYPE:%s***', msg_type)
185 185 self.log.debug(' Content: %s\n --->\n ', msg['content'])
186 186
187 187 if msg_id in self.aborted:
188 188 self.aborted.remove(msg_id)
189 189 # is it safe to assume a msg_id will not be resubmitted?
190 190 reply_type = msg_type.split('_')[0] + '_reply'
191 191 status = {'status' : 'aborted'}
192 192 md = {'engine' : self.ident}
193 193 md.update(status)
194 194 self.session.send(stream, reply_type, metadata=md,
195 195 content=status, parent=msg, ident=idents)
196 196 return
197 197
198 198 handler = self.shell_handlers.get(msg_type, None)
199 199 if handler is None:
200 200 self.log.error("UNKNOWN MESSAGE TYPE: %r", msg_type)
201 201 else:
202 202 # ensure default_int_handler during handler call
203 203 sig = signal(SIGINT, default_int_handler)
204 204 self.log.debug("%s: %s", msg_type, msg)
205 205 try:
206 206 handler(stream, idents, msg)
207 207 except Exception:
208 208 self.log.error("Exception in message handler:", exc_info=True)
209 209 finally:
210 210 signal(SIGINT, sig)
211 211
212 212 sys.stdout.flush()
213 213 sys.stderr.flush()
214 214 self._publish_status(u'idle')
215 215
216 216 def enter_eventloop(self):
217 217 """enter eventloop"""
218 218 self.log.info("entering eventloop %s", self.eventloop)
219 219 for stream in self.shell_streams:
220 220 # flush any pending replies,
221 221 # which may be skipped by entering the eventloop
222 222 stream.flush(zmq.POLLOUT)
223 223 # restore default_int_handler
224 224 signal(SIGINT, default_int_handler)
225 225 while self.eventloop is not None:
226 226 try:
227 227 self.eventloop(self)
228 228 except KeyboardInterrupt:
229 229 # Ctrl-C shouldn't crash the kernel
230 230 self.log.error("KeyboardInterrupt caught in kernel")
231 231 continue
232 232 else:
233 233 # eventloop exited cleanly, this means we should stop (right?)
234 234 self.eventloop = None
235 235 break
236 236 self.log.info("exiting eventloop")
237 237
238 238 def start(self):
239 239 """register dispatchers for streams"""
240 240 if self.control_stream:
241 241 self.control_stream.on_recv(self.dispatch_control, copy=False)
242 242
243 243 def make_dispatcher(stream):
244 244 def dispatcher(msg):
245 245 return self.dispatch_shell(stream, msg)
246 246 return dispatcher
247 247
248 248 for s in self.shell_streams:
249 249 s.on_recv(make_dispatcher(s), copy=False)
250 250
251 251 # publish idle status
252 252 self._publish_status('starting')
253 253
254 254 def do_one_iteration(self):
255 255 """step eventloop just once"""
256 256 if self.control_stream:
257 257 self.control_stream.flush()
258 258 for stream in self.shell_streams:
259 259 # handle at most one request per iteration
260 260 stream.flush(zmq.POLLIN, 1)
261 261 stream.flush(zmq.POLLOUT)
262 262
263 263
264 264 def record_ports(self, ports):
265 265 """Record the ports that this kernel is using.
266 266
267 267 The creator of the Kernel instance must call this methods if they
268 268 want the :meth:`connect_request` method to return the port numbers.
269 269 """
270 270 self._recorded_ports = ports
271 271
272 272 #---------------------------------------------------------------------------
273 273 # Kernel request handlers
274 274 #---------------------------------------------------------------------------
275 275
276 276 def _make_metadata(self, other=None):
277 277 """init metadata dict, for execute/apply_reply"""
278 278 new_md = {
279 279 'dependencies_met' : True,
280 280 'engine' : self.ident,
281 281 'started': datetime.now(),
282 282 }
283 283 if other:
284 284 new_md.update(other)
285 285 return new_md
286 286
287 287 def _publish_execute_input(self, code, parent, execution_count):
288 288 """Publish the code request on the iopub stream."""
289 289
290 290 self.session.send(self.iopub_socket, u'execute_input',
291 291 {u'code':code, u'execution_count': execution_count},
292 292 parent=parent, ident=self._topic('execute_input')
293 293 )
294 294
295 295 def _publish_status(self, status, parent=None):
296 296 """send status (busy/idle) on IOPub"""
297 297 self.session.send(self.iopub_socket,
298 298 u'status',
299 299 {u'execution_state': status},
300 300 parent=parent or self._parent_header,
301 301 ident=self._topic('status'),
302 302 )
303 303
304 304 def set_parent(self, ident, parent):
305 305 """Set the current parent_header
306 306
307 307 Side effects (IOPub messages) and replies are associated with
308 308 the request that caused them via the parent_header.
309 309
310 310 The parent identity is used to route input_request messages
311 311 on the stdin channel.
312 312 """
313 313 self._parent_ident = ident
314 314 self._parent_header = parent
315 315
316 316 def send_response(self, stream, msg_or_type, content=None, ident=None,
317 317 buffers=None, track=False, header=None, metadata=None):
318 318 """Send a response to the message we're currently processing.
319 319
320 320 This accepts all the parameters of :meth:`IPython.kernel.zmq.session.Session.send`
321 321 except ``parent``.
322 322
323 323 This relies on :meth:`set_parent` having been called for the current
324 324 message.
325 325 """
326 326 return self.session.send(stream, msg_or_type, content, self._parent_header,
327 327 ident, buffers, track, header, metadata)
328 328
329 329 def execute_request(self, stream, ident, parent):
330 330 """handle an execute_request"""
331 331
332 332 try:
333 333 content = parent[u'content']
334 334 code = py3compat.cast_unicode_py2(content[u'code'])
335 335 silent = content[u'silent']
336 336 store_history = content.get(u'store_history', not silent)
337 337 user_expressions = content.get('user_expressions', {})
338 338 allow_stdin = content.get('allow_stdin', False)
339 339 except:
340 340 self.log.error("Got bad msg: ")
341 341 self.log.error("%s", parent)
342 342 return
343 343
344 344 md = self._make_metadata(parent['metadata'])
345 345
346 346 # Re-broadcast our input for the benefit of listening clients, and
347 347 # start computing output
348 348 if not silent:
349 349 self.execution_count += 1
350 350 self._publish_execute_input(code, parent, self.execution_count)
351 351
352 352 reply_content = self.do_execute(code, silent, store_history,
353 353 user_expressions, allow_stdin)
354 354
355 355 # Flush output before sending the reply.
356 356 sys.stdout.flush()
357 357 sys.stderr.flush()
358 358 # FIXME: on rare occasions, the flush doesn't seem to make it to the
359 359 # clients... This seems to mitigate the problem, but we definitely need
360 360 # to better understand what's going on.
361 361 if self._execute_sleep:
362 362 time.sleep(self._execute_sleep)
363 363
364 364 # Send the reply.
365 365 reply_content = json_clean(reply_content)
366 366
367 367 md['status'] = reply_content['status']
368 368 if reply_content['status'] == 'error' and \
369 369 reply_content['ename'] == 'UnmetDependency':
370 370 md['dependencies_met'] = False
371 371
372 372 reply_msg = self.session.send(stream, u'execute_reply',
373 373 reply_content, parent, metadata=md,
374 374 ident=ident)
375 375
376 376 self.log.debug("%s", reply_msg)
377 377
378 378 if not silent and reply_msg['content']['status'] == u'error':
379 379 self._abort_queues()
380 380
381 381 def do_execute(self, code, silent, store_history=True,
382 382 user_experssions=None, allow_stdin=False):
383 383 """Execute user code. Must be overridden by subclasses.
384 384 """
385 385 raise NotImplementedError
386 386
387 387 def complete_request(self, stream, ident, parent):
388 388 content = parent['content']
389 389 code = content['code']
390 390 cursor_pos = content['cursor_pos']
391 391
392 392 matches = self.do_complete(code, cursor_pos)
393 393 matches = json_clean(matches)
394 394 completion_msg = self.session.send(stream, 'complete_reply',
395 395 matches, parent, ident)
396 396 self.log.debug("%s", completion_msg)
397 397
398 398 def do_complete(self, code, cursor_pos):
399 399 """Override in subclasses to find completions.
400 400 """
401 401 return {'matches' : [],
402 402 'cursor_end' : cursor_pos,
403 403 'cursor_start' : cursor_pos,
404 404 'metadata' : {},
405 405 'status' : 'ok'}
406 406
407 407 def inspect_request(self, stream, ident, parent):
408 408 content = parent['content']
409 409
410 410 reply_content = self.do_inspect(content['code'], content['cursor_pos'],
411 411 content.get('detail_level', 0))
412 412 # Before we send this object over, we scrub it for JSON usage
413 413 reply_content = json_clean(reply_content)
414 414 msg = self.session.send(stream, 'inspect_reply',
415 415 reply_content, parent, ident)
416 416 self.log.debug("%s", msg)
417 417
418 418 def do_inspect(self, code, cursor_pos, detail_level=0):
419 419 """Override in subclasses to allow introspection.
420 420 """
421 421 return {'status': 'ok', 'data':{}, 'metadata':{}, 'found':False}
422 422
423 423 def history_request(self, stream, ident, parent):
424 424 content = parent['content']
425 425
426 426 reply_content = self.do_history(**content)
427 427
428 428 reply_content = json_clean(reply_content)
429 429 msg = self.session.send(stream, 'history_reply',
430 430 reply_content, parent, ident)
431 431 self.log.debug("%s", msg)
432 432
433 433 def do_history(self, hist_access_type, output, raw, session=None, start=None,
434 434 stop=None, n=None, pattern=None, unique=False):
435 435 """Override in subclasses to access history.
436 436 """
437 437 return {'history': []}
438 438
439 439 def connect_request(self, stream, ident, parent):
440 440 if self._recorded_ports is not None:
441 441 content = self._recorded_ports.copy()
442 442 else:
443 443 content = {}
444 444 msg = self.session.send(stream, 'connect_reply',
445 445 content, parent, ident)
446 446 self.log.debug("%s", msg)
447 447
448 448 @property
449 449 def kernel_info(self):
450 450 return {
451 451 'protocol_version': release.kernel_protocol_version,
452 452 'implementation': self.implementation,
453 453 'implementation_version': self.implementation_version,
454 454 'language': self.language,
455 455 'language_version': self.language_version,
456 456 'banner': self.banner,
457 457 }
458 458
459 459 def kernel_info_request(self, stream, ident, parent):
460 460 msg = self.session.send(stream, 'kernel_info_reply',
461 461 self.kernel_info, parent, ident)
462 462 self.log.debug("%s", msg)
463 463
464 464 def shutdown_request(self, stream, ident, parent):
465 465 content = self.do_shutdown(parent['content']['restart'])
466 466 self.session.send(stream, u'shutdown_reply', content, parent, ident=ident)
467 467 # same content, but different msg_id for broadcasting on IOPub
468 468 self._shutdown_message = self.session.msg(u'shutdown_reply',
469 469 content, parent
470 470 )
471 471
472 472 self._at_shutdown()
473 473 # call sys.exit after a short delay
474 474 loop = ioloop.IOLoop.instance()
475 475 loop.add_timeout(time.time()+0.1, loop.stop)
476 476
477 477 def do_shutdown(self, restart):
478 478 """Override in subclasses to do things when the frontend shuts down the
479 479 kernel.
480 480 """
481 481 return {'status': 'ok', 'restart': restart}
482 482
483 483 def is_complete_request(self, stream, ident, parent):
484 484 content = parent['content']
485 485 code = content['code']
486 486
487 487 reply_content = self.do_is_complete(code)
488 488 reply_content = json_clean(reply_content)
489 489 reply_msg = self.session.send(stream, 'is_complete_reply',
490 490 reply_content, parent, ident)
491 491 self.log.debug("%s", reply_msg)
492 492
493 493 def do_is_complete(self, code):
494 494 """Override in subclasses to find completions.
495 495 """
496 496 return {'status' : 'unknown',
497 497 }
498 498
499 499 #---------------------------------------------------------------------------
500 500 # Engine methods
501 501 #---------------------------------------------------------------------------
502 502
503 503 def apply_request(self, stream, ident, parent):
504 504 try:
505 505 content = parent[u'content']
506 506 bufs = parent[u'buffers']
507 507 msg_id = parent['header']['msg_id']
508 508 except:
509 509 self.log.error("Got bad msg: %s", parent, exc_info=True)
510 510 return
511 511
512 512 md = self._make_metadata(parent['metadata'])
513 513
514 514 reply_content, result_buf = self.do_apply(content, bufs, msg_id, md)
515 515
516 516 # put 'ok'/'error' status in header, for scheduler introspection:
517 517 md['status'] = reply_content['status']
518 518
519 519 # flush i/o
520 520 sys.stdout.flush()
521 521 sys.stderr.flush()
522 522
523 523 self.session.send(stream, u'apply_reply', reply_content,
524 524 parent=parent, ident=ident,buffers=result_buf, metadata=md)
525 525
526 526 def do_apply(self, content, bufs, msg_id, reply_metadata):
527 527 """Override in subclasses to support the IPython parallel framework.
528 528 """
529 529 raise NotImplementedError
530 530
531 531 #---------------------------------------------------------------------------
532 532 # Control messages
533 533 #---------------------------------------------------------------------------
534 534
535 535 def abort_request(self, stream, ident, parent):
536 536 """abort a specific msg by id"""
537 537 msg_ids = parent['content'].get('msg_ids', None)
538 538 if isinstance(msg_ids, string_types):
539 539 msg_ids = [msg_ids]
540 540 if not msg_ids:
541 541 self._abort_queues()
542 542 for mid in msg_ids:
543 543 self.aborted.add(str(mid))
544 544
545 545 content = dict(status='ok')
546 546 reply_msg = self.session.send(stream, 'abort_reply', content=content,
547 547 parent=parent, ident=ident)
548 548 self.log.debug("%s", reply_msg)
549 549
550 550 def clear_request(self, stream, idents, parent):
551 551 """Clear our namespace."""
552 552 content = self.do_clear()
553 553 self.session.send(stream, 'clear_reply', ident=idents, parent=parent,
554 554 content = content)
555 555
556 556 def do_clear(self):
557 557 """Override in subclasses to clear the namespace
558 558
559 559 This is only required for IPython.parallel.
560 560 """
561 561 raise NotImplementedError
562 562
563 563 #---------------------------------------------------------------------------
564 564 # Protected interface
565 565 #---------------------------------------------------------------------------
566 566
567 567 def _topic(self, topic):
568 568 """prefixed topic for IOPub messages"""
569 569 if self.int_id >= 0:
570 570 base = "engine.%i" % self.int_id
571 571 else:
572 572 base = "kernel.%s" % self.ident
573 573
574 574 return py3compat.cast_bytes("%s.%s" % (base, topic))
575 575
576 576 def _abort_queues(self):
577 577 for stream in self.shell_streams:
578 578 if stream:
579 579 self._abort_queue(stream)
580 580
581 581 def _abort_queue(self, stream):
582 582 poller = zmq.Poller()
583 583 poller.register(stream.socket, zmq.POLLIN)
584 584 while True:
585 585 idents,msg = self.session.recv(stream, zmq.NOBLOCK, content=True)
586 586 if msg is None:
587 587 return
588 588
589 589 self.log.info("Aborting:")
590 590 self.log.info("%s", msg)
591 591 msg_type = msg['header']['msg_type']
592 592 reply_type = msg_type.split('_')[0] + '_reply'
593 593
594 594 status = {'status' : 'aborted'}
595 595 md = {'engine' : self.ident}
596 596 md.update(status)
597 597 reply_msg = self.session.send(stream, reply_type, metadata=md,
598 598 content=status, parent=msg, ident=idents)
599 599 self.log.debug("%s", reply_msg)
600 600 # We need to wait a bit for requests to come in. This can probably
601 601 # be set shorter for true asynchronous clients.
602 602 poller.poll(50)
603 603
604 604
605 605 def _no_raw_input(self):
606 606 """Raise StdinNotImplentedError if active frontend doesn't support
607 607 stdin."""
608 608 raise StdinNotImplementedError("raw_input was called, but this "
609 609 "frontend does not support stdin.")
610 610
611 611 def getpass(self, prompt=''):
612 612 """Forward getpass to frontends
613 613
614 614 Raises
615 615 ------
616 616 StdinNotImplentedError if active frontend doesn't support stdin.
617 617 """
618 618 if not self._allow_stdin:
619 619 raise StdinNotImplementedError(
620 620 "getpass was called, but this frontend does not support input requests."
621 621 )
622 622 return self._input_request(prompt,
623 623 self._parent_ident,
624 624 self._parent_header,
625 625 password=True,
626 626 )
627 627
628 628 def raw_input(self, prompt=''):
629 629 """Forward raw_input to frontends
630 630
631 631 Raises
632 632 ------
633 633 StdinNotImplentedError if active frontend doesn't support stdin.
634 634 """
635 635 if not self._allow_stdin:
636 636 raise StdinNotImplementedError(
637 637 "raw_input was called, but this frontend does not support input requests."
638 638 )
639 639 return self._input_request(prompt,
640 640 self._parent_ident,
641 641 self._parent_header,
642 642 password=False,
643 643 )
644 644
645 645 def _input_request(self, prompt, ident, parent, password=False):
646 646 # Flush output before making the request.
647 647 sys.stderr.flush()
648 648 sys.stdout.flush()
649 649 # flush the stdin socket, to purge stale replies
650 650 while True:
651 651 try:
652 652 self.stdin_socket.recv_multipart(zmq.NOBLOCK)
653 653 except zmq.ZMQError as e:
654 654 if e.errno == zmq.EAGAIN:
655 655 break
656 656 else:
657 657 raise
658 658
659 659 # Send the input request.
660 660 content = json_clean(dict(prompt=prompt, password=password))
661 661 self.session.send(self.stdin_socket, u'input_request', content, parent,
662 662 ident=ident)
663 663
664 664 # Await a response.
665 665 while True:
666 666 try:
667 667 ident, reply = self.session.recv(self.stdin_socket, 0)
668 668 except Exception:
669 669 self.log.warn("Invalid Message:", exc_info=True)
670 670 except KeyboardInterrupt:
671 671 # re-raise KeyboardInterrupt, to truncate traceback
672 672 raise KeyboardInterrupt
673 673 else:
674 674 break
675 675 try:
676 676 value = py3compat.unicode_to_str(reply['content']['value'])
677 677 except:
678 678 self.log.error("Bad input_reply: %s", parent)
679 679 value = ''
680 680 if value == '\x04':
681 681 # EOF
682 682 raise EOFError
683 683 return value
684 684
685 685 def _at_shutdown(self):
686 686 """Actions taken at shutdown by the kernel, called by python's atexit.
687 687 """
688 688 # io.rprint("Kernel at_shutdown") # dbg
689 689 if self._shutdown_message is not None:
690 690 self.session.send(self.iopub_socket, self._shutdown_message, ident=self._topic('shutdown'))
691 691 self.log.debug("%s", self._shutdown_message)
692 692 [ s.flush(zmq.POLLOUT) for s in self.shell_streams ]
@@ -1,185 +1,185
1 1 """serialization utilities for apply messages"""
2 2
3 3 # Copyright (c) IPython Development Team.
4 4 # Distributed under the terms of the Modified BSD License.
5 5
6 6 try:
7 7 import cPickle
8 8 pickle = cPickle
9 9 except:
10 10 cPickle = None
11 11 import pickle
12 12
13 13 # IPython imports
14 14 from IPython.utils import py3compat
15 15 from IPython.utils.data import flatten
16 16 from IPython.utils.pickleutil import (
17 17 can, uncan, can_sequence, uncan_sequence, CannedObject,
18 18 istype, sequence_types, PICKLE_PROTOCOL,
19 19 )
20 20
21 21 if py3compat.PY3:
22 22 buffer = memoryview
23 23
24 24 #-----------------------------------------------------------------------------
25 25 # Serialization Functions
26 26 #-----------------------------------------------------------------------------
27 27
28 28 # default values for the thresholds:
29 29 MAX_ITEMS = 64
30 30 MAX_BYTES = 1024
31 31
32 32 def _extract_buffers(obj, threshold=MAX_BYTES):
33 33 """extract buffers larger than a certain threshold"""
34 34 buffers = []
35 35 if isinstance(obj, CannedObject) and obj.buffers:
36 36 for i,buf in enumerate(obj.buffers):
37 37 if len(buf) > threshold:
38 38 # buffer larger than threshold, prevent pickling
39 39 obj.buffers[i] = None
40 40 buffers.append(buf)
41 41 elif isinstance(buf, buffer):
42 42 # buffer too small for separate send, coerce to bytes
43 43 # because pickling buffer objects just results in broken pointers
44 44 obj.buffers[i] = bytes(buf)
45 45 return buffers
46 46
47 47 def _restore_buffers(obj, buffers):
48 48 """restore buffers extracted by """
49 49 if isinstance(obj, CannedObject) and obj.buffers:
50 50 for i,buf in enumerate(obj.buffers):
51 51 if buf is None:
52 52 obj.buffers[i] = buffers.pop(0)
53 53
54 54 def serialize_object(obj, buffer_threshold=MAX_BYTES, item_threshold=MAX_ITEMS):
55 55 """Serialize an object into a list of sendable buffers.
56 56
57 57 Parameters
58 58 ----------
59 59
60 60 obj : object
61 61 The object to be serialized
62 62 buffer_threshold : int
63 63 The threshold (in bytes) for pulling out data buffers
64 64 to avoid pickling them.
65 65 item_threshold : int
66 66 The maximum number of items over which canning will iterate.
67 67 Containers (lists, dicts) larger than this will be pickled without
68 68 introspection.
69 69
70 70 Returns
71 71 -------
72 72 [bufs] : list of buffers representing the serialized object.
73 73 """
74 74 buffers = []
75 75 if istype(obj, sequence_types) and len(obj) < item_threshold:
76 76 cobj = can_sequence(obj)
77 77 for c in cobj:
78 78 buffers.extend(_extract_buffers(c, buffer_threshold))
79 79 elif istype(obj, dict) and len(obj) < item_threshold:
80 80 cobj = {}
81 81 for k in sorted(obj):
82 82 c = can(obj[k])
83 83 buffers.extend(_extract_buffers(c, buffer_threshold))
84 84 cobj[k] = c
85 85 else:
86 86 cobj = can(obj)
87 87 buffers.extend(_extract_buffers(cobj, buffer_threshold))
88 88
89 89 buffers.insert(0, pickle.dumps(cobj, PICKLE_PROTOCOL))
90 90 return buffers
91 91
92 def unserialize_object(buffers, g=None):
92 def deserialize_object(buffers, g=None):
93 93 """reconstruct an object serialized by serialize_object from data buffers.
94 94
95 95 Parameters
96 96 ----------
97 97
98 98 bufs : list of buffers/bytes
99 99
100 100 g : globals to be used when uncanning
101 101
102 102 Returns
103 103 -------
104 104
105 105 (newobj, bufs) : unpacked object, and the list of remaining unused buffers.
106 106 """
107 107 bufs = list(buffers)
108 108 pobj = bufs.pop(0)
109 109 if not isinstance(pobj, bytes):
110 110 # a zmq message
111 111 pobj = bytes(pobj)
112 112 canned = pickle.loads(pobj)
113 113 if istype(canned, sequence_types) and len(canned) < MAX_ITEMS:
114 114 for c in canned:
115 115 _restore_buffers(c, bufs)
116 116 newobj = uncan_sequence(canned, g)
117 117 elif istype(canned, dict) and len(canned) < MAX_ITEMS:
118 118 newobj = {}
119 119 for k in sorted(canned):
120 120 c = canned[k]
121 121 _restore_buffers(c, bufs)
122 122 newobj[k] = uncan(c, g)
123 123 else:
124 124 _restore_buffers(canned, bufs)
125 125 newobj = uncan(canned, g)
126 126
127 127 return newobj, bufs
128 128
129 129 def pack_apply_message(f, args, kwargs, buffer_threshold=MAX_BYTES, item_threshold=MAX_ITEMS):
130 130 """pack up a function, args, and kwargs to be sent over the wire
131 131
132 132 Each element of args/kwargs will be canned for special treatment,
133 133 but inspection will not go any deeper than that.
134 134
135 135 Any object whose data is larger than `threshold` will not have their data copied
136 136 (only numpy arrays and bytes/buffers support zero-copy)
137 137
138 138 Message will be a list of bytes/buffers of the format:
139 139
140 140 [ cf, pinfo, <arg_bufs>, <kwarg_bufs> ]
141 141
142 142 With length at least two + len(args) + len(kwargs)
143 143 """
144 144
145 145 arg_bufs = flatten(serialize_object(arg, buffer_threshold, item_threshold) for arg in args)
146 146
147 147 kw_keys = sorted(kwargs.keys())
148 148 kwarg_bufs = flatten(serialize_object(kwargs[key], buffer_threshold, item_threshold) for key in kw_keys)
149 149
150 150 info = dict(nargs=len(args), narg_bufs=len(arg_bufs), kw_keys=kw_keys)
151 151
152 152 msg = [pickle.dumps(can(f), PICKLE_PROTOCOL)]
153 153 msg.append(pickle.dumps(info, PICKLE_PROTOCOL))
154 154 msg.extend(arg_bufs)
155 155 msg.extend(kwarg_bufs)
156 156
157 157 return msg
158 158
159 159 def unpack_apply_message(bufs, g=None, copy=True):
160 160 """unpack f,args,kwargs from buffers packed by pack_apply_message()
161 161 Returns: original f,args,kwargs"""
162 162 bufs = list(bufs) # allow us to pop
163 163 assert len(bufs) >= 2, "not enough buffers!"
164 164 if not copy:
165 165 for i in range(2):
166 166 bufs[i] = bufs[i].bytes
167 167 f = uncan(pickle.loads(bufs.pop(0)), g)
168 168 info = pickle.loads(bufs.pop(0))
169 169 arg_bufs, kwarg_bufs = bufs[:info['narg_bufs']], bufs[info['narg_bufs']:]
170 170
171 171 args = []
172 172 for i in range(info['nargs']):
173 arg, arg_bufs = unserialize_object(arg_bufs, g)
173 arg, arg_bufs = deserialize_object(arg_bufs, g)
174 174 args.append(arg)
175 175 args = tuple(args)
176 176 assert not arg_bufs, "Shouldn't be any arg bufs left over"
177 177
178 178 kwargs = {}
179 179 for key in info['kw_keys']:
180 kwarg, kwarg_bufs = unserialize_object(kwarg_bufs, g)
180 kwarg, kwarg_bufs = deserialize_object(kwarg_bufs, g)
181 181 kwargs[key] = kwarg
182 182 assert not kwarg_bufs, "Shouldn't be any kwarg bufs left over"
183 183
184 184 return f,args,kwargs
185 185
@@ -1,865 +1,873
1 1 """Session object for building, serializing, sending, and receiving messages in
2 2 IPython. The Session object supports serialization, HMAC signatures, and
3 3 metadata on messages.
4 4
5 5 Also defined here are utilities for working with Sessions:
6 6 * A SessionFactory to be used as a base class for configurables that work with
7 7 Sessions.
8 8 * A Message object for convenience that allows attribute-access to the msg dict.
9 9 """
10 10
11 11 # Copyright (c) IPython Development Team.
12 12 # Distributed under the terms of the Modified BSD License.
13 13
14 14 import hashlib
15 15 import hmac
16 16 import logging
17 17 import os
18 18 import pprint
19 19 import random
20 20 import uuid
21 import warnings
21 22 from datetime import datetime
22 23
23 24 try:
24 25 import cPickle
25 26 pickle = cPickle
26 27 except:
27 28 cPickle = None
28 29 import pickle
29 30
30 31 try:
31 32 # We are using compare_digest to limit the surface of timing attacks
32 33 from hmac import compare_digest
33 34 except ImportError:
34 35 # Python < 2.7.7: When digests don't match no feedback is provided,
35 36 # limiting the surface of attack
36 37 def compare_digest(a,b): return a == b
37 38
38 39 import zmq
39 40 from zmq.utils import jsonapi
40 41 from zmq.eventloop.ioloop import IOLoop
41 42 from zmq.eventloop.zmqstream import ZMQStream
42 43
43 44 from IPython.core.release import kernel_protocol_version
44 45 from IPython.config.configurable import Configurable, LoggingConfigurable
45 46 from IPython.utils import io
46 47 from IPython.utils.importstring import import_item
47 48 from IPython.utils.jsonutil import extract_dates, squash_dates, date_default
48 49 from IPython.utils.py3compat import (str_to_bytes, str_to_unicode, unicode_type,
49 50 iteritems)
50 51 from IPython.utils.traitlets import (CBytes, Unicode, Bool, Any, Instance, Set,
51 52 DottedObjectName, CUnicode, Dict, Integer,
52 53 TraitError,
53 54 )
54 55 from IPython.utils.pickleutil import PICKLE_PROTOCOL
55 56 from IPython.kernel.adapter import adapt
56 57 from IPython.kernel.zmq.serialize import MAX_ITEMS, MAX_BYTES
57 58
58 59 #-----------------------------------------------------------------------------
59 60 # utility functions
60 61 #-----------------------------------------------------------------------------
61 62
62 63 def squash_unicode(obj):
63 64 """coerce unicode back to bytestrings."""
64 65 if isinstance(obj,dict):
65 66 for key in obj.keys():
66 67 obj[key] = squash_unicode(obj[key])
67 68 if isinstance(key, unicode_type):
68 69 obj[squash_unicode(key)] = obj.pop(key)
69 70 elif isinstance(obj, list):
70 71 for i,v in enumerate(obj):
71 72 obj[i] = squash_unicode(v)
72 73 elif isinstance(obj, unicode_type):
73 74 obj = obj.encode('utf8')
74 75 return obj
75 76
76 77 #-----------------------------------------------------------------------------
77 78 # globals and defaults
78 79 #-----------------------------------------------------------------------------
79 80
80 81 # ISO8601-ify datetime objects
81 82 # allow unicode
82 83 # disallow nan, because it's not actually valid JSON
83 84 json_packer = lambda obj: jsonapi.dumps(obj, default=date_default,
84 85 ensure_ascii=False, allow_nan=False,
85 86 )
86 87 json_unpacker = lambda s: jsonapi.loads(s)
87 88
88 89 pickle_packer = lambda o: pickle.dumps(squash_dates(o), PICKLE_PROTOCOL)
89 90 pickle_unpacker = pickle.loads
90 91
91 92 default_packer = json_packer
92 93 default_unpacker = json_unpacker
93 94
94 95 DELIM = b"<IDS|MSG>"
95 96 # singleton dummy tracker, which will always report as done
96 97 DONE = zmq.MessageTracker()
97 98
98 99 #-----------------------------------------------------------------------------
99 100 # Mixin tools for apps that use Sessions
100 101 #-----------------------------------------------------------------------------
101 102
102 103 session_aliases = dict(
103 104 ident = 'Session.session',
104 105 user = 'Session.username',
105 106 keyfile = 'Session.keyfile',
106 107 )
107 108
108 109 session_flags = {
109 110 'secure' : ({'Session' : { 'key' : str_to_bytes(str(uuid.uuid4())),
110 111 'keyfile' : '' }},
111 112 """Use HMAC digests for authentication of messages.
112 113 Setting this flag will generate a new UUID to use as the HMAC key.
113 114 """),
114 115 'no-secure' : ({'Session' : { 'key' : b'', 'keyfile' : '' }},
115 116 """Don't authenticate messages."""),
116 117 }
117 118
118 119 def default_secure(cfg):
119 120 """Set the default behavior for a config environment to be secure.
120 121
121 122 If Session.key/keyfile have not been set, set Session.key to
122 123 a new random UUID.
123 124 """
124 125
125 126 if 'Session' in cfg:
126 127 if 'key' in cfg.Session or 'keyfile' in cfg.Session:
127 128 return
128 129 # key/keyfile not specified, generate new UUID:
129 130 cfg.Session.key = str_to_bytes(str(uuid.uuid4()))
130 131
131 132
132 133 #-----------------------------------------------------------------------------
133 134 # Classes
134 135 #-----------------------------------------------------------------------------
135 136
136 137 class SessionFactory(LoggingConfigurable):
137 138 """The Base class for configurables that have a Session, Context, logger,
138 139 and IOLoop.
139 140 """
140 141
141 142 logname = Unicode('')
142 143 def _logname_changed(self, name, old, new):
143 144 self.log = logging.getLogger(new)
144 145
145 146 # not configurable:
146 147 context = Instance('zmq.Context')
147 148 def _context_default(self):
148 149 return zmq.Context.instance()
149 150
150 151 session = Instance('IPython.kernel.zmq.session.Session')
151 152
152 153 loop = Instance('zmq.eventloop.ioloop.IOLoop', allow_none=False)
153 154 def _loop_default(self):
154 155 return IOLoop.instance()
155 156
156 157 def __init__(self, **kwargs):
157 158 super(SessionFactory, self).__init__(**kwargs)
158 159
159 160 if self.session is None:
160 161 # construct the session
161 162 self.session = Session(**kwargs)
162 163
163 164
164 165 class Message(object):
165 166 """A simple message object that maps dict keys to attributes.
166 167
167 168 A Message can be created from a dict and a dict from a Message instance
168 169 simply by calling dict(msg_obj)."""
169 170
170 171 def __init__(self, msg_dict):
171 172 dct = self.__dict__
172 173 for k, v in iteritems(dict(msg_dict)):
173 174 if isinstance(v, dict):
174 175 v = Message(v)
175 176 dct[k] = v
176 177
177 178 # Having this iterator lets dict(msg_obj) work out of the box.
178 179 def __iter__(self):
179 180 return iter(iteritems(self.__dict__))
180 181
181 182 def __repr__(self):
182 183 return repr(self.__dict__)
183 184
184 185 def __str__(self):
185 186 return pprint.pformat(self.__dict__)
186 187
187 188 def __contains__(self, k):
188 189 return k in self.__dict__
189 190
190 191 def __getitem__(self, k):
191 192 return self.__dict__[k]
192 193
193 194
194 195 def msg_header(msg_id, msg_type, username, session):
195 196 date = datetime.now()
196 197 version = kernel_protocol_version
197 198 return locals()
198 199
199 200 def extract_header(msg_or_header):
200 201 """Given a message or header, return the header."""
201 202 if not msg_or_header:
202 203 return {}
203 204 try:
204 205 # See if msg_or_header is the entire message.
205 206 h = msg_or_header['header']
206 207 except KeyError:
207 208 try:
208 209 # See if msg_or_header is just the header
209 210 h = msg_or_header['msg_id']
210 211 except KeyError:
211 212 raise
212 213 else:
213 214 h = msg_or_header
214 215 if not isinstance(h, dict):
215 216 h = dict(h)
216 217 return h
217 218
218 219 class Session(Configurable):
219 220 """Object for handling serialization and sending of messages.
220 221
221 222 The Session object handles building messages and sending them
222 223 with ZMQ sockets or ZMQStream objects. Objects can communicate with each
223 224 other over the network via Session objects, and only need to work with the
224 225 dict-based IPython message spec. The Session will handle
225 226 serialization/deserialization, security, and metadata.
226 227
227 228 Sessions support configurable serialization via packer/unpacker traits,
228 229 and signing with HMAC digests via the key/keyfile traits.
229 230
230 231 Parameters
231 232 ----------
232 233
233 234 debug : bool
234 235 whether to trigger extra debugging statements
235 236 packer/unpacker : str : 'json', 'pickle' or import_string
236 237 importstrings for methods to serialize message parts. If just
237 238 'json' or 'pickle', predefined JSON and pickle packers will be used.
238 239 Otherwise, the entire importstring must be used.
239 240
240 241 The functions must accept at least valid JSON input, and output *bytes*.
241 242
242 243 For example, to use msgpack:
243 244 packer = 'msgpack.packb', unpacker='msgpack.unpackb'
244 245 pack/unpack : callables
245 246 You can also set the pack/unpack callables for serialization directly.
246 247 session : bytes
247 248 the ID of this Session object. The default is to generate a new UUID.
248 249 username : unicode
249 250 username added to message headers. The default is to ask the OS.
250 251 key : bytes
251 252 The key used to initialize an HMAC signature. If unset, messages
252 253 will not be signed or checked.
253 254 keyfile : filepath
254 255 The file containing a key. If this is set, `key` will be initialized
255 256 to the contents of the file.
256 257
257 258 """
258 259
259 260 debug=Bool(False, config=True, help="""Debug output in the Session""")
260 261
261 262 packer = DottedObjectName('json',config=True,
262 263 help="""The name of the packer for serializing messages.
263 264 Should be one of 'json', 'pickle', or an import name
264 265 for a custom callable serializer.""")
265 266 def _packer_changed(self, name, old, new):
266 267 if new.lower() == 'json':
267 268 self.pack = json_packer
268 269 self.unpack = json_unpacker
269 270 self.unpacker = new
270 271 elif new.lower() == 'pickle':
271 272 self.pack = pickle_packer
272 273 self.unpack = pickle_unpacker
273 274 self.unpacker = new
274 275 else:
275 276 self.pack = import_item(str(new))
276 277
277 278 unpacker = DottedObjectName('json', config=True,
278 279 help="""The name of the unpacker for unserializing messages.
279 280 Only used with custom functions for `packer`.""")
280 281 def _unpacker_changed(self, name, old, new):
281 282 if new.lower() == 'json':
282 283 self.pack = json_packer
283 284 self.unpack = json_unpacker
284 285 self.packer = new
285 286 elif new.lower() == 'pickle':
286 287 self.pack = pickle_packer
287 288 self.unpack = pickle_unpacker
288 289 self.packer = new
289 290 else:
290 291 self.unpack = import_item(str(new))
291 292
292 293 session = CUnicode(u'', config=True,
293 294 help="""The UUID identifying this session.""")
294 295 def _session_default(self):
295 296 u = unicode_type(uuid.uuid4())
296 297 self.bsession = u.encode('ascii')
297 298 return u
298 299
299 300 def _session_changed(self, name, old, new):
300 301 self.bsession = self.session.encode('ascii')
301 302
302 303 # bsession is the session as bytes
303 304 bsession = CBytes(b'')
304 305
305 306 username = Unicode(str_to_unicode(os.environ.get('USER', 'username')),
306 307 help="""Username for the Session. Default is your system username.""",
307 308 config=True)
308 309
309 310 metadata = Dict({}, config=True,
310 311 help="""Metadata dictionary, which serves as the default top-level metadata dict for each message.""")
311 312
312 313 # if 0, no adapting to do.
313 314 adapt_version = Integer(0)
314 315
315 316 # message signature related traits:
316 317
317 318 key = CBytes(b'', config=True,
318 319 help="""execution key, for extra authentication.""")
319 320 def _key_changed(self):
320 321 self._new_auth()
321 322
322 323 signature_scheme = Unicode('hmac-sha256', config=True,
323 324 help="""The digest scheme used to construct the message signatures.
324 325 Must have the form 'hmac-HASH'.""")
325 326 def _signature_scheme_changed(self, name, old, new):
326 327 if not new.startswith('hmac-'):
327 328 raise TraitError("signature_scheme must start with 'hmac-', got %r" % new)
328 329 hash_name = new.split('-', 1)[1]
329 330 try:
330 331 self.digest_mod = getattr(hashlib, hash_name)
331 332 except AttributeError:
332 333 raise TraitError("hashlib has no such attribute: %s" % hash_name)
333 334 self._new_auth()
334 335
335 336 digest_mod = Any()
336 337 def _digest_mod_default(self):
337 338 return hashlib.sha256
338 339
339 340 auth = Instance(hmac.HMAC)
340 341
341 342 def _new_auth(self):
342 343 if self.key:
343 344 self.auth = hmac.HMAC(self.key, digestmod=self.digest_mod)
344 345 else:
345 346 self.auth = None
346 347
347 348 digest_history = Set()
348 349 digest_history_size = Integer(2**16, config=True,
349 350 help="""The maximum number of digests to remember.
350 351
351 352 The digest history will be culled when it exceeds this value.
352 353 """
353 354 )
354 355
355 356 keyfile = Unicode('', config=True,
356 357 help="""path to file containing execution key.""")
357 358 def _keyfile_changed(self, name, old, new):
358 359 with open(new, 'rb') as f:
359 360 self.key = f.read().strip()
360 361
361 362 # for protecting against sends from forks
362 363 pid = Integer()
363 364
364 365 # serialization traits:
365 366
366 367 pack = Any(default_packer) # the actual packer function
367 368 def _pack_changed(self, name, old, new):
368 369 if not callable(new):
369 370 raise TypeError("packer must be callable, not %s"%type(new))
370 371
371 372 unpack = Any(default_unpacker) # the actual packer function
372 373 def _unpack_changed(self, name, old, new):
373 374 # unpacker is not checked - it is assumed to be
374 375 if not callable(new):
375 376 raise TypeError("unpacker must be callable, not %s"%type(new))
376 377
377 378 # thresholds:
378 379 copy_threshold = Integer(2**16, config=True,
379 380 help="Threshold (in bytes) beyond which a buffer should be sent without copying.")
380 381 buffer_threshold = Integer(MAX_BYTES, config=True,
381 382 help="Threshold (in bytes) beyond which an object's buffer should be extracted to avoid pickling.")
382 383 item_threshold = Integer(MAX_ITEMS, config=True,
383 384 help="""The maximum number of items for a container to be introspected for custom serialization.
384 385 Containers larger than this are pickled outright.
385 386 """
386 387 )
387 388
388 389
389 390 def __init__(self, **kwargs):
390 391 """create a Session object
391 392
392 393 Parameters
393 394 ----------
394 395
395 396 debug : bool
396 397 whether to trigger extra debugging statements
397 398 packer/unpacker : str : 'json', 'pickle' or import_string
398 399 importstrings for methods to serialize message parts. If just
399 400 'json' or 'pickle', predefined JSON and pickle packers will be used.
400 401 Otherwise, the entire importstring must be used.
401 402
402 403 The functions must accept at least valid JSON input, and output
403 404 *bytes*.
404 405
405 406 For example, to use msgpack:
406 407 packer = 'msgpack.packb', unpacker='msgpack.unpackb'
407 408 pack/unpack : callables
408 409 You can also set the pack/unpack callables for serialization
409 410 directly.
410 411 session : unicode (must be ascii)
411 412 the ID of this Session object. The default is to generate a new
412 413 UUID.
413 414 bsession : bytes
414 415 The session as bytes
415 416 username : unicode
416 417 username added to message headers. The default is to ask the OS.
417 418 key : bytes
418 419 The key used to initialize an HMAC signature. If unset, messages
419 420 will not be signed or checked.
420 421 signature_scheme : str
421 422 The message digest scheme. Currently must be of the form 'hmac-HASH',
422 423 where 'HASH' is a hashing function available in Python's hashlib.
423 424 The default is 'hmac-sha256'.
424 425 This is ignored if 'key' is empty.
425 426 keyfile : filepath
426 427 The file containing a key. If this is set, `key` will be
427 428 initialized to the contents of the file.
428 429 """
429 430 super(Session, self).__init__(**kwargs)
430 431 self._check_packers()
431 432 self.none = self.pack({})
432 433 # ensure self._session_default() if necessary, so bsession is defined:
433 434 self.session
434 435 self.pid = os.getpid()
435 436
436 437 @property
437 438 def msg_id(self):
438 439 """always return new uuid"""
439 440 return str(uuid.uuid4())
440 441
441 442 def _check_packers(self):
442 443 """check packers for datetime support."""
443 444 pack = self.pack
444 445 unpack = self.unpack
445 446
446 447 # check simple serialization
447 448 msg = dict(a=[1,'hi'])
448 449 try:
449 450 packed = pack(msg)
450 451 except Exception as e:
451 452 msg = "packer '{packer}' could not serialize a simple message: {e}{jsonmsg}"
452 453 if self.packer == 'json':
453 454 jsonmsg = "\nzmq.utils.jsonapi.jsonmod = %s" % jsonapi.jsonmod
454 455 else:
455 456 jsonmsg = ""
456 457 raise ValueError(
457 458 msg.format(packer=self.packer, e=e, jsonmsg=jsonmsg)
458 459 )
459 460
460 461 # ensure packed message is bytes
461 462 if not isinstance(packed, bytes):
462 463 raise ValueError("message packed to %r, but bytes are required"%type(packed))
463 464
464 465 # check that unpack is pack's inverse
465 466 try:
466 467 unpacked = unpack(packed)
467 468 assert unpacked == msg
468 469 except Exception as e:
469 470 msg = "unpacker '{unpacker}' could not handle output from packer '{packer}': {e}{jsonmsg}"
470 471 if self.packer == 'json':
471 472 jsonmsg = "\nzmq.utils.jsonapi.jsonmod = %s" % jsonapi.jsonmod
472 473 else:
473 474 jsonmsg = ""
474 475 raise ValueError(
475 476 msg.format(packer=self.packer, unpacker=self.unpacker, e=e, jsonmsg=jsonmsg)
476 477 )
477 478
478 479 # check datetime support
479 480 msg = dict(t=datetime.now())
480 481 try:
481 482 unpacked = unpack(pack(msg))
482 483 if isinstance(unpacked['t'], datetime):
483 484 raise ValueError("Shouldn't deserialize to datetime")
484 485 except Exception:
485 486 self.pack = lambda o: pack(squash_dates(o))
486 487 self.unpack = lambda s: unpack(s)
487 488
488 489 def msg_header(self, msg_type):
489 490 return msg_header(self.msg_id, msg_type, self.username, self.session)
490 491
491 492 def msg(self, msg_type, content=None, parent=None, header=None, metadata=None):
492 493 """Return the nested message dict.
493 494
494 495 This format is different from what is sent over the wire. The
495 serialize/unserialize methods converts this nested message dict to the wire
496 serialize/deserialize methods converts this nested message dict to the wire
496 497 format, which is a list of message parts.
497 498 """
498 499 msg = {}
499 500 header = self.msg_header(msg_type) if header is None else header
500 501 msg['header'] = header
501 502 msg['msg_id'] = header['msg_id']
502 503 msg['msg_type'] = header['msg_type']
503 504 msg['parent_header'] = {} if parent is None else extract_header(parent)
504 505 msg['content'] = {} if content is None else content
505 506 msg['metadata'] = self.metadata.copy()
506 507 if metadata is not None:
507 508 msg['metadata'].update(metadata)
508 509 return msg
509 510
510 511 def sign(self, msg_list):
511 512 """Sign a message with HMAC digest. If no auth, return b''.
512 513
513 514 Parameters
514 515 ----------
515 516 msg_list : list
516 517 The [p_header,p_parent,p_content] part of the message list.
517 518 """
518 519 if self.auth is None:
519 520 return b''
520 521 h = self.auth.copy()
521 522 for m in msg_list:
522 523 h.update(m)
523 524 return str_to_bytes(h.hexdigest())
524 525
525 526 def serialize(self, msg, ident=None):
526 527 """Serialize the message components to bytes.
527 528
528 This is roughly the inverse of unserialize. The serialize/unserialize
529 This is roughly the inverse of deserialize. The serialize/deserialize
529 530 methods work with full message lists, whereas pack/unpack work with
530 531 the individual message parts in the message list.
531 532
532 533 Parameters
533 534 ----------
534 535 msg : dict or Message
535 536 The next message dict as returned by the self.msg method.
536 537
537 538 Returns
538 539 -------
539 540 msg_list : list
540 541 The list of bytes objects to be sent with the format::
541 542
542 543 [ident1, ident2, ..., DELIM, HMAC, p_header, p_parent,
543 544 p_metadata, p_content, buffer1, buffer2, ...]
544 545
545 546 In this list, the ``p_*`` entities are the packed or serialized
546 547 versions, so if JSON is used, these are utf8 encoded JSON strings.
547 548 """
548 549 content = msg.get('content', {})
549 550 if content is None:
550 551 content = self.none
551 552 elif isinstance(content, dict):
552 553 content = self.pack(content)
553 554 elif isinstance(content, bytes):
554 555 # content is already packed, as in a relayed message
555 556 pass
556 557 elif isinstance(content, unicode_type):
557 558 # should be bytes, but JSON often spits out unicode
558 559 content = content.encode('utf8')
559 560 else:
560 561 raise TypeError("Content incorrect type: %s"%type(content))
561 562
562 563 real_message = [self.pack(msg['header']),
563 564 self.pack(msg['parent_header']),
564 565 self.pack(msg['metadata']),
565 566 content,
566 567 ]
567 568
568 569 to_send = []
569 570
570 571 if isinstance(ident, list):
571 572 # accept list of idents
572 573 to_send.extend(ident)
573 574 elif ident is not None:
574 575 to_send.append(ident)
575 576 to_send.append(DELIM)
576 577
577 578 signature = self.sign(real_message)
578 579 to_send.append(signature)
579 580
580 581 to_send.extend(real_message)
581 582
582 583 return to_send
583 584
584 585 def send(self, stream, msg_or_type, content=None, parent=None, ident=None,
585 586 buffers=None, track=False, header=None, metadata=None):
586 587 """Build and send a message via stream or socket.
587 588
588 589 The message format used by this function internally is as follows:
589 590
590 591 [ident1,ident2,...,DELIM,HMAC,p_header,p_parent,p_content,
591 592 buffer1,buffer2,...]
592 593
593 The serialize/unserialize methods convert the nested message dict into this
594 The serialize/deserialize methods convert the nested message dict into this
594 595 format.
595 596
596 597 Parameters
597 598 ----------
598 599
599 600 stream : zmq.Socket or ZMQStream
600 601 The socket-like object used to send the data.
601 602 msg_or_type : str or Message/dict
602 603 Normally, msg_or_type will be a msg_type unless a message is being
603 604 sent more than once. If a header is supplied, this can be set to
604 605 None and the msg_type will be pulled from the header.
605 606
606 607 content : dict or None
607 608 The content of the message (ignored if msg_or_type is a message).
608 609 header : dict or None
609 610 The header dict for the message (ignored if msg_to_type is a message).
610 611 parent : Message or dict or None
611 612 The parent or parent header describing the parent of this message
612 613 (ignored if msg_or_type is a message).
613 614 ident : bytes or list of bytes
614 615 The zmq.IDENTITY routing path.
615 616 metadata : dict or None
616 617 The metadata describing the message
617 618 buffers : list or None
618 619 The already-serialized buffers to be appended to the message.
619 620 track : bool
620 621 Whether to track. Only for use with Sockets, because ZMQStream
621 622 objects cannot track messages.
622 623
623 624
624 625 Returns
625 626 -------
626 627 msg : dict
627 628 The constructed message.
628 629 """
629 630 if not isinstance(stream, zmq.Socket):
630 631 # ZMQStreams and dummy sockets do not support tracking.
631 632 track = False
632 633
633 634 if isinstance(msg_or_type, (Message, dict)):
634 635 # We got a Message or message dict, not a msg_type so don't
635 636 # build a new Message.
636 637 msg = msg_or_type
638 buffers = buffers or msg.get('buffers', [])
637 639 else:
638 640 msg = self.msg(msg_or_type, content=content, parent=parent,
639 641 header=header, metadata=metadata)
640 642 if not os.getpid() == self.pid:
641 643 io.rprint("WARNING: attempted to send message from fork")
642 644 io.rprint(msg)
643 645 return
644 646 buffers = [] if buffers is None else buffers
645 647 if self.adapt_version:
646 648 msg = adapt(msg, self.adapt_version)
647 649 to_send = self.serialize(msg, ident)
648 650 to_send.extend(buffers)
649 651 longest = max([ len(s) for s in to_send ])
650 652 copy = (longest < self.copy_threshold)
651 653
652 654 if buffers and track and not copy:
653 655 # only really track when we are doing zero-copy buffers
654 656 tracker = stream.send_multipart(to_send, copy=False, track=True)
655 657 else:
656 658 # use dummy tracker, which will be done immediately
657 659 tracker = DONE
658 660 stream.send_multipart(to_send, copy=copy)
659 661
660 662 if self.debug:
661 663 pprint.pprint(msg)
662 664 pprint.pprint(to_send)
663 665 pprint.pprint(buffers)
664 666
665 667 msg['tracker'] = tracker
666 668
667 669 return msg
668 670
669 671 def send_raw(self, stream, msg_list, flags=0, copy=True, ident=None):
670 672 """Send a raw message via ident path.
671 673
672 674 This method is used to send a already serialized message.
673 675
674 676 Parameters
675 677 ----------
676 678 stream : ZMQStream or Socket
677 679 The ZMQ stream or socket to use for sending the message.
678 680 msg_list : list
679 681 The serialized list of messages to send. This only includes the
680 682 [p_header,p_parent,p_metadata,p_content,buffer1,buffer2,...] portion of
681 683 the message.
682 684 ident : ident or list
683 685 A single ident or a list of idents to use in sending.
684 686 """
685 687 to_send = []
686 688 if isinstance(ident, bytes):
687 689 ident = [ident]
688 690 if ident is not None:
689 691 to_send.extend(ident)
690 692
691 693 to_send.append(DELIM)
692 694 to_send.append(self.sign(msg_list))
693 695 to_send.extend(msg_list)
694 696 stream.send_multipart(to_send, flags, copy=copy)
695 697
696 698 def recv(self, socket, mode=zmq.NOBLOCK, content=True, copy=True):
697 699 """Receive and unpack a message.
698 700
699 701 Parameters
700 702 ----------
701 703 socket : ZMQStream or Socket
702 704 The socket or stream to use in receiving.
703 705
704 706 Returns
705 707 -------
706 708 [idents], msg
707 709 [idents] is a list of idents and msg is a nested message dict of
708 710 same format as self.msg returns.
709 711 """
710 712 if isinstance(socket, ZMQStream):
711 713 socket = socket.socket
712 714 try:
713 715 msg_list = socket.recv_multipart(mode, copy=copy)
714 716 except zmq.ZMQError as e:
715 717 if e.errno == zmq.EAGAIN:
716 718 # We can convert EAGAIN to None as we know in this case
717 719 # recv_multipart won't return None.
718 720 return None,None
719 721 else:
720 722 raise
721 723 # split multipart message into identity list and message dict
722 724 # invalid large messages can cause very expensive string comparisons
723 725 idents, msg_list = self.feed_identities(msg_list, copy)
724 726 try:
725 return idents, self.unserialize(msg_list, content=content, copy=copy)
727 return idents, self.deserialize(msg_list, content=content, copy=copy)
726 728 except Exception as e:
727 729 # TODO: handle it
728 730 raise e
729 731
730 732 def feed_identities(self, msg_list, copy=True):
731 733 """Split the identities from the rest of the message.
732 734
733 735 Feed until DELIM is reached, then return the prefix as idents and
734 736 remainder as msg_list. This is easily broken by setting an IDENT to DELIM,
735 737 but that would be silly.
736 738
737 739 Parameters
738 740 ----------
739 741 msg_list : a list of Message or bytes objects
740 742 The message to be split.
741 743 copy : bool
742 744 flag determining whether the arguments are bytes or Messages
743 745
744 746 Returns
745 747 -------
746 748 (idents, msg_list) : two lists
747 749 idents will always be a list of bytes, each of which is a ZMQ
748 750 identity. msg_list will be a list of bytes or zmq.Messages of the
749 751 form [HMAC,p_header,p_parent,p_content,buffer1,buffer2,...] and
750 should be unpackable/unserializable via self.unserialize at this
752 should be unpackable/unserializable via self.deserialize at this
751 753 point.
752 754 """
753 755 if copy:
754 756 idx = msg_list.index(DELIM)
755 757 return msg_list[:idx], msg_list[idx+1:]
756 758 else:
757 759 failed = True
758 760 for idx,m in enumerate(msg_list):
759 761 if m.bytes == DELIM:
760 762 failed = False
761 763 break
762 764 if failed:
763 765 raise ValueError("DELIM not in msg_list")
764 766 idents, msg_list = msg_list[:idx], msg_list[idx+1:]
765 767 return [m.bytes for m in idents], msg_list
766 768
767 769 def _add_digest(self, signature):
768 770 """add a digest to history to protect against replay attacks"""
769 771 if self.digest_history_size == 0:
770 772 # no history, never add digests
771 773 return
772 774
773 775 self.digest_history.add(signature)
774 776 if len(self.digest_history) > self.digest_history_size:
775 777 # threshold reached, cull 10%
776 778 self._cull_digest_history()
777 779
778 780 def _cull_digest_history(self):
779 781 """cull the digest history
780 782
781 783 Removes a randomly selected 10% of the digest history
782 784 """
783 785 current = len(self.digest_history)
784 786 n_to_cull = max(int(current // 10), current - self.digest_history_size)
785 787 if n_to_cull >= current:
786 788 self.digest_history = set()
787 789 return
788 790 to_cull = random.sample(self.digest_history, n_to_cull)
789 791 self.digest_history.difference_update(to_cull)
790 792
791 def unserialize(self, msg_list, content=True, copy=True):
793 def deserialize(self, msg_list, content=True, copy=True):
792 794 """Unserialize a msg_list to a nested message dict.
793 795
794 This is roughly the inverse of serialize. The serialize/unserialize
796 This is roughly the inverse of serialize. The serialize/deserialize
795 797 methods work with full message lists, whereas pack/unpack work with
796 798 the individual message parts in the message list.
797 799
798 800 Parameters
799 801 ----------
800 802 msg_list : list of bytes or Message objects
801 803 The list of message parts of the form [HMAC,p_header,p_parent,
802 804 p_metadata,p_content,buffer1,buffer2,...].
803 805 content : bool (True)
804 806 Whether to unpack the content dict (True), or leave it packed
805 807 (False).
806 808 copy : bool (True)
807 809 Whether to return the bytes (True), or the non-copying Message
808 810 object in each place (False).
809 811
810 812 Returns
811 813 -------
812 814 msg : dict
813 815 The nested message dict with top-level keys [header, parent_header,
814 816 content, buffers].
815 817 """
816 818 minlen = 5
817 819 message = {}
818 820 if not copy:
819 821 for i in range(minlen):
820 822 msg_list[i] = msg_list[i].bytes
821 823 if self.auth is not None:
822 824 signature = msg_list[0]
823 825 if not signature:
824 826 raise ValueError("Unsigned Message")
825 827 if signature in self.digest_history:
826 828 raise ValueError("Duplicate Signature: %r" % signature)
827 829 self._add_digest(signature)
828 830 check = self.sign(msg_list[1:5])
829 831 if not compare_digest(signature, check):
830 832 raise ValueError("Invalid Signature: %r" % signature)
831 833 if not len(msg_list) >= minlen:
832 834 raise TypeError("malformed message, must have at least %i elements"%minlen)
833 835 header = self.unpack(msg_list[1])
834 836 message['header'] = extract_dates(header)
835 837 message['msg_id'] = header['msg_id']
836 838 message['msg_type'] = header['msg_type']
837 839 message['parent_header'] = extract_dates(self.unpack(msg_list[2]))
838 840 message['metadata'] = self.unpack(msg_list[3])
839 841 if content:
840 842 message['content'] = self.unpack(msg_list[4])
841 843 else:
842 844 message['content'] = msg_list[4]
843 845
844 846 message['buffers'] = msg_list[5:]
845 # print("received: %s: %s\n %s" % (message['msg_type'], message['header'], message['content']))
846 847 # adapt to the current version
847 848 return adapt(message)
848 # print("adapted: %s: %s\n %s" % (adapted['msg_type'], adapted['header'], adapted['content']))
849
850 def unserialize(self, *args, **kwargs):
851 warnings.warn(
852 "Session.unserialize is deprecated. Use Session.deserialize.",
853 DeprecationWarning,
854 )
855 return self.deserialize(*args, **kwargs)
856
849 857
850 858 def test_msg2obj():
851 859 am = dict(x=1)
852 860 ao = Message(am)
853 861 assert ao.x == am['x']
854 862
855 863 am['y'] = dict(z=1)
856 864 ao = Message(am)
857 865 assert ao.y.z == am['y']['z']
858 866
859 867 k1, k2 = 'y', 'z'
860 868 assert ao[k1][k2] == am[k1][k2]
861 869
862 870 am2 = dict(ao)
863 871 assert am['x'] == am2['x']
864 872 assert am['y']['z'] == am2['y']['z']
865 873
@@ -1,208 +1,208
1 1 """test serialization tools"""
2 2
3 3 # Copyright (c) IPython Development Team.
4 4 # Distributed under the terms of the Modified BSD License.
5 5
6 6 import pickle
7 7 from collections import namedtuple
8 8
9 9 import nose.tools as nt
10 10
11 11 # from unittest import TestCaes
12 from IPython.kernel.zmq.serialize import serialize_object, unserialize_object
12 from IPython.kernel.zmq.serialize import serialize_object, deserialize_object
13 13 from IPython.testing import decorators as dec
14 14 from IPython.utils.pickleutil import CannedArray, CannedClass
15 15 from IPython.utils.py3compat import iteritems
16 16 from IPython.parallel import interactive
17 17
18 18 #-------------------------------------------------------------------------------
19 19 # Globals and Utilities
20 20 #-------------------------------------------------------------------------------
21 21
22 22 def roundtrip(obj):
23 23 """roundtrip an object through serialization"""
24 24 bufs = serialize_object(obj)
25 obj2, remainder = unserialize_object(bufs)
25 obj2, remainder = deserialize_object(bufs)
26 26 nt.assert_equals(remainder, [])
27 27 return obj2
28 28
29 29 class C(object):
30 30 """dummy class for """
31 31
32 32 def __init__(self, **kwargs):
33 33 for key,value in iteritems(kwargs):
34 34 setattr(self, key, value)
35 35
36 36 SHAPES = ((100,), (1024,10), (10,8,6,5), (), (0,))
37 37 DTYPES = ('uint8', 'float64', 'int32', [('g', 'float32')], '|S10')
38 38
39 39 #-------------------------------------------------------------------------------
40 40 # Tests
41 41 #-------------------------------------------------------------------------------
42 42
43 43 def new_array(shape, dtype):
44 44 import numpy
45 45 return numpy.random.random(shape).astype(dtype)
46 46
47 47 def test_roundtrip_simple():
48 48 for obj in [
49 49 'hello',
50 50 dict(a='b', b=10),
51 51 [1,2,'hi'],
52 52 (b'123', 'hello'),
53 53 ]:
54 54 obj2 = roundtrip(obj)
55 55 nt.assert_equal(obj, obj2)
56 56
57 57 def test_roundtrip_nested():
58 58 for obj in [
59 59 dict(a=range(5), b={1:b'hello'}),
60 60 [range(5),[range(3),(1,[b'whoda'])]],
61 61 ]:
62 62 obj2 = roundtrip(obj)
63 63 nt.assert_equal(obj, obj2)
64 64
65 65 def test_roundtrip_buffered():
66 66 for obj in [
67 67 dict(a=b"x"*1025),
68 68 b"hello"*500,
69 69 [b"hello"*501, 1,2,3]
70 70 ]:
71 71 bufs = serialize_object(obj)
72 72 nt.assert_equal(len(bufs), 2)
73 obj2, remainder = unserialize_object(bufs)
73 obj2, remainder = deserialize_object(bufs)
74 74 nt.assert_equal(remainder, [])
75 75 nt.assert_equal(obj, obj2)
76 76
77 77 @dec.skip_without('numpy')
78 78 def test_numpy():
79 79 import numpy
80 80 from numpy.testing.utils import assert_array_equal
81 81 for shape in SHAPES:
82 82 for dtype in DTYPES:
83 83 A = new_array(shape, dtype=dtype)
84 84 bufs = serialize_object(A)
85 B, r = unserialize_object(bufs)
85 B, r = deserialize_object(bufs)
86 86 nt.assert_equal(r, [])
87 87 nt.assert_equal(A.shape, B.shape)
88 88 nt.assert_equal(A.dtype, B.dtype)
89 89 assert_array_equal(A,B)
90 90
91 91 @dec.skip_without('numpy')
92 92 def test_recarray():
93 93 import numpy
94 94 from numpy.testing.utils import assert_array_equal
95 95 for shape in SHAPES:
96 96 for dtype in [
97 97 [('f', float), ('s', '|S10')],
98 98 [('n', int), ('s', '|S1'), ('u', 'uint32')],
99 99 ]:
100 100 A = new_array(shape, dtype=dtype)
101 101
102 102 bufs = serialize_object(A)
103 B, r = unserialize_object(bufs)
103 B, r = deserialize_object(bufs)
104 104 nt.assert_equal(r, [])
105 105 nt.assert_equal(A.shape, B.shape)
106 106 nt.assert_equal(A.dtype, B.dtype)
107 107 assert_array_equal(A,B)
108 108
109 109 @dec.skip_without('numpy')
110 110 def test_numpy_in_seq():
111 111 import numpy
112 112 from numpy.testing.utils import assert_array_equal
113 113 for shape in SHAPES:
114 114 for dtype in DTYPES:
115 115 A = new_array(shape, dtype=dtype)
116 116 bufs = serialize_object((A,1,2,b'hello'))
117 117 canned = pickle.loads(bufs[0])
118 118 nt.assert_is_instance(canned[0], CannedArray)
119 tup, r = unserialize_object(bufs)
119 tup, r = deserialize_object(bufs)
120 120 B = tup[0]
121 121 nt.assert_equal(r, [])
122 122 nt.assert_equal(A.shape, B.shape)
123 123 nt.assert_equal(A.dtype, B.dtype)
124 124 assert_array_equal(A,B)
125 125
126 126 @dec.skip_without('numpy')
127 127 def test_numpy_in_dict():
128 128 import numpy
129 129 from numpy.testing.utils import assert_array_equal
130 130 for shape in SHAPES:
131 131 for dtype in DTYPES:
132 132 A = new_array(shape, dtype=dtype)
133 133 bufs = serialize_object(dict(a=A,b=1,c=range(20)))
134 134 canned = pickle.loads(bufs[0])
135 135 nt.assert_is_instance(canned['a'], CannedArray)
136 d, r = unserialize_object(bufs)
136 d, r = deserialize_object(bufs)
137 137 B = d['a']
138 138 nt.assert_equal(r, [])
139 139 nt.assert_equal(A.shape, B.shape)
140 140 nt.assert_equal(A.dtype, B.dtype)
141 141 assert_array_equal(A,B)
142 142
143 143 def test_class():
144 144 @interactive
145 145 class C(object):
146 146 a=5
147 147 bufs = serialize_object(dict(C=C))
148 148 canned = pickle.loads(bufs[0])
149 149 nt.assert_is_instance(canned['C'], CannedClass)
150 d, r = unserialize_object(bufs)
150 d, r = deserialize_object(bufs)
151 151 C2 = d['C']
152 152 nt.assert_equal(C2.a, C.a)
153 153
154 154 def test_class_oldstyle():
155 155 @interactive
156 156 class C:
157 157 a=5
158 158
159 159 bufs = serialize_object(dict(C=C))
160 160 canned = pickle.loads(bufs[0])
161 161 nt.assert_is_instance(canned['C'], CannedClass)
162 d, r = unserialize_object(bufs)
162 d, r = deserialize_object(bufs)
163 163 C2 = d['C']
164 164 nt.assert_equal(C2.a, C.a)
165 165
166 166 def test_tuple():
167 167 tup = (lambda x:x, 1)
168 168 bufs = serialize_object(tup)
169 169 canned = pickle.loads(bufs[0])
170 170 nt.assert_is_instance(canned, tuple)
171 t2, r = unserialize_object(bufs)
171 t2, r = deserialize_object(bufs)
172 172 nt.assert_equal(t2[0](t2[1]), tup[0](tup[1]))
173 173
174 174 point = namedtuple('point', 'x y')
175 175
176 176 def test_namedtuple():
177 177 p = point(1,2)
178 178 bufs = serialize_object(p)
179 179 canned = pickle.loads(bufs[0])
180 180 nt.assert_is_instance(canned, point)
181 p2, r = unserialize_object(bufs, globals())
181 p2, r = deserialize_object(bufs, globals())
182 182 nt.assert_equal(p2.x, p.x)
183 183 nt.assert_equal(p2.y, p.y)
184 184
185 185 def test_list():
186 186 lis = [lambda x:x, 1]
187 187 bufs = serialize_object(lis)
188 188 canned = pickle.loads(bufs[0])
189 189 nt.assert_is_instance(canned, list)
190 l2, r = unserialize_object(bufs)
190 l2, r = deserialize_object(bufs)
191 191 nt.assert_equal(l2[0](l2[1]), lis[0](lis[1]))
192 192
193 193 def test_class_inheritance():
194 194 @interactive
195 195 class C(object):
196 196 a=5
197 197
198 198 @interactive
199 199 class D(C):
200 200 b=10
201 201
202 202 bufs = serialize_object(dict(D=D))
203 203 canned = pickle.loads(bufs[0])
204 204 nt.assert_is_instance(canned['D'], CannedClass)
205 d, r = unserialize_object(bufs)
205 d, r = deserialize_object(bufs)
206 206 D2 = d['D']
207 207 nt.assert_equal(D2.a, D.a)
208 208 nt.assert_equal(D2.b, D.b)
@@ -1,318 +1,318
1 1 """test building messages with streamsession"""
2 2
3 3 #-------------------------------------------------------------------------------
4 4 # Copyright (C) 2011 The IPython Development Team
5 5 #
6 6 # Distributed under the terms of the BSD License. The full license is in
7 7 # the file COPYING, distributed as part of this software.
8 8 #-------------------------------------------------------------------------------
9 9
10 10 #-------------------------------------------------------------------------------
11 11 # Imports
12 12 #-------------------------------------------------------------------------------
13 13
14 14 import os
15 15 import uuid
16 16 from datetime import datetime
17 17
18 18 import zmq
19 19
20 20 from zmq.tests import BaseZMQTestCase
21 21 from zmq.eventloop.zmqstream import ZMQStream
22 22
23 23 from IPython.kernel.zmq import session as ss
24 24
25 25 from IPython.testing.decorators import skipif, module_not_available
26 26 from IPython.utils.py3compat import string_types
27 27 from IPython.utils import jsonutil
28 28
29 29 def _bad_packer(obj):
30 30 raise TypeError("I don't work")
31 31
32 32 def _bad_unpacker(bytes):
33 33 raise TypeError("I don't work either")
34 34
35 35 class SessionTestCase(BaseZMQTestCase):
36 36
37 37 def setUp(self):
38 38 BaseZMQTestCase.setUp(self)
39 39 self.session = ss.Session()
40 40
41 41
42 42 class TestSession(SessionTestCase):
43 43
44 44 def test_msg(self):
45 45 """message format"""
46 46 msg = self.session.msg('execute')
47 47 thekeys = set('header parent_header metadata content msg_type msg_id'.split())
48 48 s = set(msg.keys())
49 49 self.assertEqual(s, thekeys)
50 50 self.assertTrue(isinstance(msg['content'],dict))
51 51 self.assertTrue(isinstance(msg['metadata'],dict))
52 52 self.assertTrue(isinstance(msg['header'],dict))
53 53 self.assertTrue(isinstance(msg['parent_header'],dict))
54 54 self.assertTrue(isinstance(msg['msg_id'],str))
55 55 self.assertTrue(isinstance(msg['msg_type'],str))
56 56 self.assertEqual(msg['header']['msg_type'], 'execute')
57 57 self.assertEqual(msg['msg_type'], 'execute')
58 58
59 59 def test_serialize(self):
60 60 msg = self.session.msg('execute', content=dict(a=10, b=1.1))
61 61 msg_list = self.session.serialize(msg, ident=b'foo')
62 62 ident, msg_list = self.session.feed_identities(msg_list)
63 new_msg = self.session.unserialize(msg_list)
63 new_msg = self.session.deserialize(msg_list)
64 64 self.assertEqual(ident[0], b'foo')
65 65 self.assertEqual(new_msg['msg_id'],msg['msg_id'])
66 66 self.assertEqual(new_msg['msg_type'],msg['msg_type'])
67 67 self.assertEqual(new_msg['header'],msg['header'])
68 68 self.assertEqual(new_msg['content'],msg['content'])
69 69 self.assertEqual(new_msg['parent_header'],msg['parent_header'])
70 70 self.assertEqual(new_msg['metadata'],msg['metadata'])
71 71 # ensure floats don't come out as Decimal:
72 72 self.assertEqual(type(new_msg['content']['b']),type(new_msg['content']['b']))
73 73
74 74 def test_send(self):
75 75 ctx = zmq.Context.instance()
76 76 A = ctx.socket(zmq.PAIR)
77 77 B = ctx.socket(zmq.PAIR)
78 78 A.bind("inproc://test")
79 79 B.connect("inproc://test")
80 80
81 81 msg = self.session.msg('execute', content=dict(a=10))
82 82 self.session.send(A, msg, ident=b'foo', buffers=[b'bar'])
83 83
84 84 ident, msg_list = self.session.feed_identities(B.recv_multipart())
85 new_msg = self.session.unserialize(msg_list)
85 new_msg = self.session.deserialize(msg_list)
86 86 self.assertEqual(ident[0], b'foo')
87 87 self.assertEqual(new_msg['msg_id'],msg['msg_id'])
88 88 self.assertEqual(new_msg['msg_type'],msg['msg_type'])
89 89 self.assertEqual(new_msg['header'],msg['header'])
90 90 self.assertEqual(new_msg['content'],msg['content'])
91 91 self.assertEqual(new_msg['parent_header'],msg['parent_header'])
92 92 self.assertEqual(new_msg['metadata'],msg['metadata'])
93 93 self.assertEqual(new_msg['buffers'],[b'bar'])
94 94
95 95 content = msg['content']
96 96 header = msg['header']
97 97 parent = msg['parent_header']
98 98 metadata = msg['metadata']
99 99 msg_type = header['msg_type']
100 100 self.session.send(A, None, content=content, parent=parent,
101 101 header=header, metadata=metadata, ident=b'foo', buffers=[b'bar'])
102 102 ident, msg_list = self.session.feed_identities(B.recv_multipart())
103 new_msg = self.session.unserialize(msg_list)
103 new_msg = self.session.deserialize(msg_list)
104 104 self.assertEqual(ident[0], b'foo')
105 105 self.assertEqual(new_msg['msg_id'],msg['msg_id'])
106 106 self.assertEqual(new_msg['msg_type'],msg['msg_type'])
107 107 self.assertEqual(new_msg['header'],msg['header'])
108 108 self.assertEqual(new_msg['content'],msg['content'])
109 109 self.assertEqual(new_msg['metadata'],msg['metadata'])
110 110 self.assertEqual(new_msg['parent_header'],msg['parent_header'])
111 111 self.assertEqual(new_msg['buffers'],[b'bar'])
112 112
113 113 self.session.send(A, msg, ident=b'foo', buffers=[b'bar'])
114 114 ident, new_msg = self.session.recv(B)
115 115 self.assertEqual(ident[0], b'foo')
116 116 self.assertEqual(new_msg['msg_id'],msg['msg_id'])
117 117 self.assertEqual(new_msg['msg_type'],msg['msg_type'])
118 118 self.assertEqual(new_msg['header'],msg['header'])
119 119 self.assertEqual(new_msg['content'],msg['content'])
120 120 self.assertEqual(new_msg['metadata'],msg['metadata'])
121 121 self.assertEqual(new_msg['parent_header'],msg['parent_header'])
122 122 self.assertEqual(new_msg['buffers'],[b'bar'])
123 123
124 124 A.close()
125 125 B.close()
126 126 ctx.term()
127 127
128 128 def test_args(self):
129 129 """initialization arguments for Session"""
130 130 s = self.session
131 131 self.assertTrue(s.pack is ss.default_packer)
132 132 self.assertTrue(s.unpack is ss.default_unpacker)
133 133 self.assertEqual(s.username, os.environ.get('USER', u'username'))
134 134
135 135 s = ss.Session()
136 136 self.assertEqual(s.username, os.environ.get('USER', u'username'))
137 137
138 138 self.assertRaises(TypeError, ss.Session, pack='hi')
139 139 self.assertRaises(TypeError, ss.Session, unpack='hi')
140 140 u = str(uuid.uuid4())
141 141 s = ss.Session(username=u'carrot', session=u)
142 142 self.assertEqual(s.session, u)
143 143 self.assertEqual(s.username, u'carrot')
144 144
145 145 def test_tracking(self):
146 146 """test tracking messages"""
147 147 a,b = self.create_bound_pair(zmq.PAIR, zmq.PAIR)
148 148 s = self.session
149 149 s.copy_threshold = 1
150 150 stream = ZMQStream(a)
151 151 msg = s.send(a, 'hello', track=False)
152 152 self.assertTrue(msg['tracker'] is ss.DONE)
153 153 msg = s.send(a, 'hello', track=True)
154 154 self.assertTrue(isinstance(msg['tracker'], zmq.MessageTracker))
155 155 M = zmq.Message(b'hi there', track=True)
156 156 msg = s.send(a, 'hello', buffers=[M], track=True)
157 157 t = msg['tracker']
158 158 self.assertTrue(isinstance(t, zmq.MessageTracker))
159 159 self.assertRaises(zmq.NotDone, t.wait, .1)
160 160 del M
161 161 t.wait(1) # this will raise
162 162
163 163
164 164 def test_unique_msg_ids(self):
165 165 """test that messages receive unique ids"""
166 166 ids = set()
167 167 for i in range(2**12):
168 168 h = self.session.msg_header('test')
169 169 msg_id = h['msg_id']
170 170 self.assertTrue(msg_id not in ids)
171 171 ids.add(msg_id)
172 172
173 173 def test_feed_identities(self):
174 174 """scrub the front for zmq IDENTITIES"""
175 175 theids = "engine client other".split()
176 176 content = dict(code='whoda',stuff=object())
177 177 themsg = self.session.msg('execute',content=content)
178 178 pmsg = theids
179 179
180 180 def test_session_id(self):
181 181 session = ss.Session()
182 182 # get bs before us
183 183 bs = session.bsession
184 184 us = session.session
185 185 self.assertEqual(us.encode('ascii'), bs)
186 186 session = ss.Session()
187 187 # get us before bs
188 188 us = session.session
189 189 bs = session.bsession
190 190 self.assertEqual(us.encode('ascii'), bs)
191 191 # change propagates:
192 192 session.session = 'something else'
193 193 bs = session.bsession
194 194 us = session.session
195 195 self.assertEqual(us.encode('ascii'), bs)
196 196 session = ss.Session(session='stuff')
197 197 # get us before bs
198 198 self.assertEqual(session.bsession, session.session.encode('ascii'))
199 199 self.assertEqual(b'stuff', session.bsession)
200 200
201 201 def test_zero_digest_history(self):
202 202 session = ss.Session(digest_history_size=0)
203 203 for i in range(11):
204 204 session._add_digest(uuid.uuid4().bytes)
205 205 self.assertEqual(len(session.digest_history), 0)
206 206
207 207 def test_cull_digest_history(self):
208 208 session = ss.Session(digest_history_size=100)
209 209 for i in range(100):
210 210 session._add_digest(uuid.uuid4().bytes)
211 211 self.assertTrue(len(session.digest_history) == 100)
212 212 session._add_digest(uuid.uuid4().bytes)
213 213 self.assertTrue(len(session.digest_history) == 91)
214 214 for i in range(9):
215 215 session._add_digest(uuid.uuid4().bytes)
216 216 self.assertTrue(len(session.digest_history) == 100)
217 217 session._add_digest(uuid.uuid4().bytes)
218 218 self.assertTrue(len(session.digest_history) == 91)
219 219
220 220 def test_bad_pack(self):
221 221 try:
222 222 session = ss.Session(pack=_bad_packer)
223 223 except ValueError as e:
224 224 self.assertIn("could not serialize", str(e))
225 225 self.assertIn("don't work", str(e))
226 226 else:
227 227 self.fail("Should have raised ValueError")
228 228
229 229 def test_bad_unpack(self):
230 230 try:
231 231 session = ss.Session(unpack=_bad_unpacker)
232 232 except ValueError as e:
233 233 self.assertIn("could not handle output", str(e))
234 234 self.assertIn("don't work either", str(e))
235 235 else:
236 236 self.fail("Should have raised ValueError")
237 237
238 238 def test_bad_packer(self):
239 239 try:
240 240 session = ss.Session(packer=__name__ + '._bad_packer')
241 241 except ValueError as e:
242 242 self.assertIn("could not serialize", str(e))
243 243 self.assertIn("don't work", str(e))
244 244 else:
245 245 self.fail("Should have raised ValueError")
246 246
247 247 def test_bad_unpacker(self):
248 248 try:
249 249 session = ss.Session(unpacker=__name__ + '._bad_unpacker')
250 250 except ValueError as e:
251 251 self.assertIn("could not handle output", str(e))
252 252 self.assertIn("don't work either", str(e))
253 253 else:
254 254 self.fail("Should have raised ValueError")
255 255
256 256 def test_bad_roundtrip(self):
257 257 with self.assertRaises(ValueError):
258 258 session = ss.Session(unpack=lambda b: 5)
259 259
260 260 def _datetime_test(self, session):
261 261 content = dict(t=datetime.now())
262 262 metadata = dict(t=datetime.now())
263 263 p = session.msg('msg')
264 264 msg = session.msg('msg', content=content, metadata=metadata, parent=p['header'])
265 265 smsg = session.serialize(msg)
266 msg2 = session.unserialize(session.feed_identities(smsg)[1])
266 msg2 = session.deserialize(session.feed_identities(smsg)[1])
267 267 assert isinstance(msg2['header']['date'], datetime)
268 268 self.assertEqual(msg['header'], msg2['header'])
269 269 self.assertEqual(msg['parent_header'], msg2['parent_header'])
270 270 self.assertEqual(msg['parent_header'], msg2['parent_header'])
271 271 assert isinstance(msg['content']['t'], datetime)
272 272 assert isinstance(msg['metadata']['t'], datetime)
273 273 assert isinstance(msg2['content']['t'], string_types)
274 274 assert isinstance(msg2['metadata']['t'], string_types)
275 275 self.assertEqual(msg['content'], jsonutil.extract_dates(msg2['content']))
276 276 self.assertEqual(msg['content'], jsonutil.extract_dates(msg2['content']))
277 277
278 278 def test_datetimes(self):
279 279 self._datetime_test(self.session)
280 280
281 281 def test_datetimes_pickle(self):
282 282 session = ss.Session(packer='pickle')
283 283 self._datetime_test(session)
284 284
285 285 @skipif(module_not_available('msgpack'))
286 286 def test_datetimes_msgpack(self):
287 287 import msgpack
288 288
289 289 session = ss.Session(
290 290 pack=msgpack.packb,
291 291 unpack=lambda buf: msgpack.unpackb(buf, encoding='utf8'),
292 292 )
293 293 self._datetime_test(session)
294 294
295 295 def test_send_raw(self):
296 296 ctx = zmq.Context.instance()
297 297 A = ctx.socket(zmq.PAIR)
298 298 B = ctx.socket(zmq.PAIR)
299 299 A.bind("inproc://test")
300 300 B.connect("inproc://test")
301 301
302 302 msg = self.session.msg('execute', content=dict(a=10))
303 303 msg_list = [self.session.pack(msg[part]) for part in
304 304 ['header', 'parent_header', 'metadata', 'content']]
305 305 self.session.send_raw(A, msg_list, ident=b'foo')
306 306
307 307 ident, new_msg_list = self.session.feed_identities(B.recv_multipart())
308 new_msg = self.session.unserialize(new_msg_list)
308 new_msg = self.session.deserialize(new_msg_list)
309 309 self.assertEqual(ident[0], b'foo')
310 310 self.assertEqual(new_msg['msg_type'],msg['msg_type'])
311 311 self.assertEqual(new_msg['header'],msg['header'])
312 312 self.assertEqual(new_msg['parent_header'],msg['parent_header'])
313 313 self.assertEqual(new_msg['content'],msg['content'])
314 314 self.assertEqual(new_msg['metadata'],msg['metadata'])
315 315
316 316 A.close()
317 317 B.close()
318 318 ctx.term()
@@ -1,1893 +1,1893
1 1 """A semi-synchronous Client for IPython parallel"""
2 2
3 3 # Copyright (c) IPython Development Team.
4 4 # Distributed under the terms of the Modified BSD License.
5 5
6 6 from __future__ import print_function
7 7
8 8 import os
9 9 import json
10 10 import sys
11 11 from threading import Thread, Event
12 12 import time
13 13 import warnings
14 14 from datetime import datetime
15 15 from getpass import getpass
16 16 from pprint import pprint
17 17
18 18 pjoin = os.path.join
19 19
20 20 import zmq
21 21
22 22 from IPython.config.configurable import MultipleInstanceError
23 23 from IPython.core.application import BaseIPythonApplication
24 24 from IPython.core.profiledir import ProfileDir, ProfileDirError
25 25
26 26 from IPython.utils.capture import RichOutput
27 27 from IPython.utils.coloransi import TermColors
28 28 from IPython.utils.jsonutil import rekey, extract_dates, parse_date
29 29 from IPython.utils.localinterfaces import localhost, is_local_ip
30 30 from IPython.utils.path import get_ipython_dir, compress_user
31 31 from IPython.utils.py3compat import cast_bytes, string_types, xrange, iteritems
32 32 from IPython.utils.traitlets import (HasTraits, Integer, Instance, Unicode,
33 33 Dict, List, Bool, Set, Any)
34 34 from IPython.external.decorator import decorator
35 35
36 36 from IPython.parallel import Reference
37 37 from IPython.parallel import error
38 38 from IPython.parallel import util
39 39
40 40 from IPython.kernel.zmq.session import Session, Message
41 41 from IPython.kernel.zmq import serialize
42 42
43 43 from .asyncresult import AsyncResult, AsyncHubResult
44 44 from .view import DirectView, LoadBalancedView
45 45
46 46 #--------------------------------------------------------------------------
47 47 # Decorators for Client methods
48 48 #--------------------------------------------------------------------------
49 49
50 50
51 51 @decorator
52 52 def spin_first(f, self, *args, **kwargs):
53 53 """Call spin() to sync state prior to calling the method."""
54 54 self.spin()
55 55 return f(self, *args, **kwargs)
56 56
57 57
58 58 #--------------------------------------------------------------------------
59 59 # Classes
60 60 #--------------------------------------------------------------------------
61 61
62 62 _no_connection_file_msg = """
63 63 Failed to connect because no Controller could be found.
64 64 Please double-check your profile and ensure that a cluster is running.
65 65 """
66 66
67 67 class ExecuteReply(RichOutput):
68 68 """wrapper for finished Execute results"""
69 69 def __init__(self, msg_id, content, metadata):
70 70 self.msg_id = msg_id
71 71 self._content = content
72 72 self.execution_count = content['execution_count']
73 73 self.metadata = metadata
74 74
75 75 # RichOutput overrides
76 76
77 77 @property
78 78 def source(self):
79 79 execute_result = self.metadata['execute_result']
80 80 if execute_result:
81 81 return execute_result.get('source', '')
82 82
83 83 @property
84 84 def data(self):
85 85 execute_result = self.metadata['execute_result']
86 86 if execute_result:
87 87 return execute_result.get('data', {})
88 88
89 89 @property
90 90 def _metadata(self):
91 91 execute_result = self.metadata['execute_result']
92 92 if execute_result:
93 93 return execute_result.get('metadata', {})
94 94
95 95 def display(self):
96 96 from IPython.display import publish_display_data
97 97 publish_display_data(self.data, self.metadata)
98 98
99 99 def _repr_mime_(self, mime):
100 100 if mime not in self.data:
101 101 return
102 102 data = self.data[mime]
103 103 if mime in self._metadata:
104 104 return data, self._metadata[mime]
105 105 else:
106 106 return data
107 107
108 108 def __getitem__(self, key):
109 109 return self.metadata[key]
110 110
111 111 def __getattr__(self, key):
112 112 if key not in self.metadata:
113 113 raise AttributeError(key)
114 114 return self.metadata[key]
115 115
116 116 def __repr__(self):
117 117 execute_result = self.metadata['execute_result'] or {'data':{}}
118 118 text_out = execute_result['data'].get('text/plain', '')
119 119 if len(text_out) > 32:
120 120 text_out = text_out[:29] + '...'
121 121
122 122 return "<ExecuteReply[%i]: %s>" % (self.execution_count, text_out)
123 123
124 124 def _repr_pretty_(self, p, cycle):
125 125 execute_result = self.metadata['execute_result'] or {'data':{}}
126 126 text_out = execute_result['data'].get('text/plain', '')
127 127
128 128 if not text_out:
129 129 return
130 130
131 131 try:
132 132 ip = get_ipython()
133 133 except NameError:
134 134 colors = "NoColor"
135 135 else:
136 136 colors = ip.colors
137 137
138 138 if colors == "NoColor":
139 139 out = normal = ""
140 140 else:
141 141 out = TermColors.Red
142 142 normal = TermColors.Normal
143 143
144 144 if '\n' in text_out and not text_out.startswith('\n'):
145 145 # add newline for multiline reprs
146 146 text_out = '\n' + text_out
147 147
148 148 p.text(
149 149 out + u'Out[%i:%i]: ' % (
150 150 self.metadata['engine_id'], self.execution_count
151 151 ) + normal + text_out
152 152 )
153 153
154 154
155 155 class Metadata(dict):
156 156 """Subclass of dict for initializing metadata values.
157 157
158 158 Attribute access works on keys.
159 159
160 160 These objects have a strict set of keys - errors will raise if you try
161 161 to add new keys.
162 162 """
163 163 def __init__(self, *args, **kwargs):
164 164 dict.__init__(self)
165 165 md = {'msg_id' : None,
166 166 'submitted' : None,
167 167 'started' : None,
168 168 'completed' : None,
169 169 'received' : None,
170 170 'engine_uuid' : None,
171 171 'engine_id' : None,
172 172 'follow' : None,
173 173 'after' : None,
174 174 'status' : None,
175 175
176 176 'execute_input' : None,
177 177 'execute_result' : None,
178 178 'error' : None,
179 179 'stdout' : '',
180 180 'stderr' : '',
181 181 'outputs' : [],
182 182 'data': {},
183 183 'outputs_ready' : False,
184 184 }
185 185 self.update(md)
186 186 self.update(dict(*args, **kwargs))
187 187
188 188 def __getattr__(self, key):
189 189 """getattr aliased to getitem"""
190 190 if key in self:
191 191 return self[key]
192 192 else:
193 193 raise AttributeError(key)
194 194
195 195 def __setattr__(self, key, value):
196 196 """setattr aliased to setitem, with strict"""
197 197 if key in self:
198 198 self[key] = value
199 199 else:
200 200 raise AttributeError(key)
201 201
202 202 def __setitem__(self, key, value):
203 203 """strict static key enforcement"""
204 204 if key in self:
205 205 dict.__setitem__(self, key, value)
206 206 else:
207 207 raise KeyError(key)
208 208
209 209
210 210 class Client(HasTraits):
211 211 """A semi-synchronous client to the IPython ZMQ cluster
212 212
213 213 Parameters
214 214 ----------
215 215
216 216 url_file : str/unicode; path to ipcontroller-client.json
217 217 This JSON file should contain all the information needed to connect to a cluster,
218 218 and is likely the only argument needed.
219 219 Connection information for the Hub's registration. If a json connector
220 220 file is given, then likely no further configuration is necessary.
221 221 [Default: use profile]
222 222 profile : bytes
223 223 The name of the Cluster profile to be used to find connector information.
224 224 If run from an IPython application, the default profile will be the same
225 225 as the running application, otherwise it will be 'default'.
226 226 cluster_id : str
227 227 String id to added to runtime files, to prevent name collisions when using
228 228 multiple clusters with a single profile simultaneously.
229 229 When set, will look for files named like: 'ipcontroller-<cluster_id>-client.json'
230 230 Since this is text inserted into filenames, typical recommendations apply:
231 231 Simple character strings are ideal, and spaces are not recommended (but
232 232 should generally work)
233 233 context : zmq.Context
234 234 Pass an existing zmq.Context instance, otherwise the client will create its own.
235 235 debug : bool
236 236 flag for lots of message printing for debug purposes
237 237 timeout : int/float
238 238 time (in seconds) to wait for connection replies from the Hub
239 239 [Default: 10]
240 240
241 241 #-------------- session related args ----------------
242 242
243 243 config : Config object
244 244 If specified, this will be relayed to the Session for configuration
245 245 username : str
246 246 set username for the session object
247 247
248 248 #-------------- ssh related args ----------------
249 249 # These are args for configuring the ssh tunnel to be used
250 250 # credentials are used to forward connections over ssh to the Controller
251 251 # Note that the ip given in `addr` needs to be relative to sshserver
252 252 # The most basic case is to leave addr as pointing to localhost (127.0.0.1),
253 253 # and set sshserver as the same machine the Controller is on. However,
254 254 # the only requirement is that sshserver is able to see the Controller
255 255 # (i.e. is within the same trusted network).
256 256
257 257 sshserver : str
258 258 A string of the form passed to ssh, i.e. 'server.tld' or 'user@server.tld:port'
259 259 If keyfile or password is specified, and this is not, it will default to
260 260 the ip given in addr.
261 261 sshkey : str; path to ssh private key file
262 262 This specifies a key to be used in ssh login, default None.
263 263 Regular default ssh keys will be used without specifying this argument.
264 264 password : str
265 265 Your ssh password to sshserver. Note that if this is left None,
266 266 you will be prompted for it if passwordless key based login is unavailable.
267 267 paramiko : bool
268 268 flag for whether to use paramiko instead of shell ssh for tunneling.
269 269 [default: True on win32, False else]
270 270
271 271
272 272 Attributes
273 273 ----------
274 274
275 275 ids : list of int engine IDs
276 276 requesting the ids attribute always synchronizes
277 277 the registration state. To request ids without synchronization,
278 278 use semi-private _ids attributes.
279 279
280 280 history : list of msg_ids
281 281 a list of msg_ids, keeping track of all the execution
282 282 messages you have submitted in order.
283 283
284 284 outstanding : set of msg_ids
285 285 a set of msg_ids that have been submitted, but whose
286 286 results have not yet been received.
287 287
288 288 results : dict
289 289 a dict of all our results, keyed by msg_id
290 290
291 291 block : bool
292 292 determines default behavior when block not specified
293 293 in execution methods
294 294
295 295 Methods
296 296 -------
297 297
298 298 spin
299 299 flushes incoming results and registration state changes
300 300 control methods spin, and requesting `ids` also ensures up to date
301 301
302 302 wait
303 303 wait on one or more msg_ids
304 304
305 305 execution methods
306 306 apply
307 307 legacy: execute, run
308 308
309 309 data movement
310 310 push, pull, scatter, gather
311 311
312 312 query methods
313 313 queue_status, get_result, purge, result_status
314 314
315 315 control methods
316 316 abort, shutdown
317 317
318 318 """
319 319
320 320
321 321 block = Bool(False)
322 322 outstanding = Set()
323 323 results = Instance('collections.defaultdict', (dict,))
324 324 metadata = Instance('collections.defaultdict', (Metadata,))
325 325 history = List()
326 326 debug = Bool(False)
327 327 _spin_thread = Any()
328 328 _stop_spinning = Any()
329 329
330 330 profile=Unicode()
331 331 def _profile_default(self):
332 332 if BaseIPythonApplication.initialized():
333 333 # an IPython app *might* be running, try to get its profile
334 334 try:
335 335 return BaseIPythonApplication.instance().profile
336 336 except (AttributeError, MultipleInstanceError):
337 337 # could be a *different* subclass of config.Application,
338 338 # which would raise one of these two errors.
339 339 return u'default'
340 340 else:
341 341 return u'default'
342 342
343 343
344 344 _outstanding_dict = Instance('collections.defaultdict', (set,))
345 345 _ids = List()
346 346 _connected=Bool(False)
347 347 _ssh=Bool(False)
348 348 _context = Instance('zmq.Context')
349 349 _config = Dict()
350 350 _engines=Instance(util.ReverseDict, (), {})
351 351 # _hub_socket=Instance('zmq.Socket')
352 352 _query_socket=Instance('zmq.Socket')
353 353 _control_socket=Instance('zmq.Socket')
354 354 _iopub_socket=Instance('zmq.Socket')
355 355 _notification_socket=Instance('zmq.Socket')
356 356 _mux_socket=Instance('zmq.Socket')
357 357 _task_socket=Instance('zmq.Socket')
358 358 _task_scheme=Unicode()
359 359 _closed = False
360 360 _ignored_control_replies=Integer(0)
361 361 _ignored_hub_replies=Integer(0)
362 362
363 363 def __new__(self, *args, **kw):
364 364 # don't raise on positional args
365 365 return HasTraits.__new__(self, **kw)
366 366
367 367 def __init__(self, url_file=None, profile=None, profile_dir=None, ipython_dir=None,
368 368 context=None, debug=False,
369 369 sshserver=None, sshkey=None, password=None, paramiko=None,
370 370 timeout=10, cluster_id=None, **extra_args
371 371 ):
372 372 if profile:
373 373 super(Client, self).__init__(debug=debug, profile=profile)
374 374 else:
375 375 super(Client, self).__init__(debug=debug)
376 376 if context is None:
377 377 context = zmq.Context.instance()
378 378 self._context = context
379 379 self._stop_spinning = Event()
380 380
381 381 if 'url_or_file' in extra_args:
382 382 url_file = extra_args['url_or_file']
383 383 warnings.warn("url_or_file arg no longer supported, use url_file", DeprecationWarning)
384 384
385 385 if url_file and util.is_url(url_file):
386 386 raise ValueError("single urls cannot be specified, url-files must be used.")
387 387
388 388 self._setup_profile_dir(self.profile, profile_dir, ipython_dir)
389 389
390 390 no_file_msg = '\n'.join([
391 391 "You have attempted to connect to an IPython Cluster but no Controller could be found.",
392 392 "Please double-check your configuration and ensure that a cluster is running.",
393 393 ])
394 394
395 395 if self._cd is not None:
396 396 if url_file is None:
397 397 if not cluster_id:
398 398 client_json = 'ipcontroller-client.json'
399 399 else:
400 400 client_json = 'ipcontroller-%s-client.json' % cluster_id
401 401 url_file = pjoin(self._cd.security_dir, client_json)
402 402 if not os.path.exists(url_file):
403 403 msg = '\n'.join([
404 404 "Connection file %r not found." % compress_user(url_file),
405 405 no_file_msg,
406 406 ])
407 407 raise IOError(msg)
408 408 if url_file is None:
409 409 raise IOError(no_file_msg)
410 410
411 411 if not os.path.exists(url_file):
412 412 # Connection file explicitly specified, but not found
413 413 raise IOError("Connection file %r not found. Is a controller running?" % \
414 414 compress_user(url_file)
415 415 )
416 416
417 417 with open(url_file) as f:
418 418 cfg = json.load(f)
419 419
420 420 self._task_scheme = cfg['task_scheme']
421 421
422 422 # sync defaults from args, json:
423 423 if sshserver:
424 424 cfg['ssh'] = sshserver
425 425
426 426 location = cfg.setdefault('location', None)
427 427
428 428 proto,addr = cfg['interface'].split('://')
429 429 addr = util.disambiguate_ip_address(addr, location)
430 430 cfg['interface'] = "%s://%s" % (proto, addr)
431 431
432 432 # turn interface,port into full urls:
433 433 for key in ('control', 'task', 'mux', 'iopub', 'notification', 'registration'):
434 434 cfg[key] = cfg['interface'] + ':%i' % cfg[key]
435 435
436 436 url = cfg['registration']
437 437
438 438 if location is not None and addr == localhost():
439 439 # location specified, and connection is expected to be local
440 440 if not is_local_ip(location) and not sshserver:
441 441 # load ssh from JSON *only* if the controller is not on
442 442 # this machine
443 443 sshserver=cfg['ssh']
444 444 if not is_local_ip(location) and not sshserver:
445 445 # warn if no ssh specified, but SSH is probably needed
446 446 # This is only a warning, because the most likely cause
447 447 # is a local Controller on a laptop whose IP is dynamic
448 448 warnings.warn("""
449 449 Controller appears to be listening on localhost, but not on this machine.
450 450 If this is true, you should specify Client(...,sshserver='you@%s')
451 451 or instruct your controller to listen on an external IP."""%location,
452 452 RuntimeWarning)
453 453 elif not sshserver:
454 454 # otherwise sync with cfg
455 455 sshserver = cfg['ssh']
456 456
457 457 self._config = cfg
458 458
459 459 self._ssh = bool(sshserver or sshkey or password)
460 460 if self._ssh and sshserver is None:
461 461 # default to ssh via localhost
462 462 sshserver = addr
463 463 if self._ssh and password is None:
464 464 from zmq.ssh import tunnel
465 465 if tunnel.try_passwordless_ssh(sshserver, sshkey, paramiko):
466 466 password=False
467 467 else:
468 468 password = getpass("SSH Password for %s: "%sshserver)
469 469 ssh_kwargs = dict(keyfile=sshkey, password=password, paramiko=paramiko)
470 470
471 471 # configure and construct the session
472 472 try:
473 473 extra_args['packer'] = cfg['pack']
474 474 extra_args['unpacker'] = cfg['unpack']
475 475 extra_args['key'] = cast_bytes(cfg['key'])
476 476 extra_args['signature_scheme'] = cfg['signature_scheme']
477 477 except KeyError as exc:
478 478 msg = '\n'.join([
479 479 "Connection file is invalid (missing '{}'), possibly from an old version of IPython.",
480 480 "If you are reusing connection files, remove them and start ipcontroller again."
481 481 ])
482 482 raise ValueError(msg.format(exc.message))
483 483
484 484 self.session = Session(**extra_args)
485 485
486 486 self._query_socket = self._context.socket(zmq.DEALER)
487 487
488 488 if self._ssh:
489 489 from zmq.ssh import tunnel
490 490 tunnel.tunnel_connection(self._query_socket, cfg['registration'], sshserver, **ssh_kwargs)
491 491 else:
492 492 self._query_socket.connect(cfg['registration'])
493 493
494 494 self.session.debug = self.debug
495 495
496 496 self._notification_handlers = {'registration_notification' : self._register_engine,
497 497 'unregistration_notification' : self._unregister_engine,
498 498 'shutdown_notification' : lambda msg: self.close(),
499 499 }
500 500 self._queue_handlers = {'execute_reply' : self._handle_execute_reply,
501 501 'apply_reply' : self._handle_apply_reply}
502 502
503 503 try:
504 504 self._connect(sshserver, ssh_kwargs, timeout)
505 505 except:
506 506 self.close(linger=0)
507 507 raise
508 508
509 509 # last step: setup magics, if we are in IPython:
510 510
511 511 try:
512 512 ip = get_ipython()
513 513 except NameError:
514 514 return
515 515 else:
516 516 if 'px' not in ip.magics_manager.magics:
517 517 # in IPython but we are the first Client.
518 518 # activate a default view for parallel magics.
519 519 self.activate()
520 520
521 521 def __del__(self):
522 522 """cleanup sockets, but _not_ context."""
523 523 self.close()
524 524
525 525 def _setup_profile_dir(self, profile, profile_dir, ipython_dir):
526 526 if ipython_dir is None:
527 527 ipython_dir = get_ipython_dir()
528 528 if profile_dir is not None:
529 529 try:
530 530 self._cd = ProfileDir.find_profile_dir(profile_dir)
531 531 return
532 532 except ProfileDirError:
533 533 pass
534 534 elif profile is not None:
535 535 try:
536 536 self._cd = ProfileDir.find_profile_dir_by_name(
537 537 ipython_dir, profile)
538 538 return
539 539 except ProfileDirError:
540 540 pass
541 541 self._cd = None
542 542
543 543 def _update_engines(self, engines):
544 544 """Update our engines dict and _ids from a dict of the form: {id:uuid}."""
545 545 for k,v in iteritems(engines):
546 546 eid = int(k)
547 547 if eid not in self._engines:
548 548 self._ids.append(eid)
549 549 self._engines[eid] = v
550 550 self._ids = sorted(self._ids)
551 551 if sorted(self._engines.keys()) != list(range(len(self._engines))) and \
552 552 self._task_scheme == 'pure' and self._task_socket:
553 553 self._stop_scheduling_tasks()
554 554
555 555 def _stop_scheduling_tasks(self):
556 556 """Stop scheduling tasks because an engine has been unregistered
557 557 from a pure ZMQ scheduler.
558 558 """
559 559 self._task_socket.close()
560 560 self._task_socket = None
561 561 msg = "An engine has been unregistered, and we are using pure " +\
562 562 "ZMQ task scheduling. Task farming will be disabled."
563 563 if self.outstanding:
564 564 msg += " If you were running tasks when this happened, " +\
565 565 "some `outstanding` msg_ids may never resolve."
566 566 warnings.warn(msg, RuntimeWarning)
567 567
568 568 def _build_targets(self, targets):
569 569 """Turn valid target IDs or 'all' into two lists:
570 570 (int_ids, uuids).
571 571 """
572 572 if not self._ids:
573 573 # flush notification socket if no engines yet, just in case
574 574 if not self.ids:
575 575 raise error.NoEnginesRegistered("Can't build targets without any engines")
576 576
577 577 if targets is None:
578 578 targets = self._ids
579 579 elif isinstance(targets, string_types):
580 580 if targets.lower() == 'all':
581 581 targets = self._ids
582 582 else:
583 583 raise TypeError("%r not valid str target, must be 'all'"%(targets))
584 584 elif isinstance(targets, int):
585 585 if targets < 0:
586 586 targets = self.ids[targets]
587 587 if targets not in self._ids:
588 588 raise IndexError("No such engine: %i"%targets)
589 589 targets = [targets]
590 590
591 591 if isinstance(targets, slice):
592 592 indices = list(range(len(self._ids))[targets])
593 593 ids = self.ids
594 594 targets = [ ids[i] for i in indices ]
595 595
596 596 if not isinstance(targets, (tuple, list, xrange)):
597 597 raise TypeError("targets by int/slice/collection of ints only, not %s"%(type(targets)))
598 598
599 599 return [cast_bytes(self._engines[t]) for t in targets], list(targets)
600 600
601 601 def _connect(self, sshserver, ssh_kwargs, timeout):
602 602 """setup all our socket connections to the cluster. This is called from
603 603 __init__."""
604 604
605 605 # Maybe allow reconnecting?
606 606 if self._connected:
607 607 return
608 608 self._connected=True
609 609
610 610 def connect_socket(s, url):
611 611 if self._ssh:
612 612 from zmq.ssh import tunnel
613 613 return tunnel.tunnel_connection(s, url, sshserver, **ssh_kwargs)
614 614 else:
615 615 return s.connect(url)
616 616
617 617 self.session.send(self._query_socket, 'connection_request')
618 618 # use Poller because zmq.select has wrong units in pyzmq 2.1.7
619 619 poller = zmq.Poller()
620 620 poller.register(self._query_socket, zmq.POLLIN)
621 621 # poll expects milliseconds, timeout is seconds
622 622 evts = poller.poll(timeout*1000)
623 623 if not evts:
624 624 raise error.TimeoutError("Hub connection request timed out")
625 625 idents,msg = self.session.recv(self._query_socket,mode=0)
626 626 if self.debug:
627 627 pprint(msg)
628 628 content = msg['content']
629 629 # self._config['registration'] = dict(content)
630 630 cfg = self._config
631 631 if content['status'] == 'ok':
632 632 self._mux_socket = self._context.socket(zmq.DEALER)
633 633 connect_socket(self._mux_socket, cfg['mux'])
634 634
635 635 self._task_socket = self._context.socket(zmq.DEALER)
636 636 connect_socket(self._task_socket, cfg['task'])
637 637
638 638 self._notification_socket = self._context.socket(zmq.SUB)
639 639 self._notification_socket.setsockopt(zmq.SUBSCRIBE, b'')
640 640 connect_socket(self._notification_socket, cfg['notification'])
641 641
642 642 self._control_socket = self._context.socket(zmq.DEALER)
643 643 connect_socket(self._control_socket, cfg['control'])
644 644
645 645 self._iopub_socket = self._context.socket(zmq.SUB)
646 646 self._iopub_socket.setsockopt(zmq.SUBSCRIBE, b'')
647 647 connect_socket(self._iopub_socket, cfg['iopub'])
648 648
649 649 self._update_engines(dict(content['engines']))
650 650 else:
651 651 self._connected = False
652 652 raise Exception("Failed to connect!")
653 653
654 654 #--------------------------------------------------------------------------
655 655 # handlers and callbacks for incoming messages
656 656 #--------------------------------------------------------------------------
657 657
658 658 def _unwrap_exception(self, content):
659 659 """unwrap exception, and remap engine_id to int."""
660 660 e = error.unwrap_exception(content)
661 661 # print e.traceback
662 662 if e.engine_info:
663 663 e_uuid = e.engine_info['engine_uuid']
664 664 eid = self._engines[e_uuid]
665 665 e.engine_info['engine_id'] = eid
666 666 return e
667 667
668 668 def _extract_metadata(self, msg):
669 669 header = msg['header']
670 670 parent = msg['parent_header']
671 671 msg_meta = msg['metadata']
672 672 content = msg['content']
673 673 md = {'msg_id' : parent['msg_id'],
674 674 'received' : datetime.now(),
675 675 'engine_uuid' : msg_meta.get('engine', None),
676 676 'follow' : msg_meta.get('follow', []),
677 677 'after' : msg_meta.get('after', []),
678 678 'status' : content['status'],
679 679 }
680 680
681 681 if md['engine_uuid'] is not None:
682 682 md['engine_id'] = self._engines.get(md['engine_uuid'], None)
683 683
684 684 if 'date' in parent:
685 685 md['submitted'] = parent['date']
686 686 if 'started' in msg_meta:
687 687 md['started'] = parse_date(msg_meta['started'])
688 688 if 'date' in header:
689 689 md['completed'] = header['date']
690 690 return md
691 691
692 692 def _register_engine(self, msg):
693 693 """Register a new engine, and update our connection info."""
694 694 content = msg['content']
695 695 eid = content['id']
696 696 d = {eid : content['uuid']}
697 697 self._update_engines(d)
698 698
699 699 def _unregister_engine(self, msg):
700 700 """Unregister an engine that has died."""
701 701 content = msg['content']
702 702 eid = int(content['id'])
703 703 if eid in self._ids:
704 704 self._ids.remove(eid)
705 705 uuid = self._engines.pop(eid)
706 706
707 707 self._handle_stranded_msgs(eid, uuid)
708 708
709 709 if self._task_socket and self._task_scheme == 'pure':
710 710 self._stop_scheduling_tasks()
711 711
712 712 def _handle_stranded_msgs(self, eid, uuid):
713 713 """Handle messages known to be on an engine when the engine unregisters.
714 714
715 715 It is possible that this will fire prematurely - that is, an engine will
716 716 go down after completing a result, and the client will be notified
717 717 of the unregistration and later receive the successful result.
718 718 """
719 719
720 720 outstanding = self._outstanding_dict[uuid]
721 721
722 722 for msg_id in list(outstanding):
723 723 if msg_id in self.results:
724 724 # we already
725 725 continue
726 726 try:
727 727 raise error.EngineError("Engine %r died while running task %r"%(eid, msg_id))
728 728 except:
729 729 content = error.wrap_exception()
730 730 # build a fake message:
731 731 msg = self.session.msg('apply_reply', content=content)
732 732 msg['parent_header']['msg_id'] = msg_id
733 733 msg['metadata']['engine'] = uuid
734 734 self._handle_apply_reply(msg)
735 735
736 736 def _handle_execute_reply(self, msg):
737 737 """Save the reply to an execute_request into our results.
738 738
739 739 execute messages are never actually used. apply is used instead.
740 740 """
741 741
742 742 parent = msg['parent_header']
743 743 msg_id = parent['msg_id']
744 744 if msg_id not in self.outstanding:
745 745 if msg_id in self.history:
746 746 print("got stale result: %s"%msg_id)
747 747 else:
748 748 print("got unknown result: %s"%msg_id)
749 749 else:
750 750 self.outstanding.remove(msg_id)
751 751
752 752 content = msg['content']
753 753 header = msg['header']
754 754
755 755 # construct metadata:
756 756 md = self.metadata[msg_id]
757 757 md.update(self._extract_metadata(msg))
758 758 # is this redundant?
759 759 self.metadata[msg_id] = md
760 760
761 761 e_outstanding = self._outstanding_dict[md['engine_uuid']]
762 762 if msg_id in e_outstanding:
763 763 e_outstanding.remove(msg_id)
764 764
765 765 # construct result:
766 766 if content['status'] == 'ok':
767 767 self.results[msg_id] = ExecuteReply(msg_id, content, md)
768 768 elif content['status'] == 'aborted':
769 769 self.results[msg_id] = error.TaskAborted(msg_id)
770 770 elif content['status'] == 'resubmitted':
771 771 # TODO: handle resubmission
772 772 pass
773 773 else:
774 774 self.results[msg_id] = self._unwrap_exception(content)
775 775
776 776 def _handle_apply_reply(self, msg):
777 777 """Save the reply to an apply_request into our results."""
778 778 parent = msg['parent_header']
779 779 msg_id = parent['msg_id']
780 780 if msg_id not in self.outstanding:
781 781 if msg_id in self.history:
782 782 print("got stale result: %s"%msg_id)
783 783 print(self.results[msg_id])
784 784 print(msg)
785 785 else:
786 786 print("got unknown result: %s"%msg_id)
787 787 else:
788 788 self.outstanding.remove(msg_id)
789 789 content = msg['content']
790 790 header = msg['header']
791 791
792 792 # construct metadata:
793 793 md = self.metadata[msg_id]
794 794 md.update(self._extract_metadata(msg))
795 795 # is this redundant?
796 796 self.metadata[msg_id] = md
797 797
798 798 e_outstanding = self._outstanding_dict[md['engine_uuid']]
799 799 if msg_id in e_outstanding:
800 800 e_outstanding.remove(msg_id)
801 801
802 802 # construct result:
803 803 if content['status'] == 'ok':
804 self.results[msg_id] = serialize.unserialize_object(msg['buffers'])[0]
804 self.results[msg_id] = serialize.deserialize_object(msg['buffers'])[0]
805 805 elif content['status'] == 'aborted':
806 806 self.results[msg_id] = error.TaskAborted(msg_id)
807 807 elif content['status'] == 'resubmitted':
808 808 # TODO: handle resubmission
809 809 pass
810 810 else:
811 811 self.results[msg_id] = self._unwrap_exception(content)
812 812
813 813 def _flush_notifications(self):
814 814 """Flush notifications of engine registrations waiting
815 815 in ZMQ queue."""
816 816 idents,msg = self.session.recv(self._notification_socket, mode=zmq.NOBLOCK)
817 817 while msg is not None:
818 818 if self.debug:
819 819 pprint(msg)
820 820 msg_type = msg['header']['msg_type']
821 821 handler = self._notification_handlers.get(msg_type, None)
822 822 if handler is None:
823 823 raise Exception("Unhandled message type: %s" % msg_type)
824 824 else:
825 825 handler(msg)
826 826 idents,msg = self.session.recv(self._notification_socket, mode=zmq.NOBLOCK)
827 827
828 828 def _flush_results(self, sock):
829 829 """Flush task or queue results waiting in ZMQ queue."""
830 830 idents,msg = self.session.recv(sock, mode=zmq.NOBLOCK)
831 831 while msg is not None:
832 832 if self.debug:
833 833 pprint(msg)
834 834 msg_type = msg['header']['msg_type']
835 835 handler = self._queue_handlers.get(msg_type, None)
836 836 if handler is None:
837 837 raise Exception("Unhandled message type: %s" % msg_type)
838 838 else:
839 839 handler(msg)
840 840 idents,msg = self.session.recv(sock, mode=zmq.NOBLOCK)
841 841
842 842 def _flush_control(self, sock):
843 843 """Flush replies from the control channel waiting
844 844 in the ZMQ queue.
845 845
846 846 Currently: ignore them."""
847 847 if self._ignored_control_replies <= 0:
848 848 return
849 849 idents,msg = self.session.recv(sock, mode=zmq.NOBLOCK)
850 850 while msg is not None:
851 851 self._ignored_control_replies -= 1
852 852 if self.debug:
853 853 pprint(msg)
854 854 idents,msg = self.session.recv(sock, mode=zmq.NOBLOCK)
855 855
856 856 def _flush_ignored_control(self):
857 857 """flush ignored control replies"""
858 858 while self._ignored_control_replies > 0:
859 859 self.session.recv(self._control_socket)
860 860 self._ignored_control_replies -= 1
861 861
862 862 def _flush_ignored_hub_replies(self):
863 863 ident,msg = self.session.recv(self._query_socket, mode=zmq.NOBLOCK)
864 864 while msg is not None:
865 865 ident,msg = self.session.recv(self._query_socket, mode=zmq.NOBLOCK)
866 866
867 867 def _flush_iopub(self, sock):
868 868 """Flush replies from the iopub channel waiting
869 869 in the ZMQ queue.
870 870 """
871 871 idents,msg = self.session.recv(sock, mode=zmq.NOBLOCK)
872 872 while msg is not None:
873 873 if self.debug:
874 874 pprint(msg)
875 875 parent = msg['parent_header']
876 876 if not parent or parent['session'] != self.session.session:
877 877 # ignore IOPub messages not from here
878 878 idents,msg = self.session.recv(sock, mode=zmq.NOBLOCK)
879 879 continue
880 880 msg_id = parent['msg_id']
881 881 content = msg['content']
882 882 header = msg['header']
883 883 msg_type = msg['header']['msg_type']
884 884
885 885 if msg_type == 'status' and msg_id not in self.metadata:
886 886 # ignore status messages if they aren't mine
887 887 idents,msg = self.session.recv(sock, mode=zmq.NOBLOCK)
888 888 continue
889 889
890 890 # init metadata:
891 891 md = self.metadata[msg_id]
892 892
893 893 if msg_type == 'stream':
894 894 name = content['name']
895 895 s = md[name] or ''
896 896 md[name] = s + content['text']
897 897 elif msg_type == 'error':
898 898 md.update({'error' : self._unwrap_exception(content)})
899 899 elif msg_type == 'execute_input':
900 900 md.update({'execute_input' : content['code']})
901 901 elif msg_type == 'display_data':
902 902 md['outputs'].append(content)
903 903 elif msg_type == 'execute_result':
904 904 md['execute_result'] = content
905 905 elif msg_type == 'data_message':
906 data, remainder = serialize.unserialize_object(msg['buffers'])
906 data, remainder = serialize.deserialize_object(msg['buffers'])
907 907 md['data'].update(data)
908 908 elif msg_type == 'status':
909 909 # idle message comes after all outputs
910 910 if content['execution_state'] == 'idle':
911 911 md['outputs_ready'] = True
912 912 else:
913 913 # unhandled msg_type (status, etc.)
914 914 pass
915 915
916 916 # reduntant?
917 917 self.metadata[msg_id] = md
918 918
919 919 idents,msg = self.session.recv(sock, mode=zmq.NOBLOCK)
920 920
921 921 #--------------------------------------------------------------------------
922 922 # len, getitem
923 923 #--------------------------------------------------------------------------
924 924
925 925 def __len__(self):
926 926 """len(client) returns # of engines."""
927 927 return len(self.ids)
928 928
929 929 def __getitem__(self, key):
930 930 """index access returns DirectView multiplexer objects
931 931
932 932 Must be int, slice, or list/tuple/xrange of ints"""
933 933 if not isinstance(key, (int, slice, tuple, list, xrange)):
934 934 raise TypeError("key by int/slice/iterable of ints only, not %s"%(type(key)))
935 935 else:
936 936 return self.direct_view(key)
937 937
938 938 def __iter__(self):
939 939 """Since we define getitem, Client is iterable
940 940
941 941 but unless we also define __iter__, it won't work correctly unless engine IDs
942 942 start at zero and are continuous.
943 943 """
944 944 for eid in self.ids:
945 945 yield self.direct_view(eid)
946 946
947 947 #--------------------------------------------------------------------------
948 948 # Begin public methods
949 949 #--------------------------------------------------------------------------
950 950
951 951 @property
952 952 def ids(self):
953 953 """Always up-to-date ids property."""
954 954 self._flush_notifications()
955 955 # always copy:
956 956 return list(self._ids)
957 957
958 958 def activate(self, targets='all', suffix=''):
959 959 """Create a DirectView and register it with IPython magics
960 960
961 961 Defines the magics `%px, %autopx, %pxresult, %%px`
962 962
963 963 Parameters
964 964 ----------
965 965
966 966 targets: int, list of ints, or 'all'
967 967 The engines on which the view's magics will run
968 968 suffix: str [default: '']
969 969 The suffix, if any, for the magics. This allows you to have
970 970 multiple views associated with parallel magics at the same time.
971 971
972 972 e.g. ``rc.activate(targets=0, suffix='0')`` will give you
973 973 the magics ``%px0``, ``%pxresult0``, etc. for running magics just
974 974 on engine 0.
975 975 """
976 976 view = self.direct_view(targets)
977 977 view.block = True
978 978 view.activate(suffix)
979 979 return view
980 980
981 981 def close(self, linger=None):
982 982 """Close my zmq Sockets
983 983
984 984 If `linger`, set the zmq LINGER socket option,
985 985 which allows discarding of messages.
986 986 """
987 987 if self._closed:
988 988 return
989 989 self.stop_spin_thread()
990 990 snames = [ trait for trait in self.trait_names() if trait.endswith("socket") ]
991 991 for name in snames:
992 992 socket = getattr(self, name)
993 993 if socket is not None and not socket.closed:
994 994 if linger is not None:
995 995 socket.close(linger=linger)
996 996 else:
997 997 socket.close()
998 998 self._closed = True
999 999
1000 1000 def _spin_every(self, interval=1):
1001 1001 """target func for use in spin_thread"""
1002 1002 while True:
1003 1003 if self._stop_spinning.is_set():
1004 1004 return
1005 1005 time.sleep(interval)
1006 1006 self.spin()
1007 1007
1008 1008 def spin_thread(self, interval=1):
1009 1009 """call Client.spin() in a background thread on some regular interval
1010 1010
1011 1011 This helps ensure that messages don't pile up too much in the zmq queue
1012 1012 while you are working on other things, or just leaving an idle terminal.
1013 1013
1014 1014 It also helps limit potential padding of the `received` timestamp
1015 1015 on AsyncResult objects, used for timings.
1016 1016
1017 1017 Parameters
1018 1018 ----------
1019 1019
1020 1020 interval : float, optional
1021 1021 The interval on which to spin the client in the background thread
1022 1022 (simply passed to time.sleep).
1023 1023
1024 1024 Notes
1025 1025 -----
1026 1026
1027 1027 For precision timing, you may want to use this method to put a bound
1028 1028 on the jitter (in seconds) in `received` timestamps used
1029 1029 in AsyncResult.wall_time.
1030 1030
1031 1031 """
1032 1032 if self._spin_thread is not None:
1033 1033 self.stop_spin_thread()
1034 1034 self._stop_spinning.clear()
1035 1035 self._spin_thread = Thread(target=self._spin_every, args=(interval,))
1036 1036 self._spin_thread.daemon = True
1037 1037 self._spin_thread.start()
1038 1038
1039 1039 def stop_spin_thread(self):
1040 1040 """stop background spin_thread, if any"""
1041 1041 if self._spin_thread is not None:
1042 1042 self._stop_spinning.set()
1043 1043 self._spin_thread.join()
1044 1044 self._spin_thread = None
1045 1045
1046 1046 def spin(self):
1047 1047 """Flush any registration notifications and execution results
1048 1048 waiting in the ZMQ queue.
1049 1049 """
1050 1050 if self._notification_socket:
1051 1051 self._flush_notifications()
1052 1052 if self._iopub_socket:
1053 1053 self._flush_iopub(self._iopub_socket)
1054 1054 if self._mux_socket:
1055 1055 self._flush_results(self._mux_socket)
1056 1056 if self._task_socket:
1057 1057 self._flush_results(self._task_socket)
1058 1058 if self._control_socket:
1059 1059 self._flush_control(self._control_socket)
1060 1060 if self._query_socket:
1061 1061 self._flush_ignored_hub_replies()
1062 1062
1063 1063 def wait(self, jobs=None, timeout=-1):
1064 1064 """waits on one or more `jobs`, for up to `timeout` seconds.
1065 1065
1066 1066 Parameters
1067 1067 ----------
1068 1068
1069 1069 jobs : int, str, or list of ints and/or strs, or one or more AsyncResult objects
1070 1070 ints are indices to self.history
1071 1071 strs are msg_ids
1072 1072 default: wait on all outstanding messages
1073 1073 timeout : float
1074 1074 a time in seconds, after which to give up.
1075 1075 default is -1, which means no timeout
1076 1076
1077 1077 Returns
1078 1078 -------
1079 1079
1080 1080 True : when all msg_ids are done
1081 1081 False : timeout reached, some msg_ids still outstanding
1082 1082 """
1083 1083 tic = time.time()
1084 1084 if jobs is None:
1085 1085 theids = self.outstanding
1086 1086 else:
1087 1087 if isinstance(jobs, string_types + (int, AsyncResult)):
1088 1088 jobs = [jobs]
1089 1089 theids = set()
1090 1090 for job in jobs:
1091 1091 if isinstance(job, int):
1092 1092 # index access
1093 1093 job = self.history[job]
1094 1094 elif isinstance(job, AsyncResult):
1095 1095 theids.update(job.msg_ids)
1096 1096 continue
1097 1097 theids.add(job)
1098 1098 if not theids.intersection(self.outstanding):
1099 1099 return True
1100 1100 self.spin()
1101 1101 while theids.intersection(self.outstanding):
1102 1102 if timeout >= 0 and ( time.time()-tic ) > timeout:
1103 1103 break
1104 1104 time.sleep(1e-3)
1105 1105 self.spin()
1106 1106 return len(theids.intersection(self.outstanding)) == 0
1107 1107
1108 1108 #--------------------------------------------------------------------------
1109 1109 # Control methods
1110 1110 #--------------------------------------------------------------------------
1111 1111
1112 1112 @spin_first
1113 1113 def clear(self, targets=None, block=None):
1114 1114 """Clear the namespace in target(s)."""
1115 1115 block = self.block if block is None else block
1116 1116 targets = self._build_targets(targets)[0]
1117 1117 for t in targets:
1118 1118 self.session.send(self._control_socket, 'clear_request', content={}, ident=t)
1119 1119 error = False
1120 1120 if block:
1121 1121 self._flush_ignored_control()
1122 1122 for i in range(len(targets)):
1123 1123 idents,msg = self.session.recv(self._control_socket,0)
1124 1124 if self.debug:
1125 1125 pprint(msg)
1126 1126 if msg['content']['status'] != 'ok':
1127 1127 error = self._unwrap_exception(msg['content'])
1128 1128 else:
1129 1129 self._ignored_control_replies += len(targets)
1130 1130 if error:
1131 1131 raise error
1132 1132
1133 1133
1134 1134 @spin_first
1135 1135 def abort(self, jobs=None, targets=None, block=None):
1136 1136 """Abort specific jobs from the execution queues of target(s).
1137 1137
1138 1138 This is a mechanism to prevent jobs that have already been submitted
1139 1139 from executing.
1140 1140
1141 1141 Parameters
1142 1142 ----------
1143 1143
1144 1144 jobs : msg_id, list of msg_ids, or AsyncResult
1145 1145 The jobs to be aborted
1146 1146
1147 1147 If unspecified/None: abort all outstanding jobs.
1148 1148
1149 1149 """
1150 1150 block = self.block if block is None else block
1151 1151 jobs = jobs if jobs is not None else list(self.outstanding)
1152 1152 targets = self._build_targets(targets)[0]
1153 1153
1154 1154 msg_ids = []
1155 1155 if isinstance(jobs, string_types + (AsyncResult,)):
1156 1156 jobs = [jobs]
1157 1157 bad_ids = [obj for obj in jobs if not isinstance(obj, string_types + (AsyncResult,))]
1158 1158 if bad_ids:
1159 1159 raise TypeError("Invalid msg_id type %r, expected str or AsyncResult"%bad_ids[0])
1160 1160 for j in jobs:
1161 1161 if isinstance(j, AsyncResult):
1162 1162 msg_ids.extend(j.msg_ids)
1163 1163 else:
1164 1164 msg_ids.append(j)
1165 1165 content = dict(msg_ids=msg_ids)
1166 1166 for t in targets:
1167 1167 self.session.send(self._control_socket, 'abort_request',
1168 1168 content=content, ident=t)
1169 1169 error = False
1170 1170 if block:
1171 1171 self._flush_ignored_control()
1172 1172 for i in range(len(targets)):
1173 1173 idents,msg = self.session.recv(self._control_socket,0)
1174 1174 if self.debug:
1175 1175 pprint(msg)
1176 1176 if msg['content']['status'] != 'ok':
1177 1177 error = self._unwrap_exception(msg['content'])
1178 1178 else:
1179 1179 self._ignored_control_replies += len(targets)
1180 1180 if error:
1181 1181 raise error
1182 1182
1183 1183 @spin_first
1184 1184 def shutdown(self, targets='all', restart=False, hub=False, block=None):
1185 1185 """Terminates one or more engine processes, optionally including the hub.
1186 1186
1187 1187 Parameters
1188 1188 ----------
1189 1189
1190 1190 targets: list of ints or 'all' [default: all]
1191 1191 Which engines to shutdown.
1192 1192 hub: bool [default: False]
1193 1193 Whether to include the Hub. hub=True implies targets='all'.
1194 1194 block: bool [default: self.block]
1195 1195 Whether to wait for clean shutdown replies or not.
1196 1196 restart: bool [default: False]
1197 1197 NOT IMPLEMENTED
1198 1198 whether to restart engines after shutting them down.
1199 1199 """
1200 1200 from IPython.parallel.error import NoEnginesRegistered
1201 1201 if restart:
1202 1202 raise NotImplementedError("Engine restart is not yet implemented")
1203 1203
1204 1204 block = self.block if block is None else block
1205 1205 if hub:
1206 1206 targets = 'all'
1207 1207 try:
1208 1208 targets = self._build_targets(targets)[0]
1209 1209 except NoEnginesRegistered:
1210 1210 targets = []
1211 1211 for t in targets:
1212 1212 self.session.send(self._control_socket, 'shutdown_request',
1213 1213 content={'restart':restart},ident=t)
1214 1214 error = False
1215 1215 if block or hub:
1216 1216 self._flush_ignored_control()
1217 1217 for i in range(len(targets)):
1218 1218 idents,msg = self.session.recv(self._control_socket, 0)
1219 1219 if self.debug:
1220 1220 pprint(msg)
1221 1221 if msg['content']['status'] != 'ok':
1222 1222 error = self._unwrap_exception(msg['content'])
1223 1223 else:
1224 1224 self._ignored_control_replies += len(targets)
1225 1225
1226 1226 if hub:
1227 1227 time.sleep(0.25)
1228 1228 self.session.send(self._query_socket, 'shutdown_request')
1229 1229 idents,msg = self.session.recv(self._query_socket, 0)
1230 1230 if self.debug:
1231 1231 pprint(msg)
1232 1232 if msg['content']['status'] != 'ok':
1233 1233 error = self._unwrap_exception(msg['content'])
1234 1234
1235 1235 if error:
1236 1236 raise error
1237 1237
1238 1238 #--------------------------------------------------------------------------
1239 1239 # Execution related methods
1240 1240 #--------------------------------------------------------------------------
1241 1241
1242 1242 def _maybe_raise(self, result):
1243 1243 """wrapper for maybe raising an exception if apply failed."""
1244 1244 if isinstance(result, error.RemoteError):
1245 1245 raise result
1246 1246
1247 1247 return result
1248 1248
1249 1249 def send_apply_request(self, socket, f, args=None, kwargs=None, metadata=None, track=False,
1250 1250 ident=None):
1251 1251 """construct and send an apply message via a socket.
1252 1252
1253 1253 This is the principal method with which all engine execution is performed by views.
1254 1254 """
1255 1255
1256 1256 if self._closed:
1257 1257 raise RuntimeError("Client cannot be used after its sockets have been closed")
1258 1258
1259 1259 # defaults:
1260 1260 args = args if args is not None else []
1261 1261 kwargs = kwargs if kwargs is not None else {}
1262 1262 metadata = metadata if metadata is not None else {}
1263 1263
1264 1264 # validate arguments
1265 1265 if not callable(f) and not isinstance(f, Reference):
1266 1266 raise TypeError("f must be callable, not %s"%type(f))
1267 1267 if not isinstance(args, (tuple, list)):
1268 1268 raise TypeError("args must be tuple or list, not %s"%type(args))
1269 1269 if not isinstance(kwargs, dict):
1270 1270 raise TypeError("kwargs must be dict, not %s"%type(kwargs))
1271 1271 if not isinstance(metadata, dict):
1272 1272 raise TypeError("metadata must be dict, not %s"%type(metadata))
1273 1273
1274 1274 bufs = serialize.pack_apply_message(f, args, kwargs,
1275 1275 buffer_threshold=self.session.buffer_threshold,
1276 1276 item_threshold=self.session.item_threshold,
1277 1277 )
1278 1278
1279 1279 msg = self.session.send(socket, "apply_request", buffers=bufs, ident=ident,
1280 1280 metadata=metadata, track=track)
1281 1281
1282 1282 msg_id = msg['header']['msg_id']
1283 1283 self.outstanding.add(msg_id)
1284 1284 if ident:
1285 1285 # possibly routed to a specific engine
1286 1286 if isinstance(ident, list):
1287 1287 ident = ident[-1]
1288 1288 if ident in self._engines.values():
1289 1289 # save for later, in case of engine death
1290 1290 self._outstanding_dict[ident].add(msg_id)
1291 1291 self.history.append(msg_id)
1292 1292 self.metadata[msg_id]['submitted'] = datetime.now()
1293 1293
1294 1294 return msg
1295 1295
1296 1296 def send_execute_request(self, socket, code, silent=True, metadata=None, ident=None):
1297 1297 """construct and send an execute request via a socket.
1298 1298
1299 1299 """
1300 1300
1301 1301 if self._closed:
1302 1302 raise RuntimeError("Client cannot be used after its sockets have been closed")
1303 1303
1304 1304 # defaults:
1305 1305 metadata = metadata if metadata is not None else {}
1306 1306
1307 1307 # validate arguments
1308 1308 if not isinstance(code, string_types):
1309 1309 raise TypeError("code must be text, not %s" % type(code))
1310 1310 if not isinstance(metadata, dict):
1311 1311 raise TypeError("metadata must be dict, not %s" % type(metadata))
1312 1312
1313 1313 content = dict(code=code, silent=bool(silent), user_expressions={})
1314 1314
1315 1315
1316 1316 msg = self.session.send(socket, "execute_request", content=content, ident=ident,
1317 1317 metadata=metadata)
1318 1318
1319 1319 msg_id = msg['header']['msg_id']
1320 1320 self.outstanding.add(msg_id)
1321 1321 if ident:
1322 1322 # possibly routed to a specific engine
1323 1323 if isinstance(ident, list):
1324 1324 ident = ident[-1]
1325 1325 if ident in self._engines.values():
1326 1326 # save for later, in case of engine death
1327 1327 self._outstanding_dict[ident].add(msg_id)
1328 1328 self.history.append(msg_id)
1329 1329 self.metadata[msg_id]['submitted'] = datetime.now()
1330 1330
1331 1331 return msg
1332 1332
1333 1333 #--------------------------------------------------------------------------
1334 1334 # construct a View object
1335 1335 #--------------------------------------------------------------------------
1336 1336
1337 1337 def load_balanced_view(self, targets=None):
1338 1338 """construct a DirectView object.
1339 1339
1340 1340 If no arguments are specified, create a LoadBalancedView
1341 1341 using all engines.
1342 1342
1343 1343 Parameters
1344 1344 ----------
1345 1345
1346 1346 targets: list,slice,int,etc. [default: use all engines]
1347 1347 The subset of engines across which to load-balance
1348 1348 """
1349 1349 if targets == 'all':
1350 1350 targets = None
1351 1351 if targets is not None:
1352 1352 targets = self._build_targets(targets)[1]
1353 1353 return LoadBalancedView(client=self, socket=self._task_socket, targets=targets)
1354 1354
1355 1355 def direct_view(self, targets='all'):
1356 1356 """construct a DirectView object.
1357 1357
1358 1358 If no targets are specified, create a DirectView using all engines.
1359 1359
1360 1360 rc.direct_view('all') is distinguished from rc[:] in that 'all' will
1361 1361 evaluate the target engines at each execution, whereas rc[:] will connect to
1362 1362 all *current* engines, and that list will not change.
1363 1363
1364 1364 That is, 'all' will always use all engines, whereas rc[:] will not use
1365 1365 engines added after the DirectView is constructed.
1366 1366
1367 1367 Parameters
1368 1368 ----------
1369 1369
1370 1370 targets: list,slice,int,etc. [default: use all engines]
1371 1371 The engines to use for the View
1372 1372 """
1373 1373 single = isinstance(targets, int)
1374 1374 # allow 'all' to be lazily evaluated at each execution
1375 1375 if targets != 'all':
1376 1376 targets = self._build_targets(targets)[1]
1377 1377 if single:
1378 1378 targets = targets[0]
1379 1379 return DirectView(client=self, socket=self._mux_socket, targets=targets)
1380 1380
1381 1381 #--------------------------------------------------------------------------
1382 1382 # Query methods
1383 1383 #--------------------------------------------------------------------------
1384 1384
1385 1385 @spin_first
1386 1386 def get_result(self, indices_or_msg_ids=None, block=None, owner=True):
1387 1387 """Retrieve a result by msg_id or history index, wrapped in an AsyncResult object.
1388 1388
1389 1389 If the client already has the results, no request to the Hub will be made.
1390 1390
1391 1391 This is a convenient way to construct AsyncResult objects, which are wrappers
1392 1392 that include metadata about execution, and allow for awaiting results that
1393 1393 were not submitted by this Client.
1394 1394
1395 1395 It can also be a convenient way to retrieve the metadata associated with
1396 1396 blocking execution, since it always retrieves
1397 1397
1398 1398 Examples
1399 1399 --------
1400 1400 ::
1401 1401
1402 1402 In [10]: r = client.apply()
1403 1403
1404 1404 Parameters
1405 1405 ----------
1406 1406
1407 1407 indices_or_msg_ids : integer history index, str msg_id, or list of either
1408 1408 The indices or msg_ids of indices to be retrieved
1409 1409
1410 1410 block : bool
1411 1411 Whether to wait for the result to be done
1412 1412 owner : bool [default: True]
1413 1413 Whether this AsyncResult should own the result.
1414 1414 If so, calling `ar.get()` will remove data from the
1415 1415 client's result and metadata cache.
1416 1416 There should only be one owner of any given msg_id.
1417 1417
1418 1418 Returns
1419 1419 -------
1420 1420
1421 1421 AsyncResult
1422 1422 A single AsyncResult object will always be returned.
1423 1423
1424 1424 AsyncHubResult
1425 1425 A subclass of AsyncResult that retrieves results from the Hub
1426 1426
1427 1427 """
1428 1428 block = self.block if block is None else block
1429 1429 if indices_or_msg_ids is None:
1430 1430 indices_or_msg_ids = -1
1431 1431
1432 1432 single_result = False
1433 1433 if not isinstance(indices_or_msg_ids, (list,tuple)):
1434 1434 indices_or_msg_ids = [indices_or_msg_ids]
1435 1435 single_result = True
1436 1436
1437 1437 theids = []
1438 1438 for id in indices_or_msg_ids:
1439 1439 if isinstance(id, int):
1440 1440 id = self.history[id]
1441 1441 if not isinstance(id, string_types):
1442 1442 raise TypeError("indices must be str or int, not %r"%id)
1443 1443 theids.append(id)
1444 1444
1445 1445 local_ids = [msg_id for msg_id in theids if (msg_id in self.outstanding or msg_id in self.results)]
1446 1446 remote_ids = [msg_id for msg_id in theids if msg_id not in local_ids]
1447 1447
1448 1448 # given single msg_id initially, get_result shot get the result itself,
1449 1449 # not a length-one list
1450 1450 if single_result:
1451 1451 theids = theids[0]
1452 1452
1453 1453 if remote_ids:
1454 1454 ar = AsyncHubResult(self, msg_ids=theids, owner=owner)
1455 1455 else:
1456 1456 ar = AsyncResult(self, msg_ids=theids, owner=owner)
1457 1457
1458 1458 if block:
1459 1459 ar.wait()
1460 1460
1461 1461 return ar
1462 1462
1463 1463 @spin_first
1464 1464 def resubmit(self, indices_or_msg_ids=None, metadata=None, block=None):
1465 1465 """Resubmit one or more tasks.
1466 1466
1467 1467 in-flight tasks may not be resubmitted.
1468 1468
1469 1469 Parameters
1470 1470 ----------
1471 1471
1472 1472 indices_or_msg_ids : integer history index, str msg_id, or list of either
1473 1473 The indices or msg_ids of indices to be retrieved
1474 1474
1475 1475 block : bool
1476 1476 Whether to wait for the result to be done
1477 1477
1478 1478 Returns
1479 1479 -------
1480 1480
1481 1481 AsyncHubResult
1482 1482 A subclass of AsyncResult that retrieves results from the Hub
1483 1483
1484 1484 """
1485 1485 block = self.block if block is None else block
1486 1486 if indices_or_msg_ids is None:
1487 1487 indices_or_msg_ids = -1
1488 1488
1489 1489 if not isinstance(indices_or_msg_ids, (list,tuple)):
1490 1490 indices_or_msg_ids = [indices_or_msg_ids]
1491 1491
1492 1492 theids = []
1493 1493 for id in indices_or_msg_ids:
1494 1494 if isinstance(id, int):
1495 1495 id = self.history[id]
1496 1496 if not isinstance(id, string_types):
1497 1497 raise TypeError("indices must be str or int, not %r"%id)
1498 1498 theids.append(id)
1499 1499
1500 1500 content = dict(msg_ids = theids)
1501 1501
1502 1502 self.session.send(self._query_socket, 'resubmit_request', content)
1503 1503
1504 1504 zmq.select([self._query_socket], [], [])
1505 1505 idents,msg = self.session.recv(self._query_socket, zmq.NOBLOCK)
1506 1506 if self.debug:
1507 1507 pprint(msg)
1508 1508 content = msg['content']
1509 1509 if content['status'] != 'ok':
1510 1510 raise self._unwrap_exception(content)
1511 1511 mapping = content['resubmitted']
1512 1512 new_ids = [ mapping[msg_id] for msg_id in theids ]
1513 1513
1514 1514 ar = AsyncHubResult(self, msg_ids=new_ids)
1515 1515
1516 1516 if block:
1517 1517 ar.wait()
1518 1518
1519 1519 return ar
1520 1520
1521 1521 @spin_first
1522 1522 def result_status(self, msg_ids, status_only=True):
1523 1523 """Check on the status of the result(s) of the apply request with `msg_ids`.
1524 1524
1525 1525 If status_only is False, then the actual results will be retrieved, else
1526 1526 only the status of the results will be checked.
1527 1527
1528 1528 Parameters
1529 1529 ----------
1530 1530
1531 1531 msg_ids : list of msg_ids
1532 1532 if int:
1533 1533 Passed as index to self.history for convenience.
1534 1534 status_only : bool (default: True)
1535 1535 if False:
1536 1536 Retrieve the actual results of completed tasks.
1537 1537
1538 1538 Returns
1539 1539 -------
1540 1540
1541 1541 results : dict
1542 1542 There will always be the keys 'pending' and 'completed', which will
1543 1543 be lists of msg_ids that are incomplete or complete. If `status_only`
1544 1544 is False, then completed results will be keyed by their `msg_id`.
1545 1545 """
1546 1546 if not isinstance(msg_ids, (list,tuple)):
1547 1547 msg_ids = [msg_ids]
1548 1548
1549 1549 theids = []
1550 1550 for msg_id in msg_ids:
1551 1551 if isinstance(msg_id, int):
1552 1552 msg_id = self.history[msg_id]
1553 1553 if not isinstance(msg_id, string_types):
1554 1554 raise TypeError("msg_ids must be str, not %r"%msg_id)
1555 1555 theids.append(msg_id)
1556 1556
1557 1557 completed = []
1558 1558 local_results = {}
1559 1559
1560 1560 # comment this block out to temporarily disable local shortcut:
1561 1561 for msg_id in theids:
1562 1562 if msg_id in self.results:
1563 1563 completed.append(msg_id)
1564 1564 local_results[msg_id] = self.results[msg_id]
1565 1565 theids.remove(msg_id)
1566 1566
1567 1567 if theids: # some not locally cached
1568 1568 content = dict(msg_ids=theids, status_only=status_only)
1569 1569 msg = self.session.send(self._query_socket, "result_request", content=content)
1570 1570 zmq.select([self._query_socket], [], [])
1571 1571 idents,msg = self.session.recv(self._query_socket, zmq.NOBLOCK)
1572 1572 if self.debug:
1573 1573 pprint(msg)
1574 1574 content = msg['content']
1575 1575 if content['status'] != 'ok':
1576 1576 raise self._unwrap_exception(content)
1577 1577 buffers = msg['buffers']
1578 1578 else:
1579 1579 content = dict(completed=[],pending=[])
1580 1580
1581 1581 content['completed'].extend(completed)
1582 1582
1583 1583 if status_only:
1584 1584 return content
1585 1585
1586 1586 failures = []
1587 1587 # load cached results into result:
1588 1588 content.update(local_results)
1589 1589
1590 1590 # update cache with results:
1591 1591 for msg_id in sorted(theids):
1592 1592 if msg_id in content['completed']:
1593 1593 rec = content[msg_id]
1594 1594 parent = extract_dates(rec['header'])
1595 1595 header = extract_dates(rec['result_header'])
1596 1596 rcontent = rec['result_content']
1597 1597 iodict = rec['io']
1598 1598 if isinstance(rcontent, str):
1599 1599 rcontent = self.session.unpack(rcontent)
1600 1600
1601 1601 md = self.metadata[msg_id]
1602 1602 md_msg = dict(
1603 1603 content=rcontent,
1604 1604 parent_header=parent,
1605 1605 header=header,
1606 1606 metadata=rec['result_metadata'],
1607 1607 )
1608 1608 md.update(self._extract_metadata(md_msg))
1609 1609 if rec.get('received'):
1610 1610 md['received'] = parse_date(rec['received'])
1611 1611 md.update(iodict)
1612 1612
1613 1613 if rcontent['status'] == 'ok':
1614 1614 if header['msg_type'] == 'apply_reply':
1615 res,buffers = serialize.unserialize_object(buffers)
1615 res,buffers = serialize.deserialize_object(buffers)
1616 1616 elif header['msg_type'] == 'execute_reply':
1617 1617 res = ExecuteReply(msg_id, rcontent, md)
1618 1618 else:
1619 1619 raise KeyError("unhandled msg type: %r" % header['msg_type'])
1620 1620 else:
1621 1621 res = self._unwrap_exception(rcontent)
1622 1622 failures.append(res)
1623 1623
1624 1624 self.results[msg_id] = res
1625 1625 content[msg_id] = res
1626 1626
1627 1627 if len(theids) == 1 and failures:
1628 1628 raise failures[0]
1629 1629
1630 1630 error.collect_exceptions(failures, "result_status")
1631 1631 return content
1632 1632
1633 1633 @spin_first
1634 1634 def queue_status(self, targets='all', verbose=False):
1635 1635 """Fetch the status of engine queues.
1636 1636
1637 1637 Parameters
1638 1638 ----------
1639 1639
1640 1640 targets : int/str/list of ints/strs
1641 1641 the engines whose states are to be queried.
1642 1642 default : all
1643 1643 verbose : bool
1644 1644 Whether to return lengths only, or lists of ids for each element
1645 1645 """
1646 1646 if targets == 'all':
1647 1647 # allow 'all' to be evaluated on the engine
1648 1648 engine_ids = None
1649 1649 else:
1650 1650 engine_ids = self._build_targets(targets)[1]
1651 1651 content = dict(targets=engine_ids, verbose=verbose)
1652 1652 self.session.send(self._query_socket, "queue_request", content=content)
1653 1653 idents,msg = self.session.recv(self._query_socket, 0)
1654 1654 if self.debug:
1655 1655 pprint(msg)
1656 1656 content = msg['content']
1657 1657 status = content.pop('status')
1658 1658 if status != 'ok':
1659 1659 raise self._unwrap_exception(content)
1660 1660 content = rekey(content)
1661 1661 if isinstance(targets, int):
1662 1662 return content[targets]
1663 1663 else:
1664 1664 return content
1665 1665
1666 1666 def _build_msgids_from_target(self, targets=None):
1667 1667 """Build a list of msg_ids from the list of engine targets"""
1668 1668 if not targets: # needed as _build_targets otherwise uses all engines
1669 1669 return []
1670 1670 target_ids = self._build_targets(targets)[0]
1671 1671 return [md_id for md_id in self.metadata if self.metadata[md_id]["engine_uuid"] in target_ids]
1672 1672
1673 1673 def _build_msgids_from_jobs(self, jobs=None):
1674 1674 """Build a list of msg_ids from "jobs" """
1675 1675 if not jobs:
1676 1676 return []
1677 1677 msg_ids = []
1678 1678 if isinstance(jobs, string_types + (AsyncResult,)):
1679 1679 jobs = [jobs]
1680 1680 bad_ids = [obj for obj in jobs if not isinstance(obj, string_types + (AsyncResult,))]
1681 1681 if bad_ids:
1682 1682 raise TypeError("Invalid msg_id type %r, expected str or AsyncResult"%bad_ids[0])
1683 1683 for j in jobs:
1684 1684 if isinstance(j, AsyncResult):
1685 1685 msg_ids.extend(j.msg_ids)
1686 1686 else:
1687 1687 msg_ids.append(j)
1688 1688 return msg_ids
1689 1689
1690 1690 def purge_local_results(self, jobs=[], targets=[]):
1691 1691 """Clears the client caches of results and their metadata.
1692 1692
1693 1693 Individual results can be purged by msg_id, or the entire
1694 1694 history of specific targets can be purged.
1695 1695
1696 1696 Use `purge_local_results('all')` to scrub everything from the Clients's
1697 1697 results and metadata caches.
1698 1698
1699 1699 After this call all `AsyncResults` are invalid and should be discarded.
1700 1700
1701 1701 If you must "reget" the results, you can still do so by using
1702 1702 `client.get_result(msg_id)` or `client.get_result(asyncresult)`. This will
1703 1703 redownload the results from the hub if they are still available
1704 1704 (i.e `client.purge_hub_results(...)` has not been called.
1705 1705
1706 1706 Parameters
1707 1707 ----------
1708 1708
1709 1709 jobs : str or list of str or AsyncResult objects
1710 1710 the msg_ids whose results should be purged.
1711 1711 targets : int/list of ints
1712 1712 The engines, by integer ID, whose entire result histories are to be purged.
1713 1713
1714 1714 Raises
1715 1715 ------
1716 1716
1717 1717 RuntimeError : if any of the tasks to be purged are still outstanding.
1718 1718
1719 1719 """
1720 1720 if not targets and not jobs:
1721 1721 raise ValueError("Must specify at least one of `targets` and `jobs`")
1722 1722
1723 1723 if jobs == 'all':
1724 1724 if self.outstanding:
1725 1725 raise RuntimeError("Can't purge outstanding tasks: %s" % self.outstanding)
1726 1726 self.results.clear()
1727 1727 self.metadata.clear()
1728 1728 else:
1729 1729 msg_ids = set()
1730 1730 msg_ids.update(self._build_msgids_from_target(targets))
1731 1731 msg_ids.update(self._build_msgids_from_jobs(jobs))
1732 1732 still_outstanding = self.outstanding.intersection(msg_ids)
1733 1733 if still_outstanding:
1734 1734 raise RuntimeError("Can't purge outstanding tasks: %s" % still_outstanding)
1735 1735 for mid in msg_ids:
1736 1736 self.results.pop(mid, None)
1737 1737 self.metadata.pop(mid, None)
1738 1738
1739 1739
1740 1740 @spin_first
1741 1741 def purge_hub_results(self, jobs=[], targets=[]):
1742 1742 """Tell the Hub to forget results.
1743 1743
1744 1744 Individual results can be purged by msg_id, or the entire
1745 1745 history of specific targets can be purged.
1746 1746
1747 1747 Use `purge_results('all')` to scrub everything from the Hub's db.
1748 1748
1749 1749 Parameters
1750 1750 ----------
1751 1751
1752 1752 jobs : str or list of str or AsyncResult objects
1753 1753 the msg_ids whose results should be forgotten.
1754 1754 targets : int/str/list of ints/strs
1755 1755 The targets, by int_id, whose entire history is to be purged.
1756 1756
1757 1757 default : None
1758 1758 """
1759 1759 if not targets and not jobs:
1760 1760 raise ValueError("Must specify at least one of `targets` and `jobs`")
1761 1761 if targets:
1762 1762 targets = self._build_targets(targets)[1]
1763 1763
1764 1764 # construct msg_ids from jobs
1765 1765 if jobs == 'all':
1766 1766 msg_ids = jobs
1767 1767 else:
1768 1768 msg_ids = self._build_msgids_from_jobs(jobs)
1769 1769
1770 1770 content = dict(engine_ids=targets, msg_ids=msg_ids)
1771 1771 self.session.send(self._query_socket, "purge_request", content=content)
1772 1772 idents, msg = self.session.recv(self._query_socket, 0)
1773 1773 if self.debug:
1774 1774 pprint(msg)
1775 1775 content = msg['content']
1776 1776 if content['status'] != 'ok':
1777 1777 raise self._unwrap_exception(content)
1778 1778
1779 1779 def purge_results(self, jobs=[], targets=[]):
1780 1780 """Clears the cached results from both the hub and the local client
1781 1781
1782 1782 Individual results can be purged by msg_id, or the entire
1783 1783 history of specific targets can be purged.
1784 1784
1785 1785 Use `purge_results('all')` to scrub every cached result from both the Hub's and
1786 1786 the Client's db.
1787 1787
1788 1788 Equivalent to calling both `purge_hub_results()` and `purge_client_results()` with
1789 1789 the same arguments.
1790 1790
1791 1791 Parameters
1792 1792 ----------
1793 1793
1794 1794 jobs : str or list of str or AsyncResult objects
1795 1795 the msg_ids whose results should be forgotten.
1796 1796 targets : int/str/list of ints/strs
1797 1797 The targets, by int_id, whose entire history is to be purged.
1798 1798
1799 1799 default : None
1800 1800 """
1801 1801 self.purge_local_results(jobs=jobs, targets=targets)
1802 1802 self.purge_hub_results(jobs=jobs, targets=targets)
1803 1803
1804 1804 def purge_everything(self):
1805 1805 """Clears all content from previous Tasks from both the hub and the local client
1806 1806
1807 1807 In addition to calling `purge_results("all")` it also deletes the history and
1808 1808 other bookkeeping lists.
1809 1809 """
1810 1810 self.purge_results("all")
1811 1811 self.history = []
1812 1812 self.session.digest_history.clear()
1813 1813
1814 1814 @spin_first
1815 1815 def hub_history(self):
1816 1816 """Get the Hub's history
1817 1817
1818 1818 Just like the Client, the Hub has a history, which is a list of msg_ids.
1819 1819 This will contain the history of all clients, and, depending on configuration,
1820 1820 may contain history across multiple cluster sessions.
1821 1821
1822 1822 Any msg_id returned here is a valid argument to `get_result`.
1823 1823
1824 1824 Returns
1825 1825 -------
1826 1826
1827 1827 msg_ids : list of strs
1828 1828 list of all msg_ids, ordered by task submission time.
1829 1829 """
1830 1830
1831 1831 self.session.send(self._query_socket, "history_request", content={})
1832 1832 idents, msg = self.session.recv(self._query_socket, 0)
1833 1833
1834 1834 if self.debug:
1835 1835 pprint(msg)
1836 1836 content = msg['content']
1837 1837 if content['status'] != 'ok':
1838 1838 raise self._unwrap_exception(content)
1839 1839 else:
1840 1840 return content['history']
1841 1841
1842 1842 @spin_first
1843 1843 def db_query(self, query, keys=None):
1844 1844 """Query the Hub's TaskRecord database
1845 1845
1846 1846 This will return a list of task record dicts that match `query`
1847 1847
1848 1848 Parameters
1849 1849 ----------
1850 1850
1851 1851 query : mongodb query dict
1852 1852 The search dict. See mongodb query docs for details.
1853 1853 keys : list of strs [optional]
1854 1854 The subset of keys to be returned. The default is to fetch everything but buffers.
1855 1855 'msg_id' will *always* be included.
1856 1856 """
1857 1857 if isinstance(keys, string_types):
1858 1858 keys = [keys]
1859 1859 content = dict(query=query, keys=keys)
1860 1860 self.session.send(self._query_socket, "db_request", content=content)
1861 1861 idents, msg = self.session.recv(self._query_socket, 0)
1862 1862 if self.debug:
1863 1863 pprint(msg)
1864 1864 content = msg['content']
1865 1865 if content['status'] != 'ok':
1866 1866 raise self._unwrap_exception(content)
1867 1867
1868 1868 records = content['records']
1869 1869
1870 1870 buffer_lens = content['buffer_lens']
1871 1871 result_buffer_lens = content['result_buffer_lens']
1872 1872 buffers = msg['buffers']
1873 1873 has_bufs = buffer_lens is not None
1874 1874 has_rbufs = result_buffer_lens is not None
1875 1875 for i,rec in enumerate(records):
1876 1876 # unpack datetime objects
1877 1877 for hkey in ('header', 'result_header'):
1878 1878 if hkey in rec:
1879 1879 rec[hkey] = extract_dates(rec[hkey])
1880 1880 for dtkey in ('submitted', 'started', 'completed', 'received'):
1881 1881 if dtkey in rec:
1882 1882 rec[dtkey] = parse_date(rec[dtkey])
1883 1883 # relink buffers
1884 1884 if has_bufs:
1885 1885 blen = buffer_lens[i]
1886 1886 rec['buffers'], buffers = buffers[:blen],buffers[blen:]
1887 1887 if has_rbufs:
1888 1888 blen = result_buffer_lens[i]
1889 1889 rec['result_buffers'], buffers = buffers[:blen],buffers[blen:]
1890 1890
1891 1891 return records
1892 1892
1893 1893 __all__ = [ 'Client' ]
@@ -1,1438 +1,1438
1 1 """The IPython Controller Hub with 0MQ
2 2
3 3 This is the master object that handles connections from engines and clients,
4 4 and monitors traffic through the various queues.
5 5 """
6 6
7 7 # Copyright (c) IPython Development Team.
8 8 # Distributed under the terms of the Modified BSD License.
9 9
10 10 from __future__ import print_function
11 11
12 12 import json
13 13 import os
14 14 import sys
15 15 import time
16 16 from datetime import datetime
17 17
18 18 import zmq
19 19 from zmq.eventloop.zmqstream import ZMQStream
20 20
21 21 # internal:
22 22 from IPython.utils.importstring import import_item
23 23 from IPython.utils.jsonutil import extract_dates
24 24 from IPython.utils.localinterfaces import localhost
25 25 from IPython.utils.py3compat import cast_bytes, unicode_type, iteritems
26 26 from IPython.utils.traitlets import (
27 27 HasTraits, Any, Instance, Integer, Unicode, Dict, Set, Tuple, DottedObjectName
28 28 )
29 29
30 30 from IPython.parallel import error, util
31 31 from IPython.parallel.factory import RegistrationFactory
32 32
33 33 from IPython.kernel.zmq.session import SessionFactory
34 34
35 35 from .heartmonitor import HeartMonitor
36 36
37 37
38 38 def _passer(*args, **kwargs):
39 39 return
40 40
41 41 def _printer(*args, **kwargs):
42 42 print (args)
43 43 print (kwargs)
44 44
45 45 def empty_record():
46 46 """Return an empty dict with all record keys."""
47 47 return {
48 48 'msg_id' : None,
49 49 'header' : None,
50 50 'metadata' : None,
51 51 'content': None,
52 52 'buffers': None,
53 53 'submitted': None,
54 54 'client_uuid' : None,
55 55 'engine_uuid' : None,
56 56 'started': None,
57 57 'completed': None,
58 58 'resubmitted': None,
59 59 'received': None,
60 60 'result_header' : None,
61 61 'result_metadata' : None,
62 62 'result_content' : None,
63 63 'result_buffers' : None,
64 64 'queue' : None,
65 65 'execute_input' : None,
66 66 'execute_result': None,
67 67 'error': None,
68 68 'stdout': '',
69 69 'stderr': '',
70 70 }
71 71
72 72 def init_record(msg):
73 73 """Initialize a TaskRecord based on a request."""
74 74 header = msg['header']
75 75 return {
76 76 'msg_id' : header['msg_id'],
77 77 'header' : header,
78 78 'content': msg['content'],
79 79 'metadata': msg['metadata'],
80 80 'buffers': msg['buffers'],
81 81 'submitted': header['date'],
82 82 'client_uuid' : None,
83 83 'engine_uuid' : None,
84 84 'started': None,
85 85 'completed': None,
86 86 'resubmitted': None,
87 87 'received': None,
88 88 'result_header' : None,
89 89 'result_metadata': None,
90 90 'result_content' : None,
91 91 'result_buffers' : None,
92 92 'queue' : None,
93 93 'execute_input' : None,
94 94 'execute_result': None,
95 95 'error': None,
96 96 'stdout': '',
97 97 'stderr': '',
98 98 }
99 99
100 100
101 101 class EngineConnector(HasTraits):
102 102 """A simple object for accessing the various zmq connections of an object.
103 103 Attributes are:
104 104 id (int): engine ID
105 105 uuid (unicode): engine UUID
106 106 pending: set of msg_ids
107 107 stallback: tornado timeout for stalled registration
108 108 """
109 109
110 110 id = Integer(0)
111 111 uuid = Unicode()
112 112 pending = Set()
113 113 stallback = Any()
114 114
115 115
116 116 _db_shortcuts = {
117 117 'sqlitedb' : 'IPython.parallel.controller.sqlitedb.SQLiteDB',
118 118 'mongodb' : 'IPython.parallel.controller.mongodb.MongoDB',
119 119 'dictdb' : 'IPython.parallel.controller.dictdb.DictDB',
120 120 'nodb' : 'IPython.parallel.controller.dictdb.NoDB',
121 121 }
122 122
123 123 class HubFactory(RegistrationFactory):
124 124 """The Configurable for setting up a Hub."""
125 125
126 126 # port-pairs for monitoredqueues:
127 127 hb = Tuple(Integer,Integer,config=True,
128 128 help="""PUB/ROUTER Port pair for Engine heartbeats""")
129 129 def _hb_default(self):
130 130 return tuple(util.select_random_ports(2))
131 131
132 132 mux = Tuple(Integer,Integer,config=True,
133 133 help="""Client/Engine Port pair for MUX queue""")
134 134
135 135 def _mux_default(self):
136 136 return tuple(util.select_random_ports(2))
137 137
138 138 task = Tuple(Integer,Integer,config=True,
139 139 help="""Client/Engine Port pair for Task queue""")
140 140 def _task_default(self):
141 141 return tuple(util.select_random_ports(2))
142 142
143 143 control = Tuple(Integer,Integer,config=True,
144 144 help="""Client/Engine Port pair for Control queue""")
145 145
146 146 def _control_default(self):
147 147 return tuple(util.select_random_ports(2))
148 148
149 149 iopub = Tuple(Integer,Integer,config=True,
150 150 help="""Client/Engine Port pair for IOPub relay""")
151 151
152 152 def _iopub_default(self):
153 153 return tuple(util.select_random_ports(2))
154 154
155 155 # single ports:
156 156 mon_port = Integer(config=True,
157 157 help="""Monitor (SUB) port for queue traffic""")
158 158
159 159 def _mon_port_default(self):
160 160 return util.select_random_ports(1)[0]
161 161
162 162 notifier_port = Integer(config=True,
163 163 help="""PUB port for sending engine status notifications""")
164 164
165 165 def _notifier_port_default(self):
166 166 return util.select_random_ports(1)[0]
167 167
168 168 engine_ip = Unicode(config=True,
169 169 help="IP on which to listen for engine connections. [default: loopback]")
170 170 def _engine_ip_default(self):
171 171 return localhost()
172 172 engine_transport = Unicode('tcp', config=True,
173 173 help="0MQ transport for engine connections. [default: tcp]")
174 174
175 175 client_ip = Unicode(config=True,
176 176 help="IP on which to listen for client connections. [default: loopback]")
177 177 client_transport = Unicode('tcp', config=True,
178 178 help="0MQ transport for client connections. [default : tcp]")
179 179
180 180 monitor_ip = Unicode(config=True,
181 181 help="IP on which to listen for monitor messages. [default: loopback]")
182 182 monitor_transport = Unicode('tcp', config=True,
183 183 help="0MQ transport for monitor messages. [default : tcp]")
184 184
185 185 _client_ip_default = _monitor_ip_default = _engine_ip_default
186 186
187 187
188 188 monitor_url = Unicode('')
189 189
190 190 db_class = DottedObjectName('NoDB',
191 191 config=True, help="""The class to use for the DB backend
192 192
193 193 Options include:
194 194
195 195 SQLiteDB: SQLite
196 196 MongoDB : use MongoDB
197 197 DictDB : in-memory storage (fastest, but be mindful of memory growth of the Hub)
198 198 NoDB : disable database altogether (default)
199 199
200 200 """)
201 201
202 202 registration_timeout = Integer(0, config=True,
203 203 help="Engine registration timeout in seconds [default: max(30,"
204 204 "10*heartmonitor.period)]" )
205 205
206 206 def _registration_timeout_default(self):
207 207 if self.heartmonitor is None:
208 208 # early initialization, this value will be ignored
209 209 return 0
210 210 # heartmonitor period is in milliseconds, so 10x in seconds is .01
211 211 return max(30, int(.01 * self.heartmonitor.period))
212 212
213 213 # not configurable
214 214 db = Instance('IPython.parallel.controller.dictdb.BaseDB')
215 215 heartmonitor = Instance('IPython.parallel.controller.heartmonitor.HeartMonitor')
216 216
217 217 def _ip_changed(self, name, old, new):
218 218 self.engine_ip = new
219 219 self.client_ip = new
220 220 self.monitor_ip = new
221 221 self._update_monitor_url()
222 222
223 223 def _update_monitor_url(self):
224 224 self.monitor_url = "%s://%s:%i" % (self.monitor_transport, self.monitor_ip, self.mon_port)
225 225
226 226 def _transport_changed(self, name, old, new):
227 227 self.engine_transport = new
228 228 self.client_transport = new
229 229 self.monitor_transport = new
230 230 self._update_monitor_url()
231 231
232 232 def __init__(self, **kwargs):
233 233 super(HubFactory, self).__init__(**kwargs)
234 234 self._update_monitor_url()
235 235
236 236
237 237 def construct(self):
238 238 self.init_hub()
239 239
240 240 def start(self):
241 241 self.heartmonitor.start()
242 242 self.log.info("Heartmonitor started")
243 243
244 244 def client_url(self, channel):
245 245 """return full zmq url for a named client channel"""
246 246 return "%s://%s:%i" % (self.client_transport, self.client_ip, self.client_info[channel])
247 247
248 248 def engine_url(self, channel):
249 249 """return full zmq url for a named engine channel"""
250 250 return "%s://%s:%i" % (self.engine_transport, self.engine_ip, self.engine_info[channel])
251 251
252 252 def init_hub(self):
253 253 """construct Hub object"""
254 254
255 255 ctx = self.context
256 256 loop = self.loop
257 257 if 'TaskScheduler.scheme_name' in self.config:
258 258 scheme = self.config.TaskScheduler.scheme_name
259 259 else:
260 260 from .scheduler import TaskScheduler
261 261 scheme = TaskScheduler.scheme_name.get_default_value()
262 262
263 263 # build connection dicts
264 264 engine = self.engine_info = {
265 265 'interface' : "%s://%s" % (self.engine_transport, self.engine_ip),
266 266 'registration' : self.regport,
267 267 'control' : self.control[1],
268 268 'mux' : self.mux[1],
269 269 'hb_ping' : self.hb[0],
270 270 'hb_pong' : self.hb[1],
271 271 'task' : self.task[1],
272 272 'iopub' : self.iopub[1],
273 273 }
274 274
275 275 client = self.client_info = {
276 276 'interface' : "%s://%s" % (self.client_transport, self.client_ip),
277 277 'registration' : self.regport,
278 278 'control' : self.control[0],
279 279 'mux' : self.mux[0],
280 280 'task' : self.task[0],
281 281 'task_scheme' : scheme,
282 282 'iopub' : self.iopub[0],
283 283 'notification' : self.notifier_port,
284 284 }
285 285
286 286 self.log.debug("Hub engine addrs: %s", self.engine_info)
287 287 self.log.debug("Hub client addrs: %s", self.client_info)
288 288
289 289 # Registrar socket
290 290 q = ZMQStream(ctx.socket(zmq.ROUTER), loop)
291 291 util.set_hwm(q, 0)
292 292 q.bind(self.client_url('registration'))
293 293 self.log.info("Hub listening on %s for registration.", self.client_url('registration'))
294 294 if self.client_ip != self.engine_ip:
295 295 q.bind(self.engine_url('registration'))
296 296 self.log.info("Hub listening on %s for registration.", self.engine_url('registration'))
297 297
298 298 ### Engine connections ###
299 299
300 300 # heartbeat
301 301 hpub = ctx.socket(zmq.PUB)
302 302 hpub.bind(self.engine_url('hb_ping'))
303 303 hrep = ctx.socket(zmq.ROUTER)
304 304 util.set_hwm(hrep, 0)
305 305 hrep.bind(self.engine_url('hb_pong'))
306 306 self.heartmonitor = HeartMonitor(loop=loop, parent=self, log=self.log,
307 307 pingstream=ZMQStream(hpub,loop),
308 308 pongstream=ZMQStream(hrep,loop)
309 309 )
310 310
311 311 ### Client connections ###
312 312
313 313 # Notifier socket
314 314 n = ZMQStream(ctx.socket(zmq.PUB), loop)
315 315 n.bind(self.client_url('notification'))
316 316
317 317 ### build and launch the queues ###
318 318
319 319 # monitor socket
320 320 sub = ctx.socket(zmq.SUB)
321 321 sub.setsockopt(zmq.SUBSCRIBE, b"")
322 322 sub.bind(self.monitor_url)
323 323 sub.bind('inproc://monitor')
324 324 sub = ZMQStream(sub, loop)
325 325
326 326 # connect the db
327 327 db_class = _db_shortcuts.get(self.db_class.lower(), self.db_class)
328 328 self.log.info('Hub using DB backend: %r', (db_class.split('.')[-1]))
329 329 self.db = import_item(str(db_class))(session=self.session.session,
330 330 parent=self, log=self.log)
331 331 time.sleep(.25)
332 332
333 333 # resubmit stream
334 334 r = ZMQStream(ctx.socket(zmq.DEALER), loop)
335 335 url = util.disambiguate_url(self.client_url('task'))
336 336 r.connect(url)
337 337
338 338 self.hub = Hub(loop=loop, session=self.session, monitor=sub, heartmonitor=self.heartmonitor,
339 339 query=q, notifier=n, resubmit=r, db=self.db,
340 340 engine_info=self.engine_info, client_info=self.client_info,
341 341 log=self.log, registration_timeout=self.registration_timeout)
342 342
343 343
344 344 class Hub(SessionFactory):
345 345 """The IPython Controller Hub with 0MQ connections
346 346
347 347 Parameters
348 348 ==========
349 349 loop: zmq IOLoop instance
350 350 session: Session object
351 351 <removed> context: zmq context for creating new connections (?)
352 352 queue: ZMQStream for monitoring the command queue (SUB)
353 353 query: ZMQStream for engine registration and client queries requests (ROUTER)
354 354 heartbeat: HeartMonitor object checking the pulse of the engines
355 355 notifier: ZMQStream for broadcasting engine registration changes (PUB)
356 356 db: connection to db for out of memory logging of commands
357 357 NotImplemented
358 358 engine_info: dict of zmq connection information for engines to connect
359 359 to the queues.
360 360 client_info: dict of zmq connection information for engines to connect
361 361 to the queues.
362 362 """
363 363
364 364 engine_state_file = Unicode()
365 365
366 366 # internal data structures:
367 367 ids=Set() # engine IDs
368 368 keytable=Dict()
369 369 by_ident=Dict()
370 370 engines=Dict()
371 371 clients=Dict()
372 372 hearts=Dict()
373 373 pending=Set()
374 374 queues=Dict() # pending msg_ids keyed by engine_id
375 375 tasks=Dict() # pending msg_ids submitted as tasks, keyed by client_id
376 376 completed=Dict() # completed msg_ids keyed by engine_id
377 377 all_completed=Set() # completed msg_ids keyed by engine_id
378 378 dead_engines=Set() # completed msg_ids keyed by engine_id
379 379 unassigned=Set() # set of task msg_ds not yet assigned a destination
380 380 incoming_registrations=Dict()
381 381 registration_timeout=Integer()
382 382 _idcounter=Integer(0)
383 383
384 384 # objects from constructor:
385 385 query=Instance(ZMQStream)
386 386 monitor=Instance(ZMQStream)
387 387 notifier=Instance(ZMQStream)
388 388 resubmit=Instance(ZMQStream)
389 389 heartmonitor=Instance(HeartMonitor)
390 390 db=Instance(object)
391 391 client_info=Dict()
392 392 engine_info=Dict()
393 393
394 394
395 395 def __init__(self, **kwargs):
396 396 """
397 397 # universal:
398 398 loop: IOLoop for creating future connections
399 399 session: streamsession for sending serialized data
400 400 # engine:
401 401 queue: ZMQStream for monitoring queue messages
402 402 query: ZMQStream for engine+client registration and client requests
403 403 heartbeat: HeartMonitor object for tracking engines
404 404 # extra:
405 405 db: ZMQStream for db connection (NotImplemented)
406 406 engine_info: zmq address/protocol dict for engine connections
407 407 client_info: zmq address/protocol dict for client connections
408 408 """
409 409
410 410 super(Hub, self).__init__(**kwargs)
411 411
412 412 # register our callbacks
413 413 self.query.on_recv(self.dispatch_query)
414 414 self.monitor.on_recv(self.dispatch_monitor_traffic)
415 415
416 416 self.heartmonitor.add_heart_failure_handler(self.handle_heart_failure)
417 417 self.heartmonitor.add_new_heart_handler(self.handle_new_heart)
418 418
419 419 self.monitor_handlers = {b'in' : self.save_queue_request,
420 420 b'out': self.save_queue_result,
421 421 b'intask': self.save_task_request,
422 422 b'outtask': self.save_task_result,
423 423 b'tracktask': self.save_task_destination,
424 424 b'incontrol': _passer,
425 425 b'outcontrol': _passer,
426 426 b'iopub': self.save_iopub_message,
427 427 }
428 428
429 429 self.query_handlers = {'queue_request': self.queue_status,
430 430 'result_request': self.get_results,
431 431 'history_request': self.get_history,
432 432 'db_request': self.db_query,
433 433 'purge_request': self.purge_results,
434 434 'load_request': self.check_load,
435 435 'resubmit_request': self.resubmit_task,
436 436 'shutdown_request': self.shutdown_request,
437 437 'registration_request' : self.register_engine,
438 438 'unregistration_request' : self.unregister_engine,
439 439 'connection_request': self.connection_request,
440 440 }
441 441
442 442 # ignore resubmit replies
443 443 self.resubmit.on_recv(lambda msg: None, copy=False)
444 444
445 445 self.log.info("hub::created hub")
446 446
447 447 @property
448 448 def _next_id(self):
449 449 """gemerate a new ID.
450 450
451 451 No longer reuse old ids, just count from 0."""
452 452 newid = self._idcounter
453 453 self._idcounter += 1
454 454 return newid
455 455 # newid = 0
456 456 # incoming = [id[0] for id in itervalues(self.incoming_registrations)]
457 457 # # print newid, self.ids, self.incoming_registrations
458 458 # while newid in self.ids or newid in incoming:
459 459 # newid += 1
460 460 # return newid
461 461
462 462 #-----------------------------------------------------------------------------
463 463 # message validation
464 464 #-----------------------------------------------------------------------------
465 465
466 466 def _validate_targets(self, targets):
467 467 """turn any valid targets argument into a list of integer ids"""
468 468 if targets is None:
469 469 # default to all
470 470 return self.ids
471 471
472 472 if isinstance(targets, (int,str,unicode_type)):
473 473 # only one target specified
474 474 targets = [targets]
475 475 _targets = []
476 476 for t in targets:
477 477 # map raw identities to ids
478 478 if isinstance(t, (str,unicode_type)):
479 479 t = self.by_ident.get(cast_bytes(t), t)
480 480 _targets.append(t)
481 481 targets = _targets
482 482 bad_targets = [ t for t in targets if t not in self.ids ]
483 483 if bad_targets:
484 484 raise IndexError("No Such Engine: %r" % bad_targets)
485 485 if not targets:
486 486 raise IndexError("No Engines Registered")
487 487 return targets
488 488
489 489 #-----------------------------------------------------------------------------
490 490 # dispatch methods (1 per stream)
491 491 #-----------------------------------------------------------------------------
492 492
493 493
494 494 @util.log_errors
495 495 def dispatch_monitor_traffic(self, msg):
496 496 """all ME and Task queue messages come through here, as well as
497 497 IOPub traffic."""
498 498 self.log.debug("monitor traffic: %r", msg[0])
499 499 switch = msg[0]
500 500 try:
501 501 idents, msg = self.session.feed_identities(msg[1:])
502 502 except ValueError:
503 503 idents=[]
504 504 if not idents:
505 505 self.log.error("Monitor message without topic: %r", msg)
506 506 return
507 507 handler = self.monitor_handlers.get(switch, None)
508 508 if handler is not None:
509 509 handler(idents, msg)
510 510 else:
511 511 self.log.error("Unrecognized monitor topic: %r", switch)
512 512
513 513
514 514 @util.log_errors
515 515 def dispatch_query(self, msg):
516 516 """Route registration requests and queries from clients."""
517 517 try:
518 518 idents, msg = self.session.feed_identities(msg)
519 519 except ValueError:
520 520 idents = []
521 521 if not idents:
522 522 self.log.error("Bad Query Message: %r", msg)
523 523 return
524 524 client_id = idents[0]
525 525 try:
526 msg = self.session.unserialize(msg, content=True)
526 msg = self.session.deserialize(msg, content=True)
527 527 except Exception:
528 528 content = error.wrap_exception()
529 529 self.log.error("Bad Query Message: %r", msg, exc_info=True)
530 530 self.session.send(self.query, "hub_error", ident=client_id,
531 531 content=content)
532 532 return
533 533 # print client_id, header, parent, content
534 534 #switch on message type:
535 535 msg_type = msg['header']['msg_type']
536 536 self.log.info("client::client %r requested %r", client_id, msg_type)
537 537 handler = self.query_handlers.get(msg_type, None)
538 538 try:
539 539 assert handler is not None, "Bad Message Type: %r" % msg_type
540 540 except:
541 541 content = error.wrap_exception()
542 542 self.log.error("Bad Message Type: %r", msg_type, exc_info=True)
543 543 self.session.send(self.query, "hub_error", ident=client_id,
544 544 content=content)
545 545 return
546 546
547 547 else:
548 548 handler(idents, msg)
549 549
550 550 def dispatch_db(self, msg):
551 551 """"""
552 552 raise NotImplementedError
553 553
554 554 #---------------------------------------------------------------------------
555 555 # handler methods (1 per event)
556 556 #---------------------------------------------------------------------------
557 557
558 558 #----------------------- Heartbeat --------------------------------------
559 559
560 560 def handle_new_heart(self, heart):
561 561 """handler to attach to heartbeater.
562 562 Called when a new heart starts to beat.
563 563 Triggers completion of registration."""
564 564 self.log.debug("heartbeat::handle_new_heart(%r)", heart)
565 565 if heart not in self.incoming_registrations:
566 566 self.log.info("heartbeat::ignoring new heart: %r", heart)
567 567 else:
568 568 self.finish_registration(heart)
569 569
570 570
571 571 def handle_heart_failure(self, heart):
572 572 """handler to attach to heartbeater.
573 573 called when a previously registered heart fails to respond to beat request.
574 574 triggers unregistration"""
575 575 self.log.debug("heartbeat::handle_heart_failure(%r)", heart)
576 576 eid = self.hearts.get(heart, None)
577 577 uuid = self.engines[eid].uuid
578 578 if eid is None or self.keytable[eid] in self.dead_engines:
579 579 self.log.info("heartbeat::ignoring heart failure %r (not an engine or already dead)", heart)
580 580 else:
581 581 self.unregister_engine(heart, dict(content=dict(id=eid, queue=uuid)))
582 582
583 583 #----------------------- MUX Queue Traffic ------------------------------
584 584
585 585 def save_queue_request(self, idents, msg):
586 586 if len(idents) < 2:
587 587 self.log.error("invalid identity prefix: %r", idents)
588 588 return
589 589 queue_id, client_id = idents[:2]
590 590 try:
591 msg = self.session.unserialize(msg)
591 msg = self.session.deserialize(msg)
592 592 except Exception:
593 593 self.log.error("queue::client %r sent invalid message to %r: %r", client_id, queue_id, msg, exc_info=True)
594 594 return
595 595
596 596 eid = self.by_ident.get(queue_id, None)
597 597 if eid is None:
598 598 self.log.error("queue::target %r not registered", queue_id)
599 599 self.log.debug("queue:: valid are: %r", self.by_ident.keys())
600 600 return
601 601 record = init_record(msg)
602 602 msg_id = record['msg_id']
603 603 self.log.info("queue::client %r submitted request %r to %s", client_id, msg_id, eid)
604 604 # Unicode in records
605 605 record['engine_uuid'] = queue_id.decode('ascii')
606 606 record['client_uuid'] = msg['header']['session']
607 607 record['queue'] = 'mux'
608 608
609 609 try:
610 610 # it's posible iopub arrived first:
611 611 existing = self.db.get_record(msg_id)
612 612 for key,evalue in iteritems(existing):
613 613 rvalue = record.get(key, None)
614 614 if evalue and rvalue and evalue != rvalue:
615 615 self.log.warn("conflicting initial state for record: %r:%r <%r> %r", msg_id, rvalue, key, evalue)
616 616 elif evalue and not rvalue:
617 617 record[key] = evalue
618 618 try:
619 619 self.db.update_record(msg_id, record)
620 620 except Exception:
621 621 self.log.error("DB Error updating record %r", msg_id, exc_info=True)
622 622 except KeyError:
623 623 try:
624 624 self.db.add_record(msg_id, record)
625 625 except Exception:
626 626 self.log.error("DB Error adding record %r", msg_id, exc_info=True)
627 627
628 628
629 629 self.pending.add(msg_id)
630 630 self.queues[eid].append(msg_id)
631 631
632 632 def save_queue_result(self, idents, msg):
633 633 if len(idents) < 2:
634 634 self.log.error("invalid identity prefix: %r", idents)
635 635 return
636 636
637 637 client_id, queue_id = idents[:2]
638 638 try:
639 msg = self.session.unserialize(msg)
639 msg = self.session.deserialize(msg)
640 640 except Exception:
641 641 self.log.error("queue::engine %r sent invalid message to %r: %r",
642 642 queue_id, client_id, msg, exc_info=True)
643 643 return
644 644
645 645 eid = self.by_ident.get(queue_id, None)
646 646 if eid is None:
647 647 self.log.error("queue::unknown engine %r is sending a reply: ", queue_id)
648 648 return
649 649
650 650 parent = msg['parent_header']
651 651 if not parent:
652 652 return
653 653 msg_id = parent['msg_id']
654 654 if msg_id in self.pending:
655 655 self.pending.remove(msg_id)
656 656 self.all_completed.add(msg_id)
657 657 self.queues[eid].remove(msg_id)
658 658 self.completed[eid].append(msg_id)
659 659 self.log.info("queue::request %r completed on %s", msg_id, eid)
660 660 elif msg_id not in self.all_completed:
661 661 # it could be a result from a dead engine that died before delivering the
662 662 # result
663 663 self.log.warn("queue:: unknown msg finished %r", msg_id)
664 664 return
665 665 # update record anyway, because the unregistration could have been premature
666 666 rheader = msg['header']
667 667 md = msg['metadata']
668 668 completed = rheader['date']
669 669 started = extract_dates(md.get('started', None))
670 670 result = {
671 671 'result_header' : rheader,
672 672 'result_metadata': md,
673 673 'result_content': msg['content'],
674 674 'received': datetime.now(),
675 675 'started' : started,
676 676 'completed' : completed
677 677 }
678 678
679 679 result['result_buffers'] = msg['buffers']
680 680 try:
681 681 self.db.update_record(msg_id, result)
682 682 except Exception:
683 683 self.log.error("DB Error updating record %r", msg_id, exc_info=True)
684 684
685 685
686 686 #--------------------- Task Queue Traffic ------------------------------
687 687
688 688 def save_task_request(self, idents, msg):
689 689 """Save the submission of a task."""
690 690 client_id = idents[0]
691 691
692 692 try:
693 msg = self.session.unserialize(msg)
693 msg = self.session.deserialize(msg)
694 694 except Exception:
695 695 self.log.error("task::client %r sent invalid task message: %r",
696 696 client_id, msg, exc_info=True)
697 697 return
698 698 record = init_record(msg)
699 699
700 700 record['client_uuid'] = msg['header']['session']
701 701 record['queue'] = 'task'
702 702 header = msg['header']
703 703 msg_id = header['msg_id']
704 704 self.pending.add(msg_id)
705 705 self.unassigned.add(msg_id)
706 706 try:
707 707 # it's posible iopub arrived first:
708 708 existing = self.db.get_record(msg_id)
709 709 if existing['resubmitted']:
710 710 for key in ('submitted', 'client_uuid', 'buffers'):
711 711 # don't clobber these keys on resubmit
712 712 # submitted and client_uuid should be different
713 713 # and buffers might be big, and shouldn't have changed
714 714 record.pop(key)
715 715 # still check content,header which should not change
716 716 # but are not expensive to compare as buffers
717 717
718 718 for key,evalue in iteritems(existing):
719 719 if key.endswith('buffers'):
720 720 # don't compare buffers
721 721 continue
722 722 rvalue = record.get(key, None)
723 723 if evalue and rvalue and evalue != rvalue:
724 724 self.log.warn("conflicting initial state for record: %r:%r <%r> %r", msg_id, rvalue, key, evalue)
725 725 elif evalue and not rvalue:
726 726 record[key] = evalue
727 727 try:
728 728 self.db.update_record(msg_id, record)
729 729 except Exception:
730 730 self.log.error("DB Error updating record %r", msg_id, exc_info=True)
731 731 except KeyError:
732 732 try:
733 733 self.db.add_record(msg_id, record)
734 734 except Exception:
735 735 self.log.error("DB Error adding record %r", msg_id, exc_info=True)
736 736 except Exception:
737 737 self.log.error("DB Error saving task request %r", msg_id, exc_info=True)
738 738
739 739 def save_task_result(self, idents, msg):
740 740 """save the result of a completed task."""
741 741 client_id = idents[0]
742 742 try:
743 msg = self.session.unserialize(msg)
743 msg = self.session.deserialize(msg)
744 744 except Exception:
745 745 self.log.error("task::invalid task result message send to %r: %r",
746 746 client_id, msg, exc_info=True)
747 747 return
748 748
749 749 parent = msg['parent_header']
750 750 if not parent:
751 751 # print msg
752 752 self.log.warn("Task %r had no parent!", msg)
753 753 return
754 754 msg_id = parent['msg_id']
755 755 if msg_id in self.unassigned:
756 756 self.unassigned.remove(msg_id)
757 757
758 758 header = msg['header']
759 759 md = msg['metadata']
760 760 engine_uuid = md.get('engine', u'')
761 761 eid = self.by_ident.get(cast_bytes(engine_uuid), None)
762 762
763 763 status = md.get('status', None)
764 764
765 765 if msg_id in self.pending:
766 766 self.log.info("task::task %r finished on %s", msg_id, eid)
767 767 self.pending.remove(msg_id)
768 768 self.all_completed.add(msg_id)
769 769 if eid is not None:
770 770 if status != 'aborted':
771 771 self.completed[eid].append(msg_id)
772 772 if msg_id in self.tasks[eid]:
773 773 self.tasks[eid].remove(msg_id)
774 774 completed = header['date']
775 775 started = extract_dates(md.get('started', None))
776 776 result = {
777 777 'result_header' : header,
778 778 'result_metadata': msg['metadata'],
779 779 'result_content': msg['content'],
780 780 'started' : started,
781 781 'completed' : completed,
782 782 'received' : datetime.now(),
783 783 'engine_uuid': engine_uuid,
784 784 }
785 785
786 786 result['result_buffers'] = msg['buffers']
787 787 try:
788 788 self.db.update_record(msg_id, result)
789 789 except Exception:
790 790 self.log.error("DB Error saving task request %r", msg_id, exc_info=True)
791 791
792 792 else:
793 793 self.log.debug("task::unknown task %r finished", msg_id)
794 794
795 795 def save_task_destination(self, idents, msg):
796 796 try:
797 msg = self.session.unserialize(msg, content=True)
797 msg = self.session.deserialize(msg, content=True)
798 798 except Exception:
799 799 self.log.error("task::invalid task tracking message", exc_info=True)
800 800 return
801 801 content = msg['content']
802 802 # print (content)
803 803 msg_id = content['msg_id']
804 804 engine_uuid = content['engine_id']
805 805 eid = self.by_ident[cast_bytes(engine_uuid)]
806 806
807 807 self.log.info("task::task %r arrived on %r", msg_id, eid)
808 808 if msg_id in self.unassigned:
809 809 self.unassigned.remove(msg_id)
810 810 # else:
811 811 # self.log.debug("task::task %r not listed as MIA?!"%(msg_id))
812 812
813 813 self.tasks[eid].append(msg_id)
814 814 # self.pending[msg_id][1].update(received=datetime.now(),engine=(eid,engine_uuid))
815 815 try:
816 816 self.db.update_record(msg_id, dict(engine_uuid=engine_uuid))
817 817 except Exception:
818 818 self.log.error("DB Error saving task destination %r", msg_id, exc_info=True)
819 819
820 820
821 821 def mia_task_request(self, idents, msg):
822 822 raise NotImplementedError
823 823 client_id = idents[0]
824 824 # content = dict(mia=self.mia,status='ok')
825 825 # self.session.send('mia_reply', content=content, idents=client_id)
826 826
827 827
828 828 #--------------------- IOPub Traffic ------------------------------
829 829
830 830 def save_iopub_message(self, topics, msg):
831 831 """save an iopub message into the db"""
832 832 # print (topics)
833 833 try:
834 msg = self.session.unserialize(msg, content=True)
834 msg = self.session.deserialize(msg, content=True)
835 835 except Exception:
836 836 self.log.error("iopub::invalid IOPub message", exc_info=True)
837 837 return
838 838
839 839 parent = msg['parent_header']
840 840 if not parent:
841 841 self.log.debug("iopub::IOPub message lacks parent: %r", msg)
842 842 return
843 843 msg_id = parent['msg_id']
844 844 msg_type = msg['header']['msg_type']
845 845 content = msg['content']
846 846
847 847 # ensure msg_id is in db
848 848 try:
849 849 rec = self.db.get_record(msg_id)
850 850 except KeyError:
851 851 rec = None
852 852
853 853 # stream
854 854 d = {}
855 855 if msg_type == 'stream':
856 856 name = content['name']
857 857 s = '' if rec is None else rec[name]
858 858 d[name] = s + content['text']
859 859
860 860 elif msg_type == 'error':
861 861 d['error'] = content
862 862 elif msg_type == 'execute_input':
863 863 d['execute_input'] = content['code']
864 864 elif msg_type in ('display_data', 'execute_result'):
865 865 d[msg_type] = content
866 866 elif msg_type == 'status':
867 867 pass
868 868 elif msg_type == 'data_pub':
869 869 self.log.info("ignored data_pub message for %s" % msg_id)
870 870 else:
871 871 self.log.warn("unhandled iopub msg_type: %r", msg_type)
872 872
873 873 if not d:
874 874 return
875 875
876 876 if rec is None:
877 877 # new record
878 878 rec = empty_record()
879 879 rec['msg_id'] = msg_id
880 880 rec.update(d)
881 881 d = rec
882 882 update_record = self.db.add_record
883 883 else:
884 884 update_record = self.db.update_record
885 885
886 886 try:
887 887 update_record(msg_id, d)
888 888 except Exception:
889 889 self.log.error("DB Error saving iopub message %r", msg_id, exc_info=True)
890 890
891 891
892 892
893 893 #-------------------------------------------------------------------------
894 894 # Registration requests
895 895 #-------------------------------------------------------------------------
896 896
897 897 def connection_request(self, client_id, msg):
898 898 """Reply with connection addresses for clients."""
899 899 self.log.info("client::client %r connected", client_id)
900 900 content = dict(status='ok')
901 901 jsonable = {}
902 902 for k,v in iteritems(self.keytable):
903 903 if v not in self.dead_engines:
904 904 jsonable[str(k)] = v
905 905 content['engines'] = jsonable
906 906 self.session.send(self.query, 'connection_reply', content, parent=msg, ident=client_id)
907 907
908 908 def register_engine(self, reg, msg):
909 909 """Register a new engine."""
910 910 content = msg['content']
911 911 try:
912 912 uuid = content['uuid']
913 913 except KeyError:
914 914 self.log.error("registration::queue not specified", exc_info=True)
915 915 return
916 916
917 917 eid = self._next_id
918 918
919 919 self.log.debug("registration::register_engine(%i, %r)", eid, uuid)
920 920
921 921 content = dict(id=eid,status='ok',hb_period=self.heartmonitor.period)
922 922 # check if requesting available IDs:
923 923 if cast_bytes(uuid) in self.by_ident:
924 924 try:
925 925 raise KeyError("uuid %r in use" % uuid)
926 926 except:
927 927 content = error.wrap_exception()
928 928 self.log.error("uuid %r in use", uuid, exc_info=True)
929 929 else:
930 930 for h, ec in iteritems(self.incoming_registrations):
931 931 if uuid == h:
932 932 try:
933 933 raise KeyError("heart_id %r in use" % uuid)
934 934 except:
935 935 self.log.error("heart_id %r in use", uuid, exc_info=True)
936 936 content = error.wrap_exception()
937 937 break
938 938 elif uuid == ec.uuid:
939 939 try:
940 940 raise KeyError("uuid %r in use" % uuid)
941 941 except:
942 942 self.log.error("uuid %r in use", uuid, exc_info=True)
943 943 content = error.wrap_exception()
944 944 break
945 945
946 946 msg = self.session.send(self.query, "registration_reply",
947 947 content=content,
948 948 ident=reg)
949 949
950 950 heart = cast_bytes(uuid)
951 951
952 952 if content['status'] == 'ok':
953 953 if heart in self.heartmonitor.hearts:
954 954 # already beating
955 955 self.incoming_registrations[heart] = EngineConnector(id=eid,uuid=uuid)
956 956 self.finish_registration(heart)
957 957 else:
958 958 purge = lambda : self._purge_stalled_registration(heart)
959 959 t = self.loop.add_timeout(
960 960 self.loop.time() + self.registration_timeout,
961 961 purge,
962 962 )
963 963 self.incoming_registrations[heart] = EngineConnector(id=eid,uuid=uuid,stallback=t)
964 964 else:
965 965 self.log.error("registration::registration %i failed: %r", eid, content['evalue'])
966 966
967 967 return eid
968 968
969 969 def unregister_engine(self, ident, msg):
970 970 """Unregister an engine that explicitly requested to leave."""
971 971 try:
972 972 eid = msg['content']['id']
973 973 except:
974 974 self.log.error("registration::bad engine id for unregistration: %r", ident, exc_info=True)
975 975 return
976 976 self.log.info("registration::unregister_engine(%r)", eid)
977 977
978 978 uuid = self.keytable[eid]
979 979 content=dict(id=eid, uuid=uuid)
980 980 self.dead_engines.add(uuid)
981 981
982 982 self.loop.add_timeout(
983 983 self.loop.time() + self.registration_timeout,
984 984 lambda : self._handle_stranded_msgs(eid, uuid),
985 985 )
986 986 ############## TODO: HANDLE IT ################
987 987
988 988 self._save_engine_state()
989 989
990 990 if self.notifier:
991 991 self.session.send(self.notifier, "unregistration_notification", content=content)
992 992
993 993 def _handle_stranded_msgs(self, eid, uuid):
994 994 """Handle messages known to be on an engine when the engine unregisters.
995 995
996 996 It is possible that this will fire prematurely - that is, an engine will
997 997 go down after completing a result, and the client will be notified
998 998 that the result failed and later receive the actual result.
999 999 """
1000 1000
1001 1001 outstanding = self.queues[eid]
1002 1002
1003 1003 for msg_id in outstanding:
1004 1004 self.pending.remove(msg_id)
1005 1005 self.all_completed.add(msg_id)
1006 1006 try:
1007 1007 raise error.EngineError("Engine %r died while running task %r" % (eid, msg_id))
1008 1008 except:
1009 1009 content = error.wrap_exception()
1010 1010 # build a fake header:
1011 1011 header = {}
1012 1012 header['engine'] = uuid
1013 1013 header['date'] = datetime.now()
1014 1014 rec = dict(result_content=content, result_header=header, result_buffers=[])
1015 1015 rec['completed'] = header['date']
1016 1016 rec['engine_uuid'] = uuid
1017 1017 try:
1018 1018 self.db.update_record(msg_id, rec)
1019 1019 except Exception:
1020 1020 self.log.error("DB Error handling stranded msg %r", msg_id, exc_info=True)
1021 1021
1022 1022
1023 1023 def finish_registration(self, heart):
1024 1024 """Second half of engine registration, called after our HeartMonitor
1025 1025 has received a beat from the Engine's Heart."""
1026 1026 try:
1027 1027 ec = self.incoming_registrations.pop(heart)
1028 1028 except KeyError:
1029 1029 self.log.error("registration::tried to finish nonexistant registration", exc_info=True)
1030 1030 return
1031 1031 self.log.info("registration::finished registering engine %i:%s", ec.id, ec.uuid)
1032 1032 if ec.stallback is not None:
1033 1033 self.loop.remove_timeout(ec.stallback)
1034 1034 eid = ec.id
1035 1035 self.ids.add(eid)
1036 1036 self.keytable[eid] = ec.uuid
1037 1037 self.engines[eid] = ec
1038 1038 self.by_ident[cast_bytes(ec.uuid)] = ec.id
1039 1039 self.queues[eid] = list()
1040 1040 self.tasks[eid] = list()
1041 1041 self.completed[eid] = list()
1042 1042 self.hearts[heart] = eid
1043 1043 content = dict(id=eid, uuid=self.engines[eid].uuid)
1044 1044 if self.notifier:
1045 1045 self.session.send(self.notifier, "registration_notification", content=content)
1046 1046 self.log.info("engine::Engine Connected: %i", eid)
1047 1047
1048 1048 self._save_engine_state()
1049 1049
1050 1050 def _purge_stalled_registration(self, heart):
1051 1051 if heart in self.incoming_registrations:
1052 1052 ec = self.incoming_registrations.pop(heart)
1053 1053 self.log.info("registration::purging stalled registration: %i", ec.id)
1054 1054 else:
1055 1055 pass
1056 1056
1057 1057 #-------------------------------------------------------------------------
1058 1058 # Engine State
1059 1059 #-------------------------------------------------------------------------
1060 1060
1061 1061
1062 1062 def _cleanup_engine_state_file(self):
1063 1063 """cleanup engine state mapping"""
1064 1064
1065 1065 if os.path.exists(self.engine_state_file):
1066 1066 self.log.debug("cleaning up engine state: %s", self.engine_state_file)
1067 1067 try:
1068 1068 os.remove(self.engine_state_file)
1069 1069 except IOError:
1070 1070 self.log.error("Couldn't cleanup file: %s", self.engine_state_file, exc_info=True)
1071 1071
1072 1072
1073 1073 def _save_engine_state(self):
1074 1074 """save engine mapping to JSON file"""
1075 1075 if not self.engine_state_file:
1076 1076 return
1077 1077 self.log.debug("save engine state to %s" % self.engine_state_file)
1078 1078 state = {}
1079 1079 engines = {}
1080 1080 for eid, ec in iteritems(self.engines):
1081 1081 if ec.uuid not in self.dead_engines:
1082 1082 engines[eid] = ec.uuid
1083 1083
1084 1084 state['engines'] = engines
1085 1085
1086 1086 state['next_id'] = self._idcounter
1087 1087
1088 1088 with open(self.engine_state_file, 'w') as f:
1089 1089 json.dump(state, f)
1090 1090
1091 1091
1092 1092 def _load_engine_state(self):
1093 1093 """load engine mapping from JSON file"""
1094 1094 if not os.path.exists(self.engine_state_file):
1095 1095 return
1096 1096
1097 1097 self.log.info("loading engine state from %s" % self.engine_state_file)
1098 1098
1099 1099 with open(self.engine_state_file) as f:
1100 1100 state = json.load(f)
1101 1101
1102 1102 save_notifier = self.notifier
1103 1103 self.notifier = None
1104 1104 for eid, uuid in iteritems(state['engines']):
1105 1105 heart = uuid.encode('ascii')
1106 1106 # start with this heart as current and beating:
1107 1107 self.heartmonitor.responses.add(heart)
1108 1108 self.heartmonitor.hearts.add(heart)
1109 1109
1110 1110 self.incoming_registrations[heart] = EngineConnector(id=int(eid), uuid=uuid)
1111 1111 self.finish_registration(heart)
1112 1112
1113 1113 self.notifier = save_notifier
1114 1114
1115 1115 self._idcounter = state['next_id']
1116 1116
1117 1117 #-------------------------------------------------------------------------
1118 1118 # Client Requests
1119 1119 #-------------------------------------------------------------------------
1120 1120
1121 1121 def shutdown_request(self, client_id, msg):
1122 1122 """handle shutdown request."""
1123 1123 self.session.send(self.query, 'shutdown_reply', content={'status': 'ok'}, ident=client_id)
1124 1124 # also notify other clients of shutdown
1125 1125 self.session.send(self.notifier, 'shutdown_notice', content={'status': 'ok'})
1126 1126 self.loop.add_timeout(self.loop.time() + 1, self._shutdown)
1127 1127
1128 1128 def _shutdown(self):
1129 1129 self.log.info("hub::hub shutting down.")
1130 1130 time.sleep(0.1)
1131 1131 sys.exit(0)
1132 1132
1133 1133
1134 1134 def check_load(self, client_id, msg):
1135 1135 content = msg['content']
1136 1136 try:
1137 1137 targets = content['targets']
1138 1138 targets = self._validate_targets(targets)
1139 1139 except:
1140 1140 content = error.wrap_exception()
1141 1141 self.session.send(self.query, "hub_error",
1142 1142 content=content, ident=client_id)
1143 1143 return
1144 1144
1145 1145 content = dict(status='ok')
1146 1146 # loads = {}
1147 1147 for t in targets:
1148 1148 content[bytes(t)] = len(self.queues[t])+len(self.tasks[t])
1149 1149 self.session.send(self.query, "load_reply", content=content, ident=client_id)
1150 1150
1151 1151
1152 1152 def queue_status(self, client_id, msg):
1153 1153 """Return the Queue status of one or more targets.
1154 1154
1155 1155 If verbose, return the msg_ids, else return len of each type.
1156 1156
1157 1157 Keys:
1158 1158
1159 1159 * queue (pending MUX jobs)
1160 1160 * tasks (pending Task jobs)
1161 1161 * completed (finished jobs from both queues)
1162 1162 """
1163 1163 content = msg['content']
1164 1164 targets = content['targets']
1165 1165 try:
1166 1166 targets = self._validate_targets(targets)
1167 1167 except:
1168 1168 content = error.wrap_exception()
1169 1169 self.session.send(self.query, "hub_error",
1170 1170 content=content, ident=client_id)
1171 1171 return
1172 1172 verbose = content.get('verbose', False)
1173 1173 content = dict(status='ok')
1174 1174 for t in targets:
1175 1175 queue = self.queues[t]
1176 1176 completed = self.completed[t]
1177 1177 tasks = self.tasks[t]
1178 1178 if not verbose:
1179 1179 queue = len(queue)
1180 1180 completed = len(completed)
1181 1181 tasks = len(tasks)
1182 1182 content[str(t)] = {'queue': queue, 'completed': completed , 'tasks': tasks}
1183 1183 content['unassigned'] = list(self.unassigned) if verbose else len(self.unassigned)
1184 1184 # print (content)
1185 1185 self.session.send(self.query, "queue_reply", content=content, ident=client_id)
1186 1186
1187 1187 def purge_results(self, client_id, msg):
1188 1188 """Purge results from memory. This method is more valuable before we move
1189 1189 to a DB based message storage mechanism."""
1190 1190 content = msg['content']
1191 1191 self.log.info("Dropping records with %s", content)
1192 1192 msg_ids = content.get('msg_ids', [])
1193 1193 reply = dict(status='ok')
1194 1194 if msg_ids == 'all':
1195 1195 try:
1196 1196 self.db.drop_matching_records(dict(completed={'$ne':None}))
1197 1197 except Exception:
1198 1198 reply = error.wrap_exception()
1199 1199 self.log.exception("Error dropping records")
1200 1200 else:
1201 1201 pending = [m for m in msg_ids if (m in self.pending)]
1202 1202 if pending:
1203 1203 try:
1204 1204 raise IndexError("msg pending: %r" % pending[0])
1205 1205 except:
1206 1206 reply = error.wrap_exception()
1207 1207 self.log.exception("Error dropping records")
1208 1208 else:
1209 1209 try:
1210 1210 self.db.drop_matching_records(dict(msg_id={'$in':msg_ids}))
1211 1211 except Exception:
1212 1212 reply = error.wrap_exception()
1213 1213 self.log.exception("Error dropping records")
1214 1214
1215 1215 if reply['status'] == 'ok':
1216 1216 eids = content.get('engine_ids', [])
1217 1217 for eid in eids:
1218 1218 if eid not in self.engines:
1219 1219 try:
1220 1220 raise IndexError("No such engine: %i" % eid)
1221 1221 except:
1222 1222 reply = error.wrap_exception()
1223 1223 self.log.exception("Error dropping records")
1224 1224 break
1225 1225 uid = self.engines[eid].uuid
1226 1226 try:
1227 1227 self.db.drop_matching_records(dict(engine_uuid=uid, completed={'$ne':None}))
1228 1228 except Exception:
1229 1229 reply = error.wrap_exception()
1230 1230 self.log.exception("Error dropping records")
1231 1231 break
1232 1232
1233 1233 self.session.send(self.query, 'purge_reply', content=reply, ident=client_id)
1234 1234
1235 1235 def resubmit_task(self, client_id, msg):
1236 1236 """Resubmit one or more tasks."""
1237 1237 def finish(reply):
1238 1238 self.session.send(self.query, 'resubmit_reply', content=reply, ident=client_id)
1239 1239
1240 1240 content = msg['content']
1241 1241 msg_ids = content['msg_ids']
1242 1242 reply = dict(status='ok')
1243 1243 try:
1244 1244 records = self.db.find_records({'msg_id' : {'$in' : msg_ids}}, keys=[
1245 1245 'header', 'content', 'buffers'])
1246 1246 except Exception:
1247 1247 self.log.error('db::db error finding tasks to resubmit', exc_info=True)
1248 1248 return finish(error.wrap_exception())
1249 1249
1250 1250 # validate msg_ids
1251 1251 found_ids = [ rec['msg_id'] for rec in records ]
1252 1252 pending_ids = [ msg_id for msg_id in found_ids if msg_id in self.pending ]
1253 1253 if len(records) > len(msg_ids):
1254 1254 try:
1255 1255 raise RuntimeError("DB appears to be in an inconsistent state."
1256 1256 "More matching records were found than should exist")
1257 1257 except Exception:
1258 1258 self.log.exception("Failed to resubmit task")
1259 1259 return finish(error.wrap_exception())
1260 1260 elif len(records) < len(msg_ids):
1261 1261 missing = [ m for m in msg_ids if m not in found_ids ]
1262 1262 try:
1263 1263 raise KeyError("No such msg(s): %r" % missing)
1264 1264 except KeyError:
1265 1265 self.log.exception("Failed to resubmit task")
1266 1266 return finish(error.wrap_exception())
1267 1267 elif pending_ids:
1268 1268 pass
1269 1269 # no need to raise on resubmit of pending task, now that we
1270 1270 # resubmit under new ID, but do we want to raise anyway?
1271 1271 # msg_id = invalid_ids[0]
1272 1272 # try:
1273 1273 # raise ValueError("Task(s) %r appears to be inflight" % )
1274 1274 # except Exception:
1275 1275 # return finish(error.wrap_exception())
1276 1276
1277 1277 # mapping of original IDs to resubmitted IDs
1278 1278 resubmitted = {}
1279 1279
1280 1280 # send the messages
1281 1281 for rec in records:
1282 1282 header = rec['header']
1283 1283 msg = self.session.msg(header['msg_type'], parent=header)
1284 1284 msg_id = msg['msg_id']
1285 1285 msg['content'] = rec['content']
1286 1286
1287 1287 # use the old header, but update msg_id and timestamp
1288 1288 fresh = msg['header']
1289 1289 header['msg_id'] = fresh['msg_id']
1290 1290 header['date'] = fresh['date']
1291 1291 msg['header'] = header
1292 1292
1293 1293 self.session.send(self.resubmit, msg, buffers=rec['buffers'])
1294 1294
1295 1295 resubmitted[rec['msg_id']] = msg_id
1296 1296 self.pending.add(msg_id)
1297 1297 msg['buffers'] = rec['buffers']
1298 1298 try:
1299 1299 self.db.add_record(msg_id, init_record(msg))
1300 1300 except Exception:
1301 1301 self.log.error("db::DB Error updating record: %s", msg_id, exc_info=True)
1302 1302 return finish(error.wrap_exception())
1303 1303
1304 1304 finish(dict(status='ok', resubmitted=resubmitted))
1305 1305
1306 1306 # store the new IDs in the Task DB
1307 1307 for msg_id, resubmit_id in iteritems(resubmitted):
1308 1308 try:
1309 1309 self.db.update_record(msg_id, {'resubmitted' : resubmit_id})
1310 1310 except Exception:
1311 1311 self.log.error("db::DB Error updating record: %s", msg_id, exc_info=True)
1312 1312
1313 1313
1314 1314 def _extract_record(self, rec):
1315 1315 """decompose a TaskRecord dict into subsection of reply for get_result"""
1316 1316 io_dict = {}
1317 1317 for key in ('execute_input', 'execute_result', 'error', 'stdout', 'stderr'):
1318 1318 io_dict[key] = rec[key]
1319 1319 content = {
1320 1320 'header': rec['header'],
1321 1321 'metadata': rec['metadata'],
1322 1322 'result_metadata': rec['result_metadata'],
1323 1323 'result_header' : rec['result_header'],
1324 1324 'result_content': rec['result_content'],
1325 1325 'received' : rec['received'],
1326 1326 'io' : io_dict,
1327 1327 }
1328 1328 if rec['result_buffers']:
1329 1329 buffers = list(map(bytes, rec['result_buffers']))
1330 1330 else:
1331 1331 buffers = []
1332 1332
1333 1333 return content, buffers
1334 1334
1335 1335 def get_results(self, client_id, msg):
1336 1336 """Get the result of 1 or more messages."""
1337 1337 content = msg['content']
1338 1338 msg_ids = sorted(set(content['msg_ids']))
1339 1339 statusonly = content.get('status_only', False)
1340 1340 pending = []
1341 1341 completed = []
1342 1342 content = dict(status='ok')
1343 1343 content['pending'] = pending
1344 1344 content['completed'] = completed
1345 1345 buffers = []
1346 1346 if not statusonly:
1347 1347 try:
1348 1348 matches = self.db.find_records(dict(msg_id={'$in':msg_ids}))
1349 1349 # turn match list into dict, for faster lookup
1350 1350 records = {}
1351 1351 for rec in matches:
1352 1352 records[rec['msg_id']] = rec
1353 1353 except Exception:
1354 1354 content = error.wrap_exception()
1355 1355 self.log.exception("Failed to get results")
1356 1356 self.session.send(self.query, "result_reply", content=content,
1357 1357 parent=msg, ident=client_id)
1358 1358 return
1359 1359 else:
1360 1360 records = {}
1361 1361 for msg_id in msg_ids:
1362 1362 if msg_id in self.pending:
1363 1363 pending.append(msg_id)
1364 1364 elif msg_id in self.all_completed:
1365 1365 completed.append(msg_id)
1366 1366 if not statusonly:
1367 1367 c,bufs = self._extract_record(records[msg_id])
1368 1368 content[msg_id] = c
1369 1369 buffers.extend(bufs)
1370 1370 elif msg_id in records:
1371 1371 if rec['completed']:
1372 1372 completed.append(msg_id)
1373 1373 c,bufs = self._extract_record(records[msg_id])
1374 1374 content[msg_id] = c
1375 1375 buffers.extend(bufs)
1376 1376 else:
1377 1377 pending.append(msg_id)
1378 1378 else:
1379 1379 try:
1380 1380 raise KeyError('No such message: '+msg_id)
1381 1381 except:
1382 1382 content = error.wrap_exception()
1383 1383 break
1384 1384 self.session.send(self.query, "result_reply", content=content,
1385 1385 parent=msg, ident=client_id,
1386 1386 buffers=buffers)
1387 1387
1388 1388 def get_history(self, client_id, msg):
1389 1389 """Get a list of all msg_ids in our DB records"""
1390 1390 try:
1391 1391 msg_ids = self.db.get_history()
1392 1392 except Exception as e:
1393 1393 content = error.wrap_exception()
1394 1394 self.log.exception("Failed to get history")
1395 1395 else:
1396 1396 content = dict(status='ok', history=msg_ids)
1397 1397
1398 1398 self.session.send(self.query, "history_reply", content=content,
1399 1399 parent=msg, ident=client_id)
1400 1400
1401 1401 def db_query(self, client_id, msg):
1402 1402 """Perform a raw query on the task record database."""
1403 1403 content = msg['content']
1404 1404 query = extract_dates(content.get('query', {}))
1405 1405 keys = content.get('keys', None)
1406 1406 buffers = []
1407 1407 empty = list()
1408 1408 try:
1409 1409 records = self.db.find_records(query, keys)
1410 1410 except Exception as e:
1411 1411 content = error.wrap_exception()
1412 1412 self.log.exception("DB query failed")
1413 1413 else:
1414 1414 # extract buffers from reply content:
1415 1415 if keys is not None:
1416 1416 buffer_lens = [] if 'buffers' in keys else None
1417 1417 result_buffer_lens = [] if 'result_buffers' in keys else None
1418 1418 else:
1419 1419 buffer_lens = None
1420 1420 result_buffer_lens = None
1421 1421
1422 1422 for rec in records:
1423 1423 # buffers may be None, so double check
1424 1424 b = rec.pop('buffers', empty) or empty
1425 1425 if buffer_lens is not None:
1426 1426 buffer_lens.append(len(b))
1427 1427 buffers.extend(b)
1428 1428 rb = rec.pop('result_buffers', empty) or empty
1429 1429 if result_buffer_lens is not None:
1430 1430 result_buffer_lens.append(len(rb))
1431 1431 buffers.extend(rb)
1432 1432 content = dict(status='ok', records=records, buffer_lens=buffer_lens,
1433 1433 result_buffer_lens=result_buffer_lens)
1434 1434 # self.log.debug (content)
1435 1435 self.session.send(self.query, "db_reply", content=content,
1436 1436 parent=msg, ident=client_id,
1437 1437 buffers=buffers)
1438 1438
@@ -1,849 +1,849
1 1 """The Python scheduler for rich scheduling.
2 2
3 3 The Pure ZMQ scheduler does not allow routing schemes other than LRU,
4 4 nor does it check msg_id DAG dependencies. For those, a slightly slower
5 5 Python Scheduler exists.
6 6 """
7 7
8 8 # Copyright (c) IPython Development Team.
9 9 # Distributed under the terms of the Modified BSD License.
10 10
11 11 import logging
12 12 import sys
13 13 import time
14 14
15 15 from collections import deque
16 16 from datetime import datetime
17 17 from random import randint, random
18 18 from types import FunctionType
19 19
20 20 try:
21 21 import numpy
22 22 except ImportError:
23 23 numpy = None
24 24
25 25 import zmq
26 26 from zmq.eventloop import ioloop, zmqstream
27 27
28 28 # local imports
29 29 from IPython.external.decorator import decorator
30 30 from IPython.config.application import Application
31 31 from IPython.config.loader import Config
32 32 from IPython.utils.traitlets import Instance, Dict, List, Set, Integer, Enum, CBytes
33 33 from IPython.utils.py3compat import cast_bytes
34 34
35 35 from IPython.parallel import error, util
36 36 from IPython.parallel.factory import SessionFactory
37 37 from IPython.parallel.util import connect_logger, local_logger
38 38
39 39 from .dependency import Dependency
40 40
41 41 @decorator
42 42 def logged(f,self,*args,**kwargs):
43 43 # print ("#--------------------")
44 44 self.log.debug("scheduler::%s(*%s,**%s)", f.__name__, args, kwargs)
45 45 # print ("#--")
46 46 return f(self,*args, **kwargs)
47 47
48 48 #----------------------------------------------------------------------
49 49 # Chooser functions
50 50 #----------------------------------------------------------------------
51 51
52 52 def plainrandom(loads):
53 53 """Plain random pick."""
54 54 n = len(loads)
55 55 return randint(0,n-1)
56 56
57 57 def lru(loads):
58 58 """Always pick the front of the line.
59 59
60 60 The content of `loads` is ignored.
61 61
62 62 Assumes LRU ordering of loads, with oldest first.
63 63 """
64 64 return 0
65 65
66 66 def twobin(loads):
67 67 """Pick two at random, use the LRU of the two.
68 68
69 69 The content of loads is ignored.
70 70
71 71 Assumes LRU ordering of loads, with oldest first.
72 72 """
73 73 n = len(loads)
74 74 a = randint(0,n-1)
75 75 b = randint(0,n-1)
76 76 return min(a,b)
77 77
78 78 def weighted(loads):
79 79 """Pick two at random using inverse load as weight.
80 80
81 81 Return the less loaded of the two.
82 82 """
83 83 # weight 0 a million times more than 1:
84 84 weights = 1./(1e-6+numpy.array(loads))
85 85 sums = weights.cumsum()
86 86 t = sums[-1]
87 87 x = random()*t
88 88 y = random()*t
89 89 idx = 0
90 90 idy = 0
91 91 while sums[idx] < x:
92 92 idx += 1
93 93 while sums[idy] < y:
94 94 idy += 1
95 95 if weights[idy] > weights[idx]:
96 96 return idy
97 97 else:
98 98 return idx
99 99
100 100 def leastload(loads):
101 101 """Always choose the lowest load.
102 102
103 103 If the lowest load occurs more than once, the first
104 104 occurance will be used. If loads has LRU ordering, this means
105 105 the LRU of those with the lowest load is chosen.
106 106 """
107 107 return loads.index(min(loads))
108 108
109 109 #---------------------------------------------------------------------
110 110 # Classes
111 111 #---------------------------------------------------------------------
112 112
113 113
114 114 # store empty default dependency:
115 115 MET = Dependency([])
116 116
117 117
118 118 class Job(object):
119 119 """Simple container for a job"""
120 120 def __init__(self, msg_id, raw_msg, idents, msg, header, metadata,
121 121 targets, after, follow, timeout):
122 122 self.msg_id = msg_id
123 123 self.raw_msg = raw_msg
124 124 self.idents = idents
125 125 self.msg = msg
126 126 self.header = header
127 127 self.metadata = metadata
128 128 self.targets = targets
129 129 self.after = after
130 130 self.follow = follow
131 131 self.timeout = timeout
132 132
133 133 self.removed = False # used for lazy-delete from sorted queue
134 134 self.timestamp = time.time()
135 135 self.timeout_id = 0
136 136 self.blacklist = set()
137 137
138 138 def __lt__(self, other):
139 139 return self.timestamp < other.timestamp
140 140
141 141 def __cmp__(self, other):
142 142 return cmp(self.timestamp, other.timestamp)
143 143
144 144 @property
145 145 def dependents(self):
146 146 return self.follow.union(self.after)
147 147
148 148
149 149 class TaskScheduler(SessionFactory):
150 150 """Python TaskScheduler object.
151 151
152 152 This is the simplest object that supports msg_id based
153 153 DAG dependencies. *Only* task msg_ids are checked, not
154 154 msg_ids of jobs submitted via the MUX queue.
155 155
156 156 """
157 157
158 158 hwm = Integer(1, config=True,
159 159 help="""specify the High Water Mark (HWM) for the downstream
160 160 socket in the Task scheduler. This is the maximum number
161 161 of allowed outstanding tasks on each engine.
162 162
163 163 The default (1) means that only one task can be outstanding on each
164 164 engine. Setting TaskScheduler.hwm=0 means there is no limit, and the
165 165 engines continue to be assigned tasks while they are working,
166 166 effectively hiding network latency behind computation, but can result
167 167 in an imbalance of work when submitting many heterogenous tasks all at
168 168 once. Any positive value greater than one is a compromise between the
169 169 two.
170 170
171 171 """
172 172 )
173 173 scheme_name = Enum(('leastload', 'pure', 'lru', 'plainrandom', 'weighted', 'twobin'),
174 174 'leastload', config=True, allow_none=False,
175 175 help="""select the task scheduler scheme [default: Python LRU]
176 176 Options are: 'pure', 'lru', 'plainrandom', 'weighted', 'twobin','leastload'"""
177 177 )
178 178 def _scheme_name_changed(self, old, new):
179 179 self.log.debug("Using scheme %r"%new)
180 180 self.scheme = globals()[new]
181 181
182 182 # input arguments:
183 183 scheme = Instance(FunctionType) # function for determining the destination
184 184 def _scheme_default(self):
185 185 return leastload
186 186 client_stream = Instance(zmqstream.ZMQStream) # client-facing stream
187 187 engine_stream = Instance(zmqstream.ZMQStream) # engine-facing stream
188 188 notifier_stream = Instance(zmqstream.ZMQStream) # hub-facing sub stream
189 189 mon_stream = Instance(zmqstream.ZMQStream) # hub-facing pub stream
190 190 query_stream = Instance(zmqstream.ZMQStream) # hub-facing DEALER stream
191 191
192 192 # internals:
193 193 queue = Instance(deque) # sorted list of Jobs
194 194 def _queue_default(self):
195 195 return deque()
196 196 queue_map = Dict() # dict by msg_id of Jobs (for O(1) access to the Queue)
197 197 graph = Dict() # dict by msg_id of [ msg_ids that depend on key ]
198 198 retries = Dict() # dict by msg_id of retries remaining (non-neg ints)
199 199 # waiting = List() # list of msg_ids ready to run, but haven't due to HWM
200 200 pending = Dict() # dict by engine_uuid of submitted tasks
201 201 completed = Dict() # dict by engine_uuid of completed tasks
202 202 failed = Dict() # dict by engine_uuid of failed tasks
203 203 destinations = Dict() # dict by msg_id of engine_uuids where jobs ran (reverse of completed+failed)
204 204 clients = Dict() # dict by msg_id for who submitted the task
205 205 targets = List() # list of target IDENTs
206 206 loads = List() # list of engine loads
207 207 # full = Set() # set of IDENTs that have HWM outstanding tasks
208 208 all_completed = Set() # set of all completed tasks
209 209 all_failed = Set() # set of all failed tasks
210 210 all_done = Set() # set of all finished tasks=union(completed,failed)
211 211 all_ids = Set() # set of all submitted task IDs
212 212
213 213 ident = CBytes() # ZMQ identity. This should just be self.session.session
214 214 # but ensure Bytes
215 215 def _ident_default(self):
216 216 return self.session.bsession
217 217
218 218 def start(self):
219 219 self.query_stream.on_recv(self.dispatch_query_reply)
220 220 self.session.send(self.query_stream, "connection_request", {})
221 221
222 222 self.engine_stream.on_recv(self.dispatch_result, copy=False)
223 223 self.client_stream.on_recv(self.dispatch_submission, copy=False)
224 224
225 225 self._notification_handlers = dict(
226 226 registration_notification = self._register_engine,
227 227 unregistration_notification = self._unregister_engine
228 228 )
229 229 self.notifier_stream.on_recv(self.dispatch_notification)
230 230 self.log.info("Scheduler started [%s]" % self.scheme_name)
231 231
232 232 def resume_receiving(self):
233 233 """Resume accepting jobs."""
234 234 self.client_stream.on_recv(self.dispatch_submission, copy=False)
235 235
236 236 def stop_receiving(self):
237 237 """Stop accepting jobs while there are no engines.
238 238 Leave them in the ZMQ queue."""
239 239 self.client_stream.on_recv(None)
240 240
241 241 #-----------------------------------------------------------------------
242 242 # [Un]Registration Handling
243 243 #-----------------------------------------------------------------------
244 244
245 245
246 246 def dispatch_query_reply(self, msg):
247 247 """handle reply to our initial connection request"""
248 248 try:
249 249 idents,msg = self.session.feed_identities(msg)
250 250 except ValueError:
251 251 self.log.warn("task::Invalid Message: %r",msg)
252 252 return
253 253 try:
254 msg = self.session.unserialize(msg)
254 msg = self.session.deserialize(msg)
255 255 except ValueError:
256 256 self.log.warn("task::Unauthorized message from: %r"%idents)
257 257 return
258 258
259 259 content = msg['content']
260 260 for uuid in content.get('engines', {}).values():
261 261 self._register_engine(cast_bytes(uuid))
262 262
263 263
264 264 @util.log_errors
265 265 def dispatch_notification(self, msg):
266 266 """dispatch register/unregister events."""
267 267 try:
268 268 idents,msg = self.session.feed_identities(msg)
269 269 except ValueError:
270 270 self.log.warn("task::Invalid Message: %r",msg)
271 271 return
272 272 try:
273 msg = self.session.unserialize(msg)
273 msg = self.session.deserialize(msg)
274 274 except ValueError:
275 275 self.log.warn("task::Unauthorized message from: %r"%idents)
276 276 return
277 277
278 278 msg_type = msg['header']['msg_type']
279 279
280 280 handler = self._notification_handlers.get(msg_type, None)
281 281 if handler is None:
282 282 self.log.error("Unhandled message type: %r"%msg_type)
283 283 else:
284 284 try:
285 285 handler(cast_bytes(msg['content']['uuid']))
286 286 except Exception:
287 287 self.log.error("task::Invalid notification msg: %r", msg, exc_info=True)
288 288
289 289 def _register_engine(self, uid):
290 290 """New engine with ident `uid` became available."""
291 291 # head of the line:
292 292 self.targets.insert(0,uid)
293 293 self.loads.insert(0,0)
294 294
295 295 # initialize sets
296 296 self.completed[uid] = set()
297 297 self.failed[uid] = set()
298 298 self.pending[uid] = {}
299 299
300 300 # rescan the graph:
301 301 self.update_graph(None)
302 302
303 303 def _unregister_engine(self, uid):
304 304 """Existing engine with ident `uid` became unavailable."""
305 305 if len(self.targets) == 1:
306 306 # this was our only engine
307 307 pass
308 308
309 309 # handle any potentially finished tasks:
310 310 self.engine_stream.flush()
311 311
312 312 # don't pop destinations, because they might be used later
313 313 # map(self.destinations.pop, self.completed.pop(uid))
314 314 # map(self.destinations.pop, self.failed.pop(uid))
315 315
316 316 # prevent this engine from receiving work
317 317 idx = self.targets.index(uid)
318 318 self.targets.pop(idx)
319 319 self.loads.pop(idx)
320 320
321 321 # wait 5 seconds before cleaning up pending jobs, since the results might
322 322 # still be incoming
323 323 if self.pending[uid]:
324 324 self.loop.add_timeout(self.loop.time() + 5,
325 325 lambda : self.handle_stranded_tasks(uid),
326 326 )
327 327 else:
328 328 self.completed.pop(uid)
329 329 self.failed.pop(uid)
330 330
331 331
332 332 def handle_stranded_tasks(self, engine):
333 333 """Deal with jobs resident in an engine that died."""
334 334 lost = self.pending[engine]
335 335 for msg_id in lost.keys():
336 336 if msg_id not in self.pending[engine]:
337 337 # prevent double-handling of messages
338 338 continue
339 339
340 340 raw_msg = lost[msg_id].raw_msg
341 341 idents,msg = self.session.feed_identities(raw_msg, copy=False)
342 342 parent = self.session.unpack(msg[1].bytes)
343 343 idents = [engine, idents[0]]
344 344
345 345 # build fake error reply
346 346 try:
347 347 raise error.EngineError("Engine %r died while running task %r"%(engine, msg_id))
348 348 except:
349 349 content = error.wrap_exception()
350 350 # build fake metadata
351 351 md = dict(
352 352 status=u'error',
353 353 engine=engine.decode('ascii'),
354 354 date=datetime.now(),
355 355 )
356 356 msg = self.session.msg('apply_reply', content, parent=parent, metadata=md)
357 357 raw_reply = list(map(zmq.Message, self.session.serialize(msg, ident=idents)))
358 358 # and dispatch it
359 359 self.dispatch_result(raw_reply)
360 360
361 361 # finally scrub completed/failed lists
362 362 self.completed.pop(engine)
363 363 self.failed.pop(engine)
364 364
365 365
366 366 #-----------------------------------------------------------------------
367 367 # Job Submission
368 368 #-----------------------------------------------------------------------
369 369
370 370
371 371 @util.log_errors
372 372 def dispatch_submission(self, raw_msg):
373 373 """Dispatch job submission to appropriate handlers."""
374 374 # ensure targets up to date:
375 375 self.notifier_stream.flush()
376 376 try:
377 377 idents, msg = self.session.feed_identities(raw_msg, copy=False)
378 msg = self.session.unserialize(msg, content=False, copy=False)
378 msg = self.session.deserialize(msg, content=False, copy=False)
379 379 except Exception:
380 380 self.log.error("task::Invaid task msg: %r"%raw_msg, exc_info=True)
381 381 return
382 382
383 383
384 384 # send to monitor
385 385 self.mon_stream.send_multipart([b'intask']+raw_msg, copy=False)
386 386
387 387 header = msg['header']
388 388 md = msg['metadata']
389 389 msg_id = header['msg_id']
390 390 self.all_ids.add(msg_id)
391 391
392 392 # get targets as a set of bytes objects
393 393 # from a list of unicode objects
394 394 targets = md.get('targets', [])
395 395 targets = set(map(cast_bytes, targets))
396 396
397 397 retries = md.get('retries', 0)
398 398 self.retries[msg_id] = retries
399 399
400 400 # time dependencies
401 401 after = md.get('after', None)
402 402 if after:
403 403 after = Dependency(after)
404 404 if after.all:
405 405 if after.success:
406 406 after = Dependency(after.difference(self.all_completed),
407 407 success=after.success,
408 408 failure=after.failure,
409 409 all=after.all,
410 410 )
411 411 if after.failure:
412 412 after = Dependency(after.difference(self.all_failed),
413 413 success=after.success,
414 414 failure=after.failure,
415 415 all=after.all,
416 416 )
417 417 if after.check(self.all_completed, self.all_failed):
418 418 # recast as empty set, if `after` already met,
419 419 # to prevent unnecessary set comparisons
420 420 after = MET
421 421 else:
422 422 after = MET
423 423
424 424 # location dependencies
425 425 follow = Dependency(md.get('follow', []))
426 426
427 427 timeout = md.get('timeout', None)
428 428 if timeout:
429 429 timeout = float(timeout)
430 430
431 431 job = Job(msg_id=msg_id, raw_msg=raw_msg, idents=idents, msg=msg,
432 432 header=header, targets=targets, after=after, follow=follow,
433 433 timeout=timeout, metadata=md,
434 434 )
435 435 # validate and reduce dependencies:
436 436 for dep in after,follow:
437 437 if not dep: # empty dependency
438 438 continue
439 439 # check valid:
440 440 if msg_id in dep or dep.difference(self.all_ids):
441 441 self.queue_map[msg_id] = job
442 442 return self.fail_unreachable(msg_id, error.InvalidDependency)
443 443 # check if unreachable:
444 444 if dep.unreachable(self.all_completed, self.all_failed):
445 445 self.queue_map[msg_id] = job
446 446 return self.fail_unreachable(msg_id)
447 447
448 448 if after.check(self.all_completed, self.all_failed):
449 449 # time deps already met, try to run
450 450 if not self.maybe_run(job):
451 451 # can't run yet
452 452 if msg_id not in self.all_failed:
453 453 # could have failed as unreachable
454 454 self.save_unmet(job)
455 455 else:
456 456 self.save_unmet(job)
457 457
458 458 def job_timeout(self, job, timeout_id):
459 459 """callback for a job's timeout.
460 460
461 461 The job may or may not have been run at this point.
462 462 """
463 463 if job.timeout_id != timeout_id:
464 464 # not the most recent call
465 465 return
466 466 now = time.time()
467 467 if job.timeout >= (now + 1):
468 468 self.log.warn("task %s timeout fired prematurely: %s > %s",
469 469 job.msg_id, job.timeout, now
470 470 )
471 471 if job.msg_id in self.queue_map:
472 472 # still waiting, but ran out of time
473 473 self.log.info("task %r timed out", job.msg_id)
474 474 self.fail_unreachable(job.msg_id, error.TaskTimeout)
475 475
476 476 def fail_unreachable(self, msg_id, why=error.ImpossibleDependency):
477 477 """a task has become unreachable, send a reply with an ImpossibleDependency
478 478 error."""
479 479 if msg_id not in self.queue_map:
480 480 self.log.error("task %r already failed!", msg_id)
481 481 return
482 482 job = self.queue_map.pop(msg_id)
483 483 # lazy-delete from the queue
484 484 job.removed = True
485 485 for mid in job.dependents:
486 486 if mid in self.graph:
487 487 self.graph[mid].remove(msg_id)
488 488
489 489 try:
490 490 raise why()
491 491 except:
492 492 content = error.wrap_exception()
493 493 self.log.debug("task %r failing as unreachable with: %s", msg_id, content['ename'])
494 494
495 495 self.all_done.add(msg_id)
496 496 self.all_failed.add(msg_id)
497 497
498 498 msg = self.session.send(self.client_stream, 'apply_reply', content,
499 499 parent=job.header, ident=job.idents)
500 500 self.session.send(self.mon_stream, msg, ident=[b'outtask']+job.idents)
501 501
502 502 self.update_graph(msg_id, success=False)
503 503
504 504 def available_engines(self):
505 505 """return a list of available engine indices based on HWM"""
506 506 if not self.hwm:
507 507 return list(range(len(self.targets)))
508 508 available = []
509 509 for idx in range(len(self.targets)):
510 510 if self.loads[idx] < self.hwm:
511 511 available.append(idx)
512 512 return available
513 513
514 514 def maybe_run(self, job):
515 515 """check location dependencies, and run if they are met."""
516 516 msg_id = job.msg_id
517 517 self.log.debug("Attempting to assign task %s", msg_id)
518 518 available = self.available_engines()
519 519 if not available:
520 520 # no engines, definitely can't run
521 521 return False
522 522
523 523 if job.follow or job.targets or job.blacklist or self.hwm:
524 524 # we need a can_run filter
525 525 def can_run(idx):
526 526 # check hwm
527 527 if self.hwm and self.loads[idx] == self.hwm:
528 528 return False
529 529 target = self.targets[idx]
530 530 # check blacklist
531 531 if target in job.blacklist:
532 532 return False
533 533 # check targets
534 534 if job.targets and target not in job.targets:
535 535 return False
536 536 # check follow
537 537 return job.follow.check(self.completed[target], self.failed[target])
538 538
539 539 indices = list(filter(can_run, available))
540 540
541 541 if not indices:
542 542 # couldn't run
543 543 if job.follow.all:
544 544 # check follow for impossibility
545 545 dests = set()
546 546 relevant = set()
547 547 if job.follow.success:
548 548 relevant = self.all_completed
549 549 if job.follow.failure:
550 550 relevant = relevant.union(self.all_failed)
551 551 for m in job.follow.intersection(relevant):
552 552 dests.add(self.destinations[m])
553 553 if len(dests) > 1:
554 554 self.queue_map[msg_id] = job
555 555 self.fail_unreachable(msg_id)
556 556 return False
557 557 if job.targets:
558 558 # check blacklist+targets for impossibility
559 559 job.targets.difference_update(job.blacklist)
560 560 if not job.targets or not job.targets.intersection(self.targets):
561 561 self.queue_map[msg_id] = job
562 562 self.fail_unreachable(msg_id)
563 563 return False
564 564 return False
565 565 else:
566 566 indices = None
567 567
568 568 self.submit_task(job, indices)
569 569 return True
570 570
571 571 def save_unmet(self, job):
572 572 """Save a message for later submission when its dependencies are met."""
573 573 msg_id = job.msg_id
574 574 self.log.debug("Adding task %s to the queue", msg_id)
575 575 self.queue_map[msg_id] = job
576 576 self.queue.append(job)
577 577 # track the ids in follow or after, but not those already finished
578 578 for dep_id in job.after.union(job.follow).difference(self.all_done):
579 579 if dep_id not in self.graph:
580 580 self.graph[dep_id] = set()
581 581 self.graph[dep_id].add(msg_id)
582 582
583 583 # schedule timeout callback
584 584 if job.timeout:
585 585 timeout_id = job.timeout_id = job.timeout_id + 1
586 586 self.loop.add_timeout(time.time() + job.timeout,
587 587 lambda : self.job_timeout(job, timeout_id)
588 588 )
589 589
590 590
591 591 def submit_task(self, job, indices=None):
592 592 """Submit a task to any of a subset of our targets."""
593 593 if indices:
594 594 loads = [self.loads[i] for i in indices]
595 595 else:
596 596 loads = self.loads
597 597 idx = self.scheme(loads)
598 598 if indices:
599 599 idx = indices[idx]
600 600 target = self.targets[idx]
601 601 # print (target, map(str, msg[:3]))
602 602 # send job to the engine
603 603 self.engine_stream.send(target, flags=zmq.SNDMORE, copy=False)
604 604 self.engine_stream.send_multipart(job.raw_msg, copy=False)
605 605 # update load
606 606 self.add_job(idx)
607 607 self.pending[target][job.msg_id] = job
608 608 # notify Hub
609 609 content = dict(msg_id=job.msg_id, engine_id=target.decode('ascii'))
610 610 self.session.send(self.mon_stream, 'task_destination', content=content,
611 611 ident=[b'tracktask',self.ident])
612 612
613 613
614 614 #-----------------------------------------------------------------------
615 615 # Result Handling
616 616 #-----------------------------------------------------------------------
617 617
618 618
619 619 @util.log_errors
620 620 def dispatch_result(self, raw_msg):
621 621 """dispatch method for result replies"""
622 622 try:
623 623 idents,msg = self.session.feed_identities(raw_msg, copy=False)
624 msg = self.session.unserialize(msg, content=False, copy=False)
624 msg = self.session.deserialize(msg, content=False, copy=False)
625 625 engine = idents[0]
626 626 try:
627 627 idx = self.targets.index(engine)
628 628 except ValueError:
629 629 pass # skip load-update for dead engines
630 630 else:
631 631 self.finish_job(idx)
632 632 except Exception:
633 633 self.log.error("task::Invalid result: %r", raw_msg, exc_info=True)
634 634 return
635 635
636 636 md = msg['metadata']
637 637 parent = msg['parent_header']
638 638 if md.get('dependencies_met', True):
639 639 success = (md['status'] == 'ok')
640 640 msg_id = parent['msg_id']
641 641 retries = self.retries[msg_id]
642 642 if not success and retries > 0:
643 643 # failed
644 644 self.retries[msg_id] = retries - 1
645 645 self.handle_unmet_dependency(idents, parent)
646 646 else:
647 647 del self.retries[msg_id]
648 648 # relay to client and update graph
649 649 self.handle_result(idents, parent, raw_msg, success)
650 650 # send to Hub monitor
651 651 self.mon_stream.send_multipart([b'outtask']+raw_msg, copy=False)
652 652 else:
653 653 self.handle_unmet_dependency(idents, parent)
654 654
655 655 def handle_result(self, idents, parent, raw_msg, success=True):
656 656 """handle a real task result, either success or failure"""
657 657 # first, relay result to client
658 658 engine = idents[0]
659 659 client = idents[1]
660 660 # swap_ids for ROUTER-ROUTER mirror
661 661 raw_msg[:2] = [client,engine]
662 662 # print (map(str, raw_msg[:4]))
663 663 self.client_stream.send_multipart(raw_msg, copy=False)
664 664 # now, update our data structures
665 665 msg_id = parent['msg_id']
666 666 self.pending[engine].pop(msg_id)
667 667 if success:
668 668 self.completed[engine].add(msg_id)
669 669 self.all_completed.add(msg_id)
670 670 else:
671 671 self.failed[engine].add(msg_id)
672 672 self.all_failed.add(msg_id)
673 673 self.all_done.add(msg_id)
674 674 self.destinations[msg_id] = engine
675 675
676 676 self.update_graph(msg_id, success)
677 677
678 678 def handle_unmet_dependency(self, idents, parent):
679 679 """handle an unmet dependency"""
680 680 engine = idents[0]
681 681 msg_id = parent['msg_id']
682 682
683 683 job = self.pending[engine].pop(msg_id)
684 684 job.blacklist.add(engine)
685 685
686 686 if job.blacklist == job.targets:
687 687 self.queue_map[msg_id] = job
688 688 self.fail_unreachable(msg_id)
689 689 elif not self.maybe_run(job):
690 690 # resubmit failed
691 691 if msg_id not in self.all_failed:
692 692 # put it back in our dependency tree
693 693 self.save_unmet(job)
694 694
695 695 if self.hwm:
696 696 try:
697 697 idx = self.targets.index(engine)
698 698 except ValueError:
699 699 pass # skip load-update for dead engines
700 700 else:
701 701 if self.loads[idx] == self.hwm-1:
702 702 self.update_graph(None)
703 703
704 704 def update_graph(self, dep_id=None, success=True):
705 705 """dep_id just finished. Update our dependency
706 706 graph and submit any jobs that just became runnable.
707 707
708 708 Called with dep_id=None to update entire graph for hwm, but without finishing a task.
709 709 """
710 710 # print ("\n\n***********")
711 711 # pprint (dep_id)
712 712 # pprint (self.graph)
713 713 # pprint (self.queue_map)
714 714 # pprint (self.all_completed)
715 715 # pprint (self.all_failed)
716 716 # print ("\n\n***********\n\n")
717 717 # update any jobs that depended on the dependency
718 718 msg_ids = self.graph.pop(dep_id, [])
719 719
720 720 # recheck *all* jobs if
721 721 # a) we have HWM and an engine just become no longer full
722 722 # or b) dep_id was given as None
723 723
724 724 if dep_id is None or self.hwm and any( [ load==self.hwm-1 for load in self.loads ]):
725 725 jobs = self.queue
726 726 using_queue = True
727 727 else:
728 728 using_queue = False
729 729 jobs = deque(sorted( self.queue_map[msg_id] for msg_id in msg_ids ))
730 730
731 731 to_restore = []
732 732 while jobs:
733 733 job = jobs.popleft()
734 734 if job.removed:
735 735 continue
736 736 msg_id = job.msg_id
737 737
738 738 put_it_back = True
739 739
740 740 if job.after.unreachable(self.all_completed, self.all_failed)\
741 741 or job.follow.unreachable(self.all_completed, self.all_failed):
742 742 self.fail_unreachable(msg_id)
743 743 put_it_back = False
744 744
745 745 elif job.after.check(self.all_completed, self.all_failed): # time deps met, maybe run
746 746 if self.maybe_run(job):
747 747 put_it_back = False
748 748 self.queue_map.pop(msg_id)
749 749 for mid in job.dependents:
750 750 if mid in self.graph:
751 751 self.graph[mid].remove(msg_id)
752 752
753 753 # abort the loop if we just filled up all of our engines.
754 754 # avoids an O(N) operation in situation of full queue,
755 755 # where graph update is triggered as soon as an engine becomes
756 756 # non-full, and all tasks after the first are checked,
757 757 # even though they can't run.
758 758 if not self.available_engines():
759 759 break
760 760
761 761 if using_queue and put_it_back:
762 762 # popped a job from the queue but it neither ran nor failed,
763 763 # so we need to put it back when we are done
764 764 # make sure to_restore preserves the same ordering
765 765 to_restore.append(job)
766 766
767 767 # put back any tasks we popped but didn't run
768 768 if using_queue:
769 769 self.queue.extendleft(to_restore)
770 770
771 771 #----------------------------------------------------------------------
772 772 # methods to be overridden by subclasses
773 773 #----------------------------------------------------------------------
774 774
775 775 def add_job(self, idx):
776 776 """Called after self.targets[idx] just got the job with header.
777 777 Override with subclasses. The default ordering is simple LRU.
778 778 The default loads are the number of outstanding jobs."""
779 779 self.loads[idx] += 1
780 780 for lis in (self.targets, self.loads):
781 781 lis.append(lis.pop(idx))
782 782
783 783
784 784 def finish_job(self, idx):
785 785 """Called after self.targets[idx] just finished a job.
786 786 Override with subclasses."""
787 787 self.loads[idx] -= 1
788 788
789 789
790 790
791 791 def launch_scheduler(in_addr, out_addr, mon_addr, not_addr, reg_addr, config=None,
792 792 logname='root', log_url=None, loglevel=logging.DEBUG,
793 793 identity=b'task', in_thread=False):
794 794
795 795 ZMQStream = zmqstream.ZMQStream
796 796
797 797 if config:
798 798 # unwrap dict back into Config
799 799 config = Config(config)
800 800
801 801 if in_thread:
802 802 # use instance() to get the same Context/Loop as our parent
803 803 ctx = zmq.Context.instance()
804 804 loop = ioloop.IOLoop.instance()
805 805 else:
806 806 # in a process, don't use instance()
807 807 # for safety with multiprocessing
808 808 ctx = zmq.Context()
809 809 loop = ioloop.IOLoop()
810 810 ins = ZMQStream(ctx.socket(zmq.ROUTER),loop)
811 811 util.set_hwm(ins, 0)
812 812 ins.setsockopt(zmq.IDENTITY, identity + b'_in')
813 813 ins.bind(in_addr)
814 814
815 815 outs = ZMQStream(ctx.socket(zmq.ROUTER),loop)
816 816 util.set_hwm(outs, 0)
817 817 outs.setsockopt(zmq.IDENTITY, identity + b'_out')
818 818 outs.bind(out_addr)
819 819 mons = zmqstream.ZMQStream(ctx.socket(zmq.PUB),loop)
820 820 util.set_hwm(mons, 0)
821 821 mons.connect(mon_addr)
822 822 nots = zmqstream.ZMQStream(ctx.socket(zmq.SUB),loop)
823 823 nots.setsockopt(zmq.SUBSCRIBE, b'')
824 824 nots.connect(not_addr)
825 825
826 826 querys = ZMQStream(ctx.socket(zmq.DEALER),loop)
827 827 querys.connect(reg_addr)
828 828
829 829 # setup logging.
830 830 if in_thread:
831 831 log = Application.instance().log
832 832 else:
833 833 if log_url:
834 834 log = connect_logger(logname, ctx, log_url, root="scheduler", loglevel=loglevel)
835 835 else:
836 836 log = local_logger(logname, loglevel)
837 837
838 838 scheduler = TaskScheduler(client_stream=ins, engine_stream=outs,
839 839 mon_stream=mons, notifier_stream=nots,
840 840 query_stream=querys,
841 841 loop=loop, log=log,
842 842 config=config)
843 843 scheduler.start()
844 844 if not in_thread:
845 845 try:
846 846 loop.start()
847 847 except KeyboardInterrupt:
848 848 scheduler.log.critical("Interrupted, exiting...")
849 849
@@ -1,301 +1,301
1 1 """A simple engine that talks to a controller over 0MQ.
2 2 it handles registration, etc. and launches a kernel
3 3 connected to the Controller's Schedulers.
4 4 """
5 5
6 6 # Copyright (c) IPython Development Team.
7 7 # Distributed under the terms of the Modified BSD License.
8 8
9 9 from __future__ import print_function
10 10
11 11 import sys
12 12 import time
13 13 from getpass import getpass
14 14
15 15 import zmq
16 16 from zmq.eventloop import ioloop, zmqstream
17 17
18 18 from IPython.utils.localinterfaces import localhost
19 19 from IPython.utils.traitlets import (
20 20 Instance, Dict, Integer, Type, Float, Unicode, CBytes, Bool
21 21 )
22 22 from IPython.utils.py3compat import cast_bytes
23 23
24 24 from IPython.parallel.controller.heartmonitor import Heart
25 25 from IPython.parallel.factory import RegistrationFactory
26 26 from IPython.parallel.util import disambiguate_url
27 27
28 28 from IPython.kernel.zmq.ipkernel import IPythonKernel as Kernel
29 29 from IPython.kernel.zmq.kernelapp import IPKernelApp
30 30
31 31 class EngineFactory(RegistrationFactory):
32 32 """IPython engine"""
33 33
34 34 # configurables:
35 35 out_stream_factory=Type('IPython.kernel.zmq.iostream.OutStream', config=True,
36 36 help="""The OutStream for handling stdout/err.
37 37 Typically 'IPython.kernel.zmq.iostream.OutStream'""")
38 38 display_hook_factory=Type('IPython.kernel.zmq.displayhook.ZMQDisplayHook', config=True,
39 39 help="""The class for handling displayhook.
40 40 Typically 'IPython.kernel.zmq.displayhook.ZMQDisplayHook'""")
41 41 location=Unicode(config=True,
42 42 help="""The location (an IP address) of the controller. This is
43 43 used for disambiguating URLs, to determine whether
44 44 loopback should be used to connect or the public address.""")
45 45 timeout=Float(5.0, config=True,
46 46 help="""The time (in seconds) to wait for the Controller to respond
47 47 to registration requests before giving up.""")
48 48 max_heartbeat_misses=Integer(50, config=True,
49 49 help="""The maximum number of times a check for the heartbeat ping of a
50 50 controller can be missed before shutting down the engine.
51 51
52 52 If set to 0, the check is disabled.""")
53 53 sshserver=Unicode(config=True,
54 54 help="""The SSH server to use for tunneling connections to the Controller.""")
55 55 sshkey=Unicode(config=True,
56 56 help="""The SSH private key file to use when tunneling connections to the Controller.""")
57 57 paramiko=Bool(sys.platform == 'win32', config=True,
58 58 help="""Whether to use paramiko instead of openssh for tunnels.""")
59 59
60 60 @property
61 61 def tunnel_mod(self):
62 62 from zmq.ssh import tunnel
63 63 return tunnel
64 64
65 65
66 66 # not configurable:
67 67 connection_info = Dict()
68 68 user_ns = Dict()
69 69 id = Integer(allow_none=True)
70 70 registrar = Instance('zmq.eventloop.zmqstream.ZMQStream')
71 71 kernel = Instance(Kernel)
72 72 hb_check_period=Integer()
73 73
74 74 # States for the heartbeat monitoring
75 75 # Initial values for monitored and pinged must satisfy "monitored > pinged == False" so that
76 76 # during the first check no "missed" ping is reported. Must be floats for Python 3 compatibility.
77 77 _hb_last_pinged = 0.0
78 78 _hb_last_monitored = 0.0
79 79 _hb_missed_beats = 0
80 80 # The zmq Stream which receives the pings from the Heart
81 81 _hb_listener = None
82 82
83 83 bident = CBytes()
84 84 ident = Unicode()
85 85 def _ident_changed(self, name, old, new):
86 86 self.bident = cast_bytes(new)
87 87 using_ssh=Bool(False)
88 88
89 89
90 90 def __init__(self, **kwargs):
91 91 super(EngineFactory, self).__init__(**kwargs)
92 92 self.ident = self.session.session
93 93
94 94 def init_connector(self):
95 95 """construct connection function, which handles tunnels."""
96 96 self.using_ssh = bool(self.sshkey or self.sshserver)
97 97
98 98 if self.sshkey and not self.sshserver:
99 99 # We are using ssh directly to the controller, tunneling localhost to localhost
100 100 self.sshserver = self.url.split('://')[1].split(':')[0]
101 101
102 102 if self.using_ssh:
103 103 if self.tunnel_mod.try_passwordless_ssh(self.sshserver, self.sshkey, self.paramiko):
104 104 password=False
105 105 else:
106 106 password = getpass("SSH Password for %s: "%self.sshserver)
107 107 else:
108 108 password = False
109 109
110 110 def connect(s, url):
111 111 url = disambiguate_url(url, self.location)
112 112 if self.using_ssh:
113 113 self.log.debug("Tunneling connection to %s via %s", url, self.sshserver)
114 114 return self.tunnel_mod.tunnel_connection(s, url, self.sshserver,
115 115 keyfile=self.sshkey, paramiko=self.paramiko,
116 116 password=password,
117 117 )
118 118 else:
119 119 return s.connect(url)
120 120
121 121 def maybe_tunnel(url):
122 122 """like connect, but don't complete the connection (for use by heartbeat)"""
123 123 url = disambiguate_url(url, self.location)
124 124 if self.using_ssh:
125 125 self.log.debug("Tunneling connection to %s via %s", url, self.sshserver)
126 126 url, tunnelobj = self.tunnel_mod.open_tunnel(url, self.sshserver,
127 127 keyfile=self.sshkey, paramiko=self.paramiko,
128 128 password=password,
129 129 )
130 130 return str(url)
131 131 return connect, maybe_tunnel
132 132
133 133 def register(self):
134 134 """send the registration_request"""
135 135
136 136 self.log.info("Registering with controller at %s"%self.url)
137 137 ctx = self.context
138 138 connect,maybe_tunnel = self.init_connector()
139 139 reg = ctx.socket(zmq.DEALER)
140 140 reg.setsockopt(zmq.IDENTITY, self.bident)
141 141 connect(reg, self.url)
142 142 self.registrar = zmqstream.ZMQStream(reg, self.loop)
143 143
144 144
145 145 content = dict(uuid=self.ident)
146 146 self.registrar.on_recv(lambda msg: self.complete_registration(msg, connect, maybe_tunnel))
147 147 # print (self.session.key)
148 148 self.session.send(self.registrar, "registration_request", content=content)
149 149
150 150 def _report_ping(self, msg):
151 151 """Callback for when the heartmonitor.Heart receives a ping"""
152 152 #self.log.debug("Received a ping: %s", msg)
153 153 self._hb_last_pinged = time.time()
154 154
155 155 def complete_registration(self, msg, connect, maybe_tunnel):
156 156 # print msg
157 157 self.loop.remove_timeout(self._abort_timeout)
158 158 ctx = self.context
159 159 loop = self.loop
160 160 identity = self.bident
161 161 idents,msg = self.session.feed_identities(msg)
162 msg = self.session.unserialize(msg)
162 msg = self.session.deserialize(msg)
163 163 content = msg['content']
164 164 info = self.connection_info
165 165
166 166 def url(key):
167 167 """get zmq url for given channel"""
168 168 return str(info["interface"] + ":%i" % info[key])
169 169
170 170 if content['status'] == 'ok':
171 171 self.id = int(content['id'])
172 172
173 173 # launch heartbeat
174 174 # possibly forward hb ports with tunnels
175 175 hb_ping = maybe_tunnel(url('hb_ping'))
176 176 hb_pong = maybe_tunnel(url('hb_pong'))
177 177
178 178 hb_monitor = None
179 179 if self.max_heartbeat_misses > 0:
180 180 # Add a monitor socket which will record the last time a ping was seen
181 181 mon = self.context.socket(zmq.SUB)
182 182 mport = mon.bind_to_random_port('tcp://%s' % localhost())
183 183 mon.setsockopt(zmq.SUBSCRIBE, b"")
184 184 self._hb_listener = zmqstream.ZMQStream(mon, self.loop)
185 185 self._hb_listener.on_recv(self._report_ping)
186 186
187 187
188 188 hb_monitor = "tcp://%s:%i" % (localhost(), mport)
189 189
190 190 heart = Heart(hb_ping, hb_pong, hb_monitor , heart_id=identity)
191 191 heart.start()
192 192
193 193 # create Shell Connections (MUX, Task, etc.):
194 194 shell_addrs = url('mux'), url('task')
195 195
196 196 # Use only one shell stream for mux and tasks
197 197 stream = zmqstream.ZMQStream(ctx.socket(zmq.ROUTER), loop)
198 198 stream.setsockopt(zmq.IDENTITY, identity)
199 199 shell_streams = [stream]
200 200 for addr in shell_addrs:
201 201 connect(stream, addr)
202 202
203 203 # control stream:
204 204 control_addr = url('control')
205 205 control_stream = zmqstream.ZMQStream(ctx.socket(zmq.ROUTER), loop)
206 206 control_stream.setsockopt(zmq.IDENTITY, identity)
207 207 connect(control_stream, control_addr)
208 208
209 209 # create iopub stream:
210 210 iopub_addr = url('iopub')
211 211 iopub_socket = ctx.socket(zmq.PUB)
212 212 iopub_socket.setsockopt(zmq.IDENTITY, identity)
213 213 connect(iopub_socket, iopub_addr)
214 214
215 215 # disable history:
216 216 self.config.HistoryManager.hist_file = ':memory:'
217 217
218 218 # Redirect input streams and set a display hook.
219 219 if self.out_stream_factory:
220 220 sys.stdout = self.out_stream_factory(self.session, iopub_socket, u'stdout')
221 221 sys.stdout.topic = cast_bytes('engine.%i.stdout' % self.id)
222 222 sys.stderr = self.out_stream_factory(self.session, iopub_socket, u'stderr')
223 223 sys.stderr.topic = cast_bytes('engine.%i.stderr' % self.id)
224 224 if self.display_hook_factory:
225 225 sys.displayhook = self.display_hook_factory(self.session, iopub_socket)
226 226 sys.displayhook.topic = cast_bytes('engine.%i.execute_result' % self.id)
227 227
228 228 self.kernel = Kernel(parent=self, int_id=self.id, ident=self.ident, session=self.session,
229 229 control_stream=control_stream, shell_streams=shell_streams, iopub_socket=iopub_socket,
230 230 loop=loop, user_ns=self.user_ns, log=self.log)
231 231
232 232 self.kernel.shell.display_pub.topic = cast_bytes('engine.%i.displaypub' % self.id)
233 233
234 234
235 235 # periodically check the heartbeat pings of the controller
236 236 # Should be started here and not in "start()" so that the right period can be taken
237 237 # from the hubs HeartBeatMonitor.period
238 238 if self.max_heartbeat_misses > 0:
239 239 # Use a slightly bigger check period than the hub signal period to not warn unnecessary
240 240 self.hb_check_period = int(content['hb_period'])+10
241 241 self.log.info("Starting to monitor the heartbeat signal from the hub every %i ms." , self.hb_check_period)
242 242 self._hb_reporter = ioloop.PeriodicCallback(self._hb_monitor, self.hb_check_period, self.loop)
243 243 self._hb_reporter.start()
244 244 else:
245 245 self.log.info("Monitoring of the heartbeat signal from the hub is not enabled.")
246 246
247 247
248 248 # FIXME: This is a hack until IPKernelApp and IPEngineApp can be fully merged
249 249 app = IPKernelApp(parent=self, shell=self.kernel.shell, kernel=self.kernel, log=self.log)
250 250 app.init_profile_dir()
251 251 app.init_code()
252 252
253 253 self.kernel.start()
254 254 else:
255 255 self.log.fatal("Registration Failed: %s"%msg)
256 256 raise Exception("Registration Failed: %s"%msg)
257 257
258 258 self.log.info("Completed registration with id %i"%self.id)
259 259
260 260
261 261 def abort(self):
262 262 self.log.fatal("Registration timed out after %.1f seconds"%self.timeout)
263 263 if self.url.startswith('127.'):
264 264 self.log.fatal("""
265 265 If the controller and engines are not on the same machine,
266 266 you will have to instruct the controller to listen on an external IP (in ipcontroller_config.py):
267 267 c.HubFactory.ip='*' # for all interfaces, internal and external
268 268 c.HubFactory.ip='192.168.1.101' # or any interface that the engines can see
269 269 or tunnel connections via ssh.
270 270 """)
271 271 self.session.send(self.registrar, "unregistration_request", content=dict(id=self.id))
272 272 time.sleep(1)
273 273 sys.exit(255)
274 274
275 275 def _hb_monitor(self):
276 276 """Callback to monitor the heartbeat from the controller"""
277 277 self._hb_listener.flush()
278 278 if self._hb_last_monitored > self._hb_last_pinged:
279 279 self._hb_missed_beats += 1
280 280 self.log.warn("No heartbeat in the last %s ms (%s time(s) in a row).", self.hb_check_period, self._hb_missed_beats)
281 281 else:
282 282 #self.log.debug("Heartbeat received (after missing %s beats).", self._hb_missed_beats)
283 283 self._hb_missed_beats = 0
284 284
285 285 if self._hb_missed_beats >= self.max_heartbeat_misses:
286 286 self.log.fatal("Maximum number of heartbeats misses reached (%s times %s ms), shutting down.",
287 287 self.max_heartbeat_misses, self.hb_check_period)
288 288 self.session.send(self.registrar, "unregistration_request", content=dict(id=self.id))
289 289 self.loop.stop()
290 290
291 291 self._hb_last_monitored = time.time()
292 292
293 293
294 294 def start(self):
295 295 loop = self.loop
296 296 def _start():
297 297 self.register()
298 298 self._abort_timeout = loop.add_timeout(loop.time() + self.timeout, self.abort)
299 299 self.loop.add_callback(_start)
300 300
301 301
@@ -1,734 +1,735
1 1 # encoding: utf-8
2 2 """
3 3 This module defines the things that are used in setup.py for building IPython
4 4
5 5 This includes:
6 6
7 7 * The basic arguments to setup
8 8 * Functions for finding things like packages, package data, etc.
9 9 * A function for checking dependencies.
10 10 """
11 11
12 12 # Copyright (c) IPython Development Team.
13 13 # Distributed under the terms of the Modified BSD License.
14 14
15 15 from __future__ import print_function
16 16
17 17 import errno
18 18 import os
19 19 import sys
20 20
21 21 from distutils import log
22 22 from distutils.command.build_py import build_py
23 23 from distutils.command.build_scripts import build_scripts
24 24 from distutils.command.install import install
25 25 from distutils.command.install_scripts import install_scripts
26 26 from distutils.cmd import Command
27 27 from fnmatch import fnmatch
28 28 from glob import glob
29 29 from subprocess import check_call
30 30
31 31 from setupext import install_data_ext
32 32
33 33 #-------------------------------------------------------------------------------
34 34 # Useful globals and utility functions
35 35 #-------------------------------------------------------------------------------
36 36
37 37 # A few handy globals
38 38 isfile = os.path.isfile
39 39 pjoin = os.path.join
40 40 repo_root = os.path.dirname(os.path.abspath(__file__))
41 41
42 42 def oscmd(s):
43 43 print(">", s)
44 44 os.system(s)
45 45
46 46 # Py3 compatibility hacks, without assuming IPython itself is installed with
47 47 # the full py3compat machinery.
48 48
49 49 try:
50 50 execfile
51 51 except NameError:
52 52 def execfile(fname, globs, locs=None):
53 53 locs = locs or globs
54 54 exec(compile(open(fname).read(), fname, "exec"), globs, locs)
55 55
56 56 # A little utility we'll need below, since glob() does NOT allow you to do
57 57 # exclusion on multiple endings!
58 58 def file_doesnt_endwith(test,endings):
59 59 """Return true if test is a file and its name does NOT end with any
60 60 of the strings listed in endings."""
61 61 if not isfile(test):
62 62 return False
63 63 for e in endings:
64 64 if test.endswith(e):
65 65 return False
66 66 return True
67 67
68 68 #---------------------------------------------------------------------------
69 69 # Basic project information
70 70 #---------------------------------------------------------------------------
71 71
72 72 # release.py contains version, authors, license, url, keywords, etc.
73 73 execfile(pjoin(repo_root, 'IPython','core','release.py'), globals())
74 74
75 75 # Create a dict with the basic information
76 76 # This dict is eventually passed to setup after additional keys are added.
77 77 setup_args = dict(
78 78 name = name,
79 79 version = version,
80 80 description = description,
81 81 long_description = long_description,
82 82 author = author,
83 83 author_email = author_email,
84 84 url = url,
85 85 download_url = download_url,
86 86 license = license,
87 87 platforms = platforms,
88 88 keywords = keywords,
89 89 classifiers = classifiers,
90 90 cmdclass = {'install_data': install_data_ext},
91 91 )
92 92
93 93
94 94 #---------------------------------------------------------------------------
95 95 # Find packages
96 96 #---------------------------------------------------------------------------
97 97
98 98 def find_packages():
99 99 """
100 100 Find all of IPython's packages.
101 101 """
102 102 excludes = ['deathrow', 'quarantine']
103 103 packages = []
104 104 for dir,subdirs,files in os.walk('IPython'):
105 105 package = dir.replace(os.path.sep, '.')
106 106 if any(package.startswith('IPython.'+exc) for exc in excludes):
107 107 # package is to be excluded (e.g. deathrow)
108 108 continue
109 109 if '__init__.py' not in files:
110 110 # not a package
111 111 continue
112 112 packages.append(package)
113 113 return packages
114 114
115 115 #---------------------------------------------------------------------------
116 116 # Find package data
117 117 #---------------------------------------------------------------------------
118 118
119 119 def find_package_data():
120 120 """
121 121 Find IPython's package_data.
122 122 """
123 123 # This is not enough for these things to appear in an sdist.
124 124 # We need to muck with the MANIFEST to get this to work
125 125
126 126 # exclude components and less from the walk;
127 127 # we will build the components separately
128 128 excludes = [
129 129 pjoin('static', 'components'),
130 130 pjoin('static', '*', 'less'),
131 131 ]
132 132
133 133 # walk notebook resources:
134 134 cwd = os.getcwd()
135 135 os.chdir(os.path.join('IPython', 'html'))
136 136 static_data = []
137 137 for parent, dirs, files in os.walk('static'):
138 138 if any(fnmatch(parent, pat) for pat in excludes):
139 139 # prevent descending into subdirs
140 140 dirs[:] = []
141 141 continue
142 142 for f in files:
143 143 static_data.append(pjoin(parent, f))
144 144
145 145 components = pjoin("static", "components")
146 146 # select the components we actually need to install
147 147 # (there are lots of resources we bundle for sdist-reasons that we don't actually use)
148 148 static_data.extend([
149 149 pjoin(components, "backbone", "backbone-min.js"),
150 150 pjoin(components, "bootstrap", "js", "bootstrap.min.js"),
151 151 pjoin(components, "bootstrap-tour", "build", "css", "bootstrap-tour.min.css"),
152 152 pjoin(components, "bootstrap-tour", "build", "js", "bootstrap-tour.min.js"),
153 153 pjoin(components, "font-awesome", "fonts", "*.*"),
154 154 pjoin(components, "google-caja", "html-css-sanitizer-minified.js"),
155 155 pjoin(components, "highlight.js", "build", "highlight.pack.js"),
156 156 pjoin(components, "jquery", "jquery.min.js"),
157 157 pjoin(components, "jquery-ui", "ui", "minified", "jquery-ui.min.js"),
158 158 pjoin(components, "jquery-ui", "themes", "smoothness", "jquery-ui.min.css"),
159 159 pjoin(components, "jquery-ui", "themes", "smoothness", "images", "*"),
160 160 pjoin(components, "marked", "lib", "marked.js"),
161 161 pjoin(components, "requirejs", "require.js"),
162 162 pjoin(components, "underscore", "underscore-min.js"),
163 163 pjoin(components, "moment", "moment.js"),
164 164 pjoin(components, "moment", "min","moment.min.js"),
165 pjoin(components, "text-encoding", "lib", "encoding.js"),
165 166 ])
166 167
167 168 # Ship all of Codemirror's CSS and JS
168 169 for parent, dirs, files in os.walk(pjoin(components, 'codemirror')):
169 170 for f in files:
170 171 if f.endswith(('.js', '.css')):
171 172 static_data.append(pjoin(parent, f))
172 173
173 174 os.chdir(os.path.join('tests',))
174 175 js_tests = glob('*.js') + glob('*/*.js')
175 176
176 177 os.chdir(os.path.join(cwd, 'IPython', 'nbconvert'))
177 178 nbconvert_templates = [os.path.join(dirpath, '*.*')
178 179 for dirpath, _, _ in os.walk('templates')]
179 180
180 181 os.chdir(cwd)
181 182
182 183 package_data = {
183 184 'IPython.config.profile' : ['README*', '*/*.py'],
184 185 'IPython.core.tests' : ['*.png', '*.jpg'],
185 186 'IPython.lib.tests' : ['*.wav'],
186 187 'IPython.testing.plugin' : ['*.txt'],
187 188 'IPython.html' : ['templates/*'] + static_data,
188 189 'IPython.html.tests' : js_tests,
189 190 'IPython.qt.console' : ['resources/icon/*.svg'],
190 191 'IPython.nbconvert' : nbconvert_templates +
191 192 [
192 193 'tests/files/*.*',
193 194 'exporters/tests/files/*.*',
194 195 'preprocessors/tests/files/*.*',
195 196 ],
196 197 'IPython.nbconvert.filters' : ['marked.js'],
197 198 'IPython.nbformat' : [
198 199 'tests/*.ipynb',
199 200 'v3/nbformat.v3.schema.json',
200 201 ]
201 202 }
202 203
203 204 return package_data
204 205
205 206
206 207 def check_package_data(package_data):
207 208 """verify that package_data globs make sense"""
208 209 print("checking package data")
209 210 for pkg, data in package_data.items():
210 211 pkg_root = pjoin(*pkg.split('.'))
211 212 for d in data:
212 213 path = pjoin(pkg_root, d)
213 214 if '*' in path:
214 215 assert len(glob(path)) > 0, "No files match pattern %s" % path
215 216 else:
216 217 assert os.path.exists(path), "Missing package data: %s" % path
217 218
218 219
219 220 def check_package_data_first(command):
220 221 """decorator for checking package_data before running a given command
221 222
222 223 Probably only needs to wrap build_py
223 224 """
224 225 class DecoratedCommand(command):
225 226 def run(self):
226 227 check_package_data(self.package_data)
227 228 command.run(self)
228 229 return DecoratedCommand
229 230
230 231
231 232 #---------------------------------------------------------------------------
232 233 # Find data files
233 234 #---------------------------------------------------------------------------
234 235
235 236 def make_dir_struct(tag,base,out_base):
236 237 """Make the directory structure of all files below a starting dir.
237 238
238 239 This is just a convenience routine to help build a nested directory
239 240 hierarchy because distutils is too stupid to do this by itself.
240 241
241 242 XXX - this needs a proper docstring!
242 243 """
243 244
244 245 # we'll use these a lot below
245 246 lbase = len(base)
246 247 pathsep = os.path.sep
247 248 lpathsep = len(pathsep)
248 249
249 250 out = []
250 251 for (dirpath,dirnames,filenames) in os.walk(base):
251 252 # we need to strip out the dirpath from the base to map it to the
252 253 # output (installation) path. This requires possibly stripping the
253 254 # path separator, because otherwise pjoin will not work correctly
254 255 # (pjoin('foo/','/bar') returns '/bar').
255 256
256 257 dp_eff = dirpath[lbase:]
257 258 if dp_eff.startswith(pathsep):
258 259 dp_eff = dp_eff[lpathsep:]
259 260 # The output path must be anchored at the out_base marker
260 261 out_path = pjoin(out_base,dp_eff)
261 262 # Now we can generate the final filenames. Since os.walk only produces
262 263 # filenames, we must join back with the dirpath to get full valid file
263 264 # paths:
264 265 pfiles = [pjoin(dirpath,f) for f in filenames]
265 266 # Finally, generate the entry we need, which is a pari of (output
266 267 # path, files) for use as a data_files parameter in install_data.
267 268 out.append((out_path, pfiles))
268 269
269 270 return out
270 271
271 272
272 273 def find_data_files():
273 274 """
274 275 Find IPython's data_files.
275 276
276 277 Just man pages at this point.
277 278 """
278 279
279 280 manpagebase = pjoin('share', 'man', 'man1')
280 281
281 282 # Simple file lists can be made by hand
282 283 manpages = [f for f in glob(pjoin('docs','man','*.1.gz')) if isfile(f)]
283 284 if not manpages:
284 285 # When running from a source tree, the manpages aren't gzipped
285 286 manpages = [f for f in glob(pjoin('docs','man','*.1')) if isfile(f)]
286 287
287 288 # And assemble the entire output list
288 289 data_files = [ (manpagebase, manpages) ]
289 290
290 291 return data_files
291 292
292 293
293 294 def make_man_update_target(manpage):
294 295 """Return a target_update-compliant tuple for the given manpage.
295 296
296 297 Parameters
297 298 ----------
298 299 manpage : string
299 300 Name of the manpage, must include the section number (trailing number).
300 301
301 302 Example
302 303 -------
303 304
304 305 >>> make_man_update_target('ipython.1') #doctest: +NORMALIZE_WHITESPACE
305 306 ('docs/man/ipython.1.gz',
306 307 ['docs/man/ipython.1'],
307 308 'cd docs/man && gzip -9c ipython.1 > ipython.1.gz')
308 309 """
309 310 man_dir = pjoin('docs', 'man')
310 311 manpage_gz = manpage + '.gz'
311 312 manpath = pjoin(man_dir, manpage)
312 313 manpath_gz = pjoin(man_dir, manpage_gz)
313 314 gz_cmd = ( "cd %(man_dir)s && gzip -9c %(manpage)s > %(manpage_gz)s" %
314 315 locals() )
315 316 return (manpath_gz, [manpath], gz_cmd)
316 317
317 318 # The two functions below are copied from IPython.utils.path, so we don't need
318 319 # to import IPython during setup, which fails on Python 3.
319 320
320 321 def target_outdated(target,deps):
321 322 """Determine whether a target is out of date.
322 323
323 324 target_outdated(target,deps) -> 1/0
324 325
325 326 deps: list of filenames which MUST exist.
326 327 target: single filename which may or may not exist.
327 328
328 329 If target doesn't exist or is older than any file listed in deps, return
329 330 true, otherwise return false.
330 331 """
331 332 try:
332 333 target_time = os.path.getmtime(target)
333 334 except os.error:
334 335 return 1
335 336 for dep in deps:
336 337 dep_time = os.path.getmtime(dep)
337 338 if dep_time > target_time:
338 339 #print "For target",target,"Dep failed:",dep # dbg
339 340 #print "times (dep,tar):",dep_time,target_time # dbg
340 341 return 1
341 342 return 0
342 343
343 344
344 345 def target_update(target,deps,cmd):
345 346 """Update a target with a given command given a list of dependencies.
346 347
347 348 target_update(target,deps,cmd) -> runs cmd if target is outdated.
348 349
349 350 This is just a wrapper around target_outdated() which calls the given
350 351 command if target is outdated."""
351 352
352 353 if target_outdated(target,deps):
353 354 os.system(cmd)
354 355
355 356 #---------------------------------------------------------------------------
356 357 # Find scripts
357 358 #---------------------------------------------------------------------------
358 359
359 360 def find_entry_points():
360 361 """Find IPython's scripts.
361 362
362 363 if entry_points is True:
363 364 return setuptools entry_point-style definitions
364 365 else:
365 366 return file paths of plain scripts [default]
366 367
367 368 suffix is appended to script names if entry_points is True, so that the
368 369 Python 3 scripts get named "ipython3" etc.
369 370 """
370 371 ep = [
371 372 'ipython%s = IPython:start_ipython',
372 373 'ipcontroller%s = IPython.parallel.apps.ipcontrollerapp:launch_new_instance',
373 374 'ipengine%s = IPython.parallel.apps.ipengineapp:launch_new_instance',
374 375 'ipcluster%s = IPython.parallel.apps.ipclusterapp:launch_new_instance',
375 376 'iptest%s = IPython.testing.iptestcontroller:main',
376 377 ]
377 378 suffix = str(sys.version_info[0])
378 379 return [e % '' for e in ep] + [e % suffix for e in ep]
379 380
380 381 script_src = """#!{executable}
381 382 # This script was automatically generated by setup.py
382 383 if __name__ == '__main__':
383 384 from {mod} import {func}
384 385 {func}()
385 386 """
386 387
387 388 class build_scripts_entrypt(build_scripts):
388 389 def run(self):
389 390 self.mkpath(self.build_dir)
390 391 outfiles = []
391 392 for script in find_entry_points():
392 393 name, entrypt = script.split('=')
393 394 name = name.strip()
394 395 entrypt = entrypt.strip()
395 396 outfile = os.path.join(self.build_dir, name)
396 397 outfiles.append(outfile)
397 398 print('Writing script to', outfile)
398 399
399 400 mod, func = entrypt.split(':')
400 401 with open(outfile, 'w') as f:
401 402 f.write(script_src.format(executable=sys.executable,
402 403 mod=mod, func=func))
403 404
404 405 return outfiles, outfiles
405 406
406 407 class install_lib_symlink(Command):
407 408 user_options = [
408 409 ('install-dir=', 'd', "directory to install to"),
409 410 ]
410 411
411 412 def initialize_options(self):
412 413 self.install_dir = None
413 414
414 415 def finalize_options(self):
415 416 self.set_undefined_options('symlink',
416 417 ('install_lib', 'install_dir'),
417 418 )
418 419
419 420 def run(self):
420 421 if sys.platform == 'win32':
421 422 raise Exception("This doesn't work on Windows.")
422 423 pkg = os.path.join(os.getcwd(), 'IPython')
423 424 dest = os.path.join(self.install_dir, 'IPython')
424 425 if os.path.islink(dest):
425 426 print('removing existing symlink at %s' % dest)
426 427 os.unlink(dest)
427 428 print('symlinking %s -> %s' % (pkg, dest))
428 429 os.symlink(pkg, dest)
429 430
430 431 class unsymlink(install):
431 432 def run(self):
432 433 dest = os.path.join(self.install_lib, 'IPython')
433 434 if os.path.islink(dest):
434 435 print('removing symlink at %s' % dest)
435 436 os.unlink(dest)
436 437 else:
437 438 print('No symlink exists at %s' % dest)
438 439
439 440 class install_symlinked(install):
440 441 def run(self):
441 442 if sys.platform == 'win32':
442 443 raise Exception("This doesn't work on Windows.")
443 444
444 445 # Run all sub-commands (at least those that need to be run)
445 446 for cmd_name in self.get_sub_commands():
446 447 self.run_command(cmd_name)
447 448
448 449 # 'sub_commands': a list of commands this command might have to run to
449 450 # get its work done. See cmd.py for more info.
450 451 sub_commands = [('install_lib_symlink', lambda self:True),
451 452 ('install_scripts_sym', lambda self:True),
452 453 ]
453 454
454 455 class install_scripts_for_symlink(install_scripts):
455 456 """Redefined to get options from 'symlink' instead of 'install'.
456 457
457 458 I love distutils almost as much as I love setuptools.
458 459 """
459 460 def finalize_options(self):
460 461 self.set_undefined_options('build', ('build_scripts', 'build_dir'))
461 462 self.set_undefined_options('symlink',
462 463 ('install_scripts', 'install_dir'),
463 464 ('force', 'force'),
464 465 ('skip_build', 'skip_build'),
465 466 )
466 467
467 468 #---------------------------------------------------------------------------
468 469 # Verify all dependencies
469 470 #---------------------------------------------------------------------------
470 471
471 472 def check_for_dependencies():
472 473 """Check for IPython's dependencies.
473 474
474 475 This function should NOT be called if running under setuptools!
475 476 """
476 477 from setupext.setupext import (
477 478 print_line, print_raw, print_status,
478 479 check_for_sphinx, check_for_pygments,
479 480 check_for_nose, check_for_pexpect,
480 481 check_for_pyzmq, check_for_readline,
481 482 check_for_jinja2, check_for_tornado
482 483 )
483 484 print_line()
484 485 print_raw("BUILDING IPYTHON")
485 486 print_status('python', sys.version)
486 487 print_status('platform', sys.platform)
487 488 if sys.platform == 'win32':
488 489 print_status('Windows version', sys.getwindowsversion())
489 490
490 491 print_raw("")
491 492 print_raw("OPTIONAL DEPENDENCIES")
492 493
493 494 check_for_sphinx()
494 495 check_for_pygments()
495 496 check_for_nose()
496 497 if os.name == 'posix':
497 498 check_for_pexpect()
498 499 check_for_pyzmq()
499 500 check_for_tornado()
500 501 check_for_readline()
501 502 check_for_jinja2()
502 503
503 504 #---------------------------------------------------------------------------
504 505 # VCS related
505 506 #---------------------------------------------------------------------------
506 507
507 508 # utils.submodule has checks for submodule status
508 509 execfile(pjoin('IPython','utils','submodule.py'), globals())
509 510
510 511 class UpdateSubmodules(Command):
511 512 """Update git submodules
512 513
513 514 IPython's external javascript dependencies live in a separate repo.
514 515 """
515 516 description = "Update git submodules"
516 517 user_options = []
517 518
518 519 def initialize_options(self):
519 520 pass
520 521
521 522 def finalize_options(self):
522 523 pass
523 524
524 525 def run(self):
525 526 failure = False
526 527 try:
527 528 self.spawn('git submodule init'.split())
528 529 self.spawn('git submodule update --recursive'.split())
529 530 except Exception as e:
530 531 failure = e
531 532 print(e)
532 533
533 534 if not check_submodule_status(repo_root) == 'clean':
534 535 print("submodules could not be checked out")
535 536 sys.exit(1)
536 537
537 538
538 539 def git_prebuild(pkg_dir, build_cmd=build_py):
539 540 """Return extended build or sdist command class for recording commit
540 541
541 542 records git commit in IPython.utils._sysinfo.commit
542 543
543 544 for use in IPython.utils.sysinfo.sys_info() calls after installation.
544 545
545 546 Also ensures that submodules exist prior to running
546 547 """
547 548
548 549 class MyBuildPy(build_cmd):
549 550 ''' Subclass to write commit data into installation tree '''
550 551 def run(self):
551 552 build_cmd.run(self)
552 553 # this one will only fire for build commands
553 554 if hasattr(self, 'build_lib'):
554 555 self._record_commit(self.build_lib)
555 556
556 557 def make_release_tree(self, base_dir, files):
557 558 # this one will fire for sdist
558 559 build_cmd.make_release_tree(self, base_dir, files)
559 560 self._record_commit(base_dir)
560 561
561 562 def _record_commit(self, base_dir):
562 563 import subprocess
563 564 proc = subprocess.Popen('git rev-parse --short HEAD',
564 565 stdout=subprocess.PIPE,
565 566 stderr=subprocess.PIPE,
566 567 shell=True)
567 568 repo_commit, _ = proc.communicate()
568 569 repo_commit = repo_commit.strip().decode("ascii")
569 570
570 571 out_pth = pjoin(base_dir, pkg_dir, 'utils', '_sysinfo.py')
571 572 if os.path.isfile(out_pth) and not repo_commit:
572 573 # nothing to write, don't clobber
573 574 return
574 575
575 576 print("writing git commit '%s' to %s" % (repo_commit, out_pth))
576 577
577 578 # remove to avoid overwriting original via hard link
578 579 try:
579 580 os.remove(out_pth)
580 581 except (IOError, OSError):
581 582 pass
582 583 with open(out_pth, 'w') as out_file:
583 584 out_file.writelines([
584 585 '# GENERATED BY setup.py\n',
585 586 'commit = u"%s"\n' % repo_commit,
586 587 ])
587 588 return require_submodules(MyBuildPy)
588 589
589 590
590 591 def require_submodules(command):
591 592 """decorator for instructing a command to check for submodules before running"""
592 593 class DecoratedCommand(command):
593 594 def run(self):
594 595 if not check_submodule_status(repo_root) == 'clean':
595 596 print("submodules missing! Run `setup.py submodule` and try again")
596 597 sys.exit(1)
597 598 command.run(self)
598 599 return DecoratedCommand
599 600
600 601 #---------------------------------------------------------------------------
601 602 # bdist related
602 603 #---------------------------------------------------------------------------
603 604
604 605 def get_bdist_wheel():
605 606 """Construct bdist_wheel command for building wheels
606 607
607 608 Constructs py2-none-any tag, instead of py2.7-none-any
608 609 """
609 610 class RequiresWheel(Command):
610 611 description = "Dummy command for missing bdist_wheel"
611 612 user_options = []
612 613
613 614 def initialize_options(self):
614 615 pass
615 616
616 617 def finalize_options(self):
617 618 pass
618 619
619 620 def run(self):
620 621 print("bdist_wheel requires the wheel package")
621 622 sys.exit(1)
622 623
623 624 if 'setuptools' not in sys.modules:
624 625 return RequiresWheel
625 626 else:
626 627 try:
627 628 from wheel.bdist_wheel import bdist_wheel, read_pkg_info, write_pkg_info
628 629 except ImportError:
629 630 return RequiresWheel
630 631
631 632 class bdist_wheel_tag(bdist_wheel):
632 633
633 634 def add_requirements(self, metadata_path):
634 635 """transform platform-dependent requirements"""
635 636 pkg_info = read_pkg_info(metadata_path)
636 637 # pkg_info is an email.Message object (?!)
637 638 # we have to remove the unconditional 'readline' and/or 'pyreadline' entries
638 639 # and transform them to conditionals
639 640 requires = pkg_info.get_all('Requires-Dist')
640 641 del pkg_info['Requires-Dist']
641 642 def _remove_startswith(lis, prefix):
642 643 """like list.remove, but with startswith instead of =="""
643 644 found = False
644 645 for idx, item in enumerate(lis):
645 646 if item.startswith(prefix):
646 647 found = True
647 648 break
648 649 if found:
649 650 lis.pop(idx)
650 651
651 652 for pkg in ("gnureadline", "pyreadline", "mock"):
652 653 _remove_startswith(requires, pkg)
653 654 requires.append("gnureadline; sys.platform == 'darwin' and platform.python_implementation == 'CPython'")
654 655 requires.append("pyreadline (>=2.0); extra == 'terminal' and sys.platform == 'win32' and platform.python_implementation == 'CPython'")
655 656 requires.append("pyreadline (>=2.0); extra == 'all' and sys.platform == 'win32' and platform.python_implementation == 'CPython'")
656 657 requires.append("mock; extra == 'test' and python_version < '3.3'")
657 658 for r in requires:
658 659 pkg_info['Requires-Dist'] = r
659 660 write_pkg_info(metadata_path, pkg_info)
660 661
661 662 return bdist_wheel_tag
662 663
663 664 #---------------------------------------------------------------------------
664 665 # Notebook related
665 666 #---------------------------------------------------------------------------
666 667
667 668 class CompileCSS(Command):
668 669 """Recompile Notebook CSS
669 670
670 671 Regenerate the compiled CSS from LESS sources.
671 672
672 673 Requires various dev dependencies, such as invoke and lessc.
673 674 """
674 675 description = "Recompile Notebook CSS"
675 676 user_options = [
676 677 ('minify', 'x', "minify CSS"),
677 678 ('force', 'f', "force recompilation of CSS"),
678 679 ]
679 680
680 681 def initialize_options(self):
681 682 self.minify = False
682 683 self.force = False
683 684
684 685 def finalize_options(self):
685 686 self.minify = bool(self.minify)
686 687 self.force = bool(self.force)
687 688
688 689 def run(self):
689 690 cmd = ['invoke', 'css']
690 691 if self.minify:
691 692 cmd.append('--minify')
692 693 if self.force:
693 694 cmd.append('--force')
694 695 check_call(cmd, cwd=pjoin(repo_root, "IPython", "html"))
695 696
696 697
697 698 class JavascriptVersion(Command):
698 699 """write the javascript version to notebook javascript"""
699 700 description = "Write IPython version to javascript"
700 701 user_options = []
701 702
702 703 def initialize_options(self):
703 704 pass
704 705
705 706 def finalize_options(self):
706 707 pass
707 708
708 709 def run(self):
709 710 nsfile = pjoin(repo_root, "IPython", "html", "static", "base", "js", "namespace.js")
710 711 with open(nsfile) as f:
711 712 lines = f.readlines()
712 713 with open(nsfile, 'w') as f:
713 714 for line in lines:
714 715 if line.startswith("IPython.version"):
715 716 line = 'IPython.version = "{0}";\n'.format(version)
716 717 f.write(line)
717 718
718 719
719 720 def css_js_prerelease(command, strict=True):
720 721 """decorator for building js/minified css prior to a release"""
721 722 class DecoratedCommand(command):
722 723 def run(self):
723 724 self.distribution.run_command('jsversion')
724 725 css = self.distribution.get_command_obj('css')
725 726 css.minify = True
726 727 try:
727 728 self.distribution.run_command('css')
728 729 except Exception as e:
729 730 if strict:
730 731 raise
731 732 else:
732 733 log.warn("Failed to build css sourcemaps: %s" % e)
733 734 command.run(self)
734 735 return DecoratedCommand
General Comments 0
You need to be logged in to leave comments. Login now