##// END OF EJS Templates
pyin -> execute_input
MinRK -
Show More
@@ -1,442 +1,442 b''
1 1 """Test suite for our zeromq-based message specification."""
2 2
3 3 # Copyright (c) IPython Development Team.
4 4 # Distributed under the terms of the Modified BSD License.
5 5
6 6 import re
7 7 from distutils.version import LooseVersion as V
8 8 from subprocess import PIPE
9 9 try:
10 10 from queue import Empty # Py 3
11 11 except ImportError:
12 12 from Queue import Empty # Py 2
13 13
14 14 import nose.tools as nt
15 15
16 16 from IPython.kernel import KernelManager
17 17
18 18 from IPython.utils.traitlets import (
19 19 HasTraits, TraitError, Bool, Unicode, Dict, Integer, List, Enum, Any,
20 20 )
21 21 from IPython.utils.py3compat import string_types, iteritems
22 22
23 23 from .utils import TIMEOUT, start_global_kernel, flush_channels, execute
24 24
25 25 #-----------------------------------------------------------------------------
26 26 # Globals
27 27 #-----------------------------------------------------------------------------
28 28 KC = None
29 29
30 30 def setup():
31 31 global KC
32 32 KC = start_global_kernel()
33 33
34 34 #-----------------------------------------------------------------------------
35 35 # Message Spec References
36 36 #-----------------------------------------------------------------------------
37 37
38 38 class Reference(HasTraits):
39 39
40 40 """
41 41 Base class for message spec specification testing.
42 42
43 43 This class is the core of the message specification test. The
44 44 idea is that child classes implement trait attributes for each
45 45 message keys, so that message keys can be tested against these
46 46 traits using :meth:`check` method.
47 47
48 48 """
49 49
50 50 def check(self, d):
51 51 """validate a dict against our traits"""
52 52 for key in self.trait_names():
53 53 nt.assert_in(key, d)
54 54 # FIXME: always allow None, probably not a good idea
55 55 if d[key] is None:
56 56 continue
57 57 try:
58 58 setattr(self, key, d[key])
59 59 except TraitError as e:
60 60 assert False, str(e)
61 61
62 62 class Version(Unicode):
63 63 def validate(self, obj, value):
64 64 min_version = self.default_value
65 65 if V(value) < V(min_version):
66 66 raise TraitError("bad version: %s < %s" % (value, min_version))
67 67
68 68 class RMessage(Reference):
69 69 msg_id = Unicode()
70 70 msg_type = Unicode()
71 71 header = Dict()
72 72 parent_header = Dict()
73 73 content = Dict()
74 74
75 75 def check(self, d):
76 76 super(RMessage, self).check(d)
77 77 RHeader().check(self.header)
78 78 RHeader().check(self.parent_header)
79 79
80 80 class RHeader(Reference):
81 81 msg_id = Unicode()
82 82 msg_type = Unicode()
83 83 session = Unicode()
84 84 username = Unicode()
85 85 version = Version('5.0')
86 86
87 87
88 88 class ExecuteReply(Reference):
89 89 execution_count = Integer()
90 90 status = Enum((u'ok', u'error'))
91 91
92 92 def check(self, d):
93 93 Reference.check(self, d)
94 94 if d['status'] == 'ok':
95 95 ExecuteReplyOkay().check(d)
96 96 elif d['status'] == 'error':
97 97 ExecuteReplyError().check(d)
98 98
99 99
100 100 class ExecuteReplyOkay(Reference):
101 101 payload = List(Dict)
102 102 user_variables = Dict()
103 103 user_expressions = Dict()
104 104
105 105
106 106 class ExecuteReplyError(Reference):
107 107 ename = Unicode()
108 108 evalue = Unicode()
109 109 traceback = List(Unicode)
110 110
111 111
112 112 class OInfoReply(Reference):
113 113 name = Unicode()
114 114 found = Bool()
115 115 ismagic = Bool()
116 116 isalias = Bool()
117 117 namespace = Enum((u'builtin', u'magics', u'alias', u'Interactive'))
118 118 type_name = Unicode()
119 119 string_form = Unicode()
120 120 base_class = Unicode()
121 121 length = Integer()
122 122 file = Unicode()
123 123 definition = Unicode()
124 124 argspec = Dict()
125 125 init_definition = Unicode()
126 126 docstring = Unicode()
127 127 init_docstring = Unicode()
128 128 class_docstring = Unicode()
129 129 call_def = Unicode()
130 130 call_docstring = Unicode()
131 131 source = Unicode()
132 132
133 133 def check(self, d):
134 134 super(OInfoReply, self).check(d)
135 135 if d['argspec'] is not None:
136 136 ArgSpec().check(d['argspec'])
137 137
138 138
139 139 class ArgSpec(Reference):
140 140 args = List(Unicode)
141 141 varargs = Unicode()
142 142 varkw = Unicode()
143 143 defaults = List()
144 144
145 145
146 146 class Status(Reference):
147 147 execution_state = Enum((u'busy', u'idle', u'starting'))
148 148
149 149
150 150 class CompleteReply(Reference):
151 151 matches = List(Unicode)
152 152
153 153
154 154 class KernelInfoReply(Reference):
155 155 protocol_version = Version('5.0')
156 156 ipython_version = Version('2.0')
157 157 language_version = Version('2.7')
158 158 language = Unicode()
159 159
160 160
161 161 # IOPub messages
162 162
163 class PyIn(Reference):
163 class ExecuteInput(Reference):
164 164 code = Unicode()
165 165 execution_count = Integer()
166 166
167 167
168 168 PyErr = ExecuteReplyError
169 169
170 170
171 171 class Stream(Reference):
172 172 name = Enum((u'stdout', u'stderr'))
173 173 data = Unicode()
174 174
175 175
176 176 mime_pat = re.compile(r'\w+/\w+')
177 177
178 178 class DisplayData(Reference):
179 179 source = Unicode()
180 180 metadata = Dict()
181 181 data = Dict()
182 182 def _data_changed(self, name, old, new):
183 183 for k,v in iteritems(new):
184 184 assert mime_pat.match(k)
185 185 nt.assert_is_instance(v, string_types)
186 186
187 187
188 188 class PyOut(Reference):
189 189 execution_count = Integer()
190 190 data = Dict()
191 191 def _data_changed(self, name, old, new):
192 192 for k,v in iteritems(new):
193 193 assert mime_pat.match(k)
194 194 nt.assert_is_instance(v, string_types)
195 195
196 196
197 197 references = {
198 198 'execute_reply' : ExecuteReply(),
199 199 'object_info_reply' : OInfoReply(),
200 200 'status' : Status(),
201 201 'complete_reply' : CompleteReply(),
202 202 'kernel_info_reply': KernelInfoReply(),
203 'pyin' : PyIn(),
203 'execute_input' : ExecuteInput(),
204 204 'pyout' : PyOut(),
205 205 'pyerr' : PyErr(),
206 206 'stream' : Stream(),
207 207 'display_data' : DisplayData(),
208 208 'header' : RHeader(),
209 209 }
210 210 """
211 211 Specifications of `content` part of the reply messages.
212 212 """
213 213
214 214
215 215 def validate_message(msg, msg_type=None, parent=None):
216 216 """validate a message
217 217
218 218 This is a generator, and must be iterated through to actually
219 219 trigger each test.
220 220
221 221 If msg_type and/or parent are given, the msg_type and/or parent msg_id
222 222 are compared with the given values.
223 223 """
224 224 RMessage().check(msg)
225 225 if msg_type:
226 226 nt.assert_equal(msg['msg_type'], msg_type)
227 227 if parent:
228 228 nt.assert_equal(msg['parent_header']['msg_id'], parent)
229 229 content = msg['content']
230 230 ref = references[msg['msg_type']]
231 231 ref.check(content)
232 232
233 233
234 234 #-----------------------------------------------------------------------------
235 235 # Tests
236 236 #-----------------------------------------------------------------------------
237 237
238 238 # Shell channel
239 239
240 240 def test_execute():
241 241 flush_channels()
242 242
243 243 msg_id = KC.execute(code='x=1')
244 244 reply = KC.get_shell_msg(timeout=TIMEOUT)
245 245 validate_message(reply, 'execute_reply', msg_id)
246 246
247 247
248 248 def test_execute_silent():
249 249 flush_channels()
250 250 msg_id, reply = execute(code='x=1', silent=True)
251 251
252 252 # flush status=idle
253 253 status = KC.iopub_channel.get_msg(timeout=TIMEOUT)
254 254 validate_message(status, 'status', msg_id)
255 255 nt.assert_equal(status['content']['execution_state'], 'idle')
256 256
257 257 nt.assert_raises(Empty, KC.iopub_channel.get_msg, timeout=0.1)
258 258 count = reply['execution_count']
259 259
260 260 msg_id, reply = execute(code='x=2', silent=True)
261 261
262 262 # flush status=idle
263 263 status = KC.iopub_channel.get_msg(timeout=TIMEOUT)
264 264 validate_message(status, 'status', msg_id)
265 265 nt.assert_equal(status['content']['execution_state'], 'idle')
266 266
267 267 nt.assert_raises(Empty, KC.iopub_channel.get_msg, timeout=0.1)
268 268 count_2 = reply['execution_count']
269 269 nt.assert_equal(count_2, count)
270 270
271 271
272 272 def test_execute_error():
273 273 flush_channels()
274 274
275 275 msg_id, reply = execute(code='1/0')
276 276 nt.assert_equal(reply['status'], 'error')
277 277 nt.assert_equal(reply['ename'], 'ZeroDivisionError')
278 278
279 279 pyerr = KC.iopub_channel.get_msg(timeout=TIMEOUT)
280 280 validate_message(pyerr, 'pyerr', msg_id)
281 281
282 282
283 283 def test_execute_inc():
284 284 """execute request should increment execution_count"""
285 285 flush_channels()
286 286
287 287 msg_id, reply = execute(code='x=1')
288 288 count = reply['execution_count']
289 289
290 290 flush_channels()
291 291
292 292 msg_id, reply = execute(code='x=2')
293 293 count_2 = reply['execution_count']
294 294 nt.assert_equal(count_2, count+1)
295 295
296 296
297 297 def test_user_variables():
298 298 flush_channels()
299 299
300 300 msg_id, reply = execute(code='x=1', user_variables=['x'])
301 301 user_variables = reply['user_variables']
302 302 nt.assert_equal(user_variables, {u'x': {
303 303 u'status': u'ok',
304 304 u'data': {u'text/plain': u'1'},
305 305 u'metadata': {},
306 306 }})
307 307
308 308
309 309 def test_user_variables_fail():
310 310 flush_channels()
311 311
312 312 msg_id, reply = execute(code='x=1', user_variables=['nosuchname'])
313 313 user_variables = reply['user_variables']
314 314 foo = user_variables['nosuchname']
315 315 nt.assert_equal(foo['status'], 'error')
316 316 nt.assert_equal(foo['ename'], 'KeyError')
317 317
318 318
319 319 def test_user_expressions():
320 320 flush_channels()
321 321
322 322 msg_id, reply = execute(code='x=1', user_expressions=dict(foo='x+1'))
323 323 user_expressions = reply['user_expressions']
324 324 nt.assert_equal(user_expressions, {u'foo': {
325 325 u'status': u'ok',
326 326 u'data': {u'text/plain': u'2'},
327 327 u'metadata': {},
328 328 }})
329 329
330 330
331 331 def test_user_expressions_fail():
332 332 flush_channels()
333 333
334 334 msg_id, reply = execute(code='x=0', user_expressions=dict(foo='nosuchname'))
335 335 user_expressions = reply['user_expressions']
336 336 foo = user_expressions['foo']
337 337 nt.assert_equal(foo['status'], 'error')
338 338 nt.assert_equal(foo['ename'], 'NameError')
339 339
340 340
341 341 def test_oinfo():
342 342 flush_channels()
343 343
344 344 msg_id = KC.object_info('a')
345 345 reply = KC.get_shell_msg(timeout=TIMEOUT)
346 346 validate_message(reply, 'object_info_reply', msg_id)
347 347
348 348
349 349 def test_oinfo_found():
350 350 flush_channels()
351 351
352 352 msg_id, reply = execute(code='a=5')
353 353
354 354 msg_id = KC.object_info('a')
355 355 reply = KC.get_shell_msg(timeout=TIMEOUT)
356 356 validate_message(reply, 'object_info_reply', msg_id)
357 357 content = reply['content']
358 358 assert content['found']
359 359 argspec = content['argspec']
360 360 nt.assert_is(argspec, None)
361 361
362 362
363 363 def test_oinfo_detail():
364 364 flush_channels()
365 365
366 366 msg_id, reply = execute(code='ip=get_ipython()')
367 367
368 368 msg_id = KC.object_info('ip.object_inspect', detail_level=2)
369 369 reply = KC.get_shell_msg(timeout=TIMEOUT)
370 370 validate_message(reply, 'object_info_reply', msg_id)
371 371 content = reply['content']
372 372 assert content['found']
373 373 argspec = content['argspec']
374 374 nt.assert_is_instance(argspec, dict, "expected non-empty argspec dict, got %r" % argspec)
375 375 nt.assert_equal(argspec['defaults'], [0])
376 376
377 377
378 378 def test_oinfo_not_found():
379 379 flush_channels()
380 380
381 381 msg_id = KC.object_info('dne')
382 382 reply = KC.get_shell_msg(timeout=TIMEOUT)
383 383 validate_message(reply, 'object_info_reply', msg_id)
384 384 content = reply['content']
385 385 nt.assert_false(content['found'])
386 386
387 387
388 388 def test_complete():
389 389 flush_channels()
390 390
391 391 msg_id, reply = execute(code="alpha = albert = 5")
392 392
393 393 msg_id = KC.complete('al', 'al', 2)
394 394 reply = KC.get_shell_msg(timeout=TIMEOUT)
395 395 validate_message(reply, 'complete_reply', msg_id)
396 396 matches = reply['content']['matches']
397 397 for name in ('alpha', 'albert'):
398 398 nt.assert_in(name, matches)
399 399
400 400
401 401 def test_kernel_info_request():
402 402 flush_channels()
403 403
404 404 msg_id = KC.kernel_info()
405 405 reply = KC.get_shell_msg(timeout=TIMEOUT)
406 406 validate_message(reply, 'kernel_info_reply', msg_id)
407 407
408 408
409 409 def test_single_payload():
410 410 flush_channels()
411 411 msg_id, reply = execute(code="for i in range(3):\n"+
412 412 " x=range?\n")
413 413 payload = reply['payload']
414 414 next_input_pls = [pl for pl in payload if pl["source"] == "set_next_input"]
415 415 nt.assert_equal(len(next_input_pls), 1)
416 416
417 417
418 418 # IOPub channel
419 419
420 420
421 421 def test_stream():
422 422 flush_channels()
423 423
424 424 msg_id, reply = execute("print('hi')")
425 425
426 426 stdout = KC.iopub_channel.get_msg(timeout=TIMEOUT)
427 427 validate_message(stdout, 'stream', msg_id)
428 428 content = stdout['content']
429 429 nt.assert_equal(content['name'], u'stdout')
430 430 nt.assert_equal(content['data'], u'hi\n')
431 431
432 432
433 433 def test_display_data():
434 434 flush_channels()
435 435
436 436 msg_id, reply = execute("from IPython.core.display import display; display(1)")
437 437
438 438 display = KC.iopub_channel.get_msg(timeout=TIMEOUT)
439 439 validate_message(display, 'display_data', parent=msg_id)
440 440 data = display['content']['data']
441 441 nt.assert_equal(data['text/plain'], u'1')
442 442
@@ -1,179 +1,171 b''
1 1 """utilities for testing IPython kernels"""
2 2
3 #-------------------------------------------------------------------------------
4 # Copyright (C) 2013 The IPython Development Team
5 #
6 # Distributed under the terms of the BSD License. The full license is in
7 # the file COPYING, distributed as part of this software.
8 #-------------------------------------------------------------------------------
9
10 #-------------------------------------------------------------------------------
11 # Imports
12 #-------------------------------------------------------------------------------
3 # Copyright (c) IPython Development Team.
4 # Distributed under the terms of the Modified BSD License.
13 5
14 6 import atexit
15 7
16 8 from contextlib import contextmanager
17 9 from subprocess import PIPE, STDOUT
18 10 try:
19 11 from queue import Empty # Py 3
20 12 except ImportError:
21 13 from Queue import Empty # Py 2
22 14
23 15 import nose
24 16 import nose.tools as nt
25 17
26 18 from IPython.kernel import KernelManager
27 19
28 20 #-------------------------------------------------------------------------------
29 21 # Globals
30 22 #-------------------------------------------------------------------------------
31 23
32 24 STARTUP_TIMEOUT = 60
33 25 TIMEOUT = 15
34 26
35 27 KM = None
36 28 KC = None
37 29
38 30 #-------------------------------------------------------------------------------
39 31 # code
40 32 #-------------------------------------------------------------------------------
41 33
42 34
43 35 def start_new_kernel(argv=None):
44 36 """start a new kernel, and return its Manager and Client"""
45 37 km = KernelManager()
46 38 kwargs = dict(stdout=nose.iptest_stdstreams_fileno(), stderr=STDOUT)
47 39 if argv:
48 40 kwargs['extra_arguments'] = argv
49 41 km.start_kernel(**kwargs)
50 42 kc = km.client()
51 43 kc.start_channels()
52 44
53 45 msg_id = kc.kernel_info()
54 46 kc.get_shell_msg(block=True, timeout=STARTUP_TIMEOUT)
55 47 flush_channels(kc)
56 48 return km, kc
57 49
58 50 def flush_channels(kc=None):
59 51 """flush any messages waiting on the queue"""
60 52 from .test_message_spec import validate_message
61 53
62 54 if kc is None:
63 55 kc = KC
64 56 for channel in (kc.shell_channel, kc.iopub_channel):
65 57 while True:
66 58 try:
67 59 msg = channel.get_msg(block=True, timeout=0.1)
68 60 except Empty:
69 61 break
70 62 else:
71 63 validate_message(msg)
72 64
73 65
74 66 def execute(code='', kc=None, **kwargs):
75 67 """wrapper for doing common steps for validating an execution request"""
76 68 from .test_message_spec import validate_message
77 69 if kc is None:
78 70 kc = KC
79 71 msg_id = kc.execute(code=code, **kwargs)
80 72 reply = kc.get_shell_msg(timeout=TIMEOUT)
81 73 validate_message(reply, 'execute_reply', msg_id)
82 74 busy = kc.get_iopub_msg(timeout=TIMEOUT)
83 75 validate_message(busy, 'status', msg_id)
84 76 nt.assert_equal(busy['content']['execution_state'], 'busy')
85 77
86 78 if not kwargs.get('silent'):
87 pyin = kc.get_iopub_msg(timeout=TIMEOUT)
88 validate_message(pyin, 'pyin', msg_id)
89 nt.assert_equal(pyin['content']['code'], code)
79 execute_input = kc.get_iopub_msg(timeout=TIMEOUT)
80 validate_message(execute_input, 'execute_input', msg_id)
81 nt.assert_equal(execute_input['content']['code'], code)
90 82
91 83 return msg_id, reply['content']
92 84
93 85 def start_global_kernel():
94 86 """start the global kernel (if it isn't running) and return its client"""
95 87 global KM, KC
96 88 if KM is None:
97 89 KM, KC = start_new_kernel()
98 90 atexit.register(stop_global_kernel)
99 91 return KC
100 92
101 93 @contextmanager
102 94 def kernel():
103 95 """Context manager for the global kernel instance
104 96
105 97 Should be used for most kernel tests
106 98
107 99 Returns
108 100 -------
109 101 kernel_client: connected KernelClient instance
110 102 """
111 103 yield start_global_kernel()
112 104
113 105 def uses_kernel(test_f):
114 106 """Decorator for tests that use the global kernel"""
115 107 def wrapped_test():
116 108 with kernel() as kc:
117 109 test_f(kc)
118 110 wrapped_test.__doc__ = test_f.__doc__
119 111 wrapped_test.__name__ = test_f.__name__
120 112 return wrapped_test
121 113
122 114 def stop_global_kernel():
123 115 """Stop the global shared kernel instance, if it exists"""
124 116 global KM, KC
125 117 KC.stop_channels()
126 118 KC = None
127 119 if KM is None:
128 120 return
129 121 KM.shutdown_kernel(now=True)
130 122 KM = None
131 123
132 124 @contextmanager
133 125 def new_kernel(argv=None):
134 126 """Context manager for a new kernel in a subprocess
135 127
136 128 Should only be used for tests where the kernel must not be re-used.
137 129
138 130 Returns
139 131 -------
140 132 kernel_client: connected KernelClient instance
141 133 """
142 134 km, kc = start_new_kernel(argv)
143 135 try:
144 136 yield kc
145 137 finally:
146 138 kc.stop_channels()
147 139 km.shutdown_kernel(now=True)
148 140
149 141
150 142 def assemble_output(iopub):
151 143 """assemble stdout/err from an execution"""
152 144 stdout = ''
153 145 stderr = ''
154 146 while True:
155 147 msg = iopub.get_msg(block=True, timeout=1)
156 148 msg_type = msg['msg_type']
157 149 content = msg['content']
158 150 if msg_type == 'status' and content['execution_state'] == 'idle':
159 151 # idle message signals end of output
160 152 break
161 153 elif msg['msg_type'] == 'stream':
162 154 if content['name'] == 'stdout':
163 155 stdout += content['data']
164 156 elif content['name'] == 'stderr':
165 157 stderr += content['data']
166 158 else:
167 159 raise KeyError("bad stream: %r" % content['name'])
168 160 else:
169 161 # other output, ignored
170 162 pass
171 163 return stdout, stderr
172 164
173 165 def wait_for_idle(kc):
174 166 while True:
175 167 msg = kc.iopub_channel.get_msg(block=True, timeout=1)
176 168 msg_type = msg['msg_type']
177 169 content = msg['content']
178 170 if msg_type == 'status' and content['execution_state'] == 'idle':
179 171 break
@@ -1,797 +1,797 b''
1 1 #!/usr/bin/env python
2 2 """An interactive kernel that talks to frontends over 0MQ."""
3 3
4 4 # Copyright (c) IPython Development Team.
5 5 # Distributed under the terms of the Modified BSD License.
6 6
7 7 from __future__ import print_function
8 8
9 9 import sys
10 10 import time
11 11 import traceback
12 12 import logging
13 13 import uuid
14 14
15 15 from datetime import datetime
16 16 from signal import (
17 17 signal, default_int_handler, SIGINT
18 18 )
19 19
20 20 import zmq
21 21 from zmq.eventloop import ioloop
22 22 from zmq.eventloop.zmqstream import ZMQStream
23 23
24 24 from IPython.config.configurable import Configurable
25 25 from IPython.core.error import StdinNotImplementedError
26 26 from IPython.core import release
27 27 from IPython.utils import py3compat
28 28 from IPython.utils.py3compat import builtin_mod, unicode_type, string_types
29 29 from IPython.utils.jsonutil import json_clean
30 30 from IPython.utils.traitlets import (
31 31 Any, Instance, Float, Dict, List, Set, Integer, Unicode,
32 32 Type, Bool,
33 33 )
34 34
35 35 from .serialize import serialize_object, unpack_apply_message
36 36 from .session import Session
37 37 from .zmqshell import ZMQInteractiveShell
38 38
39 39
40 40 #-----------------------------------------------------------------------------
41 41 # Main kernel class
42 42 #-----------------------------------------------------------------------------
43 43
44 44 protocol_version = release.kernel_protocol_version
45 45 ipython_version = release.version
46 46 language_version = sys.version.split()[0]
47 47
48 48
49 49 class Kernel(Configurable):
50 50
51 51 #---------------------------------------------------------------------------
52 52 # Kernel interface
53 53 #---------------------------------------------------------------------------
54 54
55 55 # attribute to override with a GUI
56 56 eventloop = Any(None)
57 57 def _eventloop_changed(self, name, old, new):
58 58 """schedule call to eventloop from IOLoop"""
59 59 loop = ioloop.IOLoop.instance()
60 60 loop.add_callback(self.enter_eventloop)
61 61
62 62 shell = Instance('IPython.core.interactiveshell.InteractiveShellABC')
63 63 shell_class = Type(ZMQInteractiveShell)
64 64
65 65 session = Instance(Session)
66 66 profile_dir = Instance('IPython.core.profiledir.ProfileDir')
67 67 shell_streams = List()
68 68 control_stream = Instance(ZMQStream)
69 69 iopub_socket = Instance(zmq.Socket)
70 70 stdin_socket = Instance(zmq.Socket)
71 71 log = Instance(logging.Logger)
72 72
73 73 user_module = Any()
74 74 def _user_module_changed(self, name, old, new):
75 75 if self.shell is not None:
76 76 self.shell.user_module = new
77 77
78 78 user_ns = Instance(dict, args=None, allow_none=True)
79 79 def _user_ns_changed(self, name, old, new):
80 80 if self.shell is not None:
81 81 self.shell.user_ns = new
82 82 self.shell.init_user_ns()
83 83
84 84 # identities:
85 85 int_id = Integer(-1)
86 86 ident = Unicode()
87 87
88 88 def _ident_default(self):
89 89 return unicode_type(uuid.uuid4())
90 90
91 91 # Private interface
92 92
93 93 _darwin_app_nap = Bool(True, config=True,
94 94 help="""Whether to use appnope for compatiblity with OS X App Nap.
95 95
96 96 Only affects OS X >= 10.9.
97 97 """
98 98 )
99 99
100 100 # Time to sleep after flushing the stdout/err buffers in each execute
101 101 # cycle. While this introduces a hard limit on the minimal latency of the
102 102 # execute cycle, it helps prevent output synchronization problems for
103 103 # clients.
104 104 # Units are in seconds. The minimum zmq latency on local host is probably
105 105 # ~150 microseconds, set this to 500us for now. We may need to increase it
106 106 # a little if it's not enough after more interactive testing.
107 107 _execute_sleep = Float(0.0005, config=True)
108 108
109 109 # Frequency of the kernel's event loop.
110 110 # Units are in seconds, kernel subclasses for GUI toolkits may need to
111 111 # adapt to milliseconds.
112 112 _poll_interval = Float(0.05, config=True)
113 113
114 114 # If the shutdown was requested over the network, we leave here the
115 115 # necessary reply message so it can be sent by our registered atexit
116 116 # handler. This ensures that the reply is only sent to clients truly at
117 117 # the end of our shutdown process (which happens after the underlying
118 118 # IPython shell's own shutdown).
119 119 _shutdown_message = None
120 120
121 121 # This is a dict of port number that the kernel is listening on. It is set
122 122 # by record_ports and used by connect_request.
123 123 _recorded_ports = Dict()
124 124
125 125 # A reference to the Python builtin 'raw_input' function.
126 126 # (i.e., __builtin__.raw_input for Python 2.7, builtins.input for Python 3)
127 127 _sys_raw_input = Any()
128 128 _sys_eval_input = Any()
129 129
130 130 # set of aborted msg_ids
131 131 aborted = Set()
132 132
133 133
134 134 def __init__(self, **kwargs):
135 135 super(Kernel, self).__init__(**kwargs)
136 136
137 137 # Initialize the InteractiveShell subclass
138 138 self.shell = self.shell_class.instance(parent=self,
139 139 profile_dir = self.profile_dir,
140 140 user_module = self.user_module,
141 141 user_ns = self.user_ns,
142 142 kernel = self,
143 143 )
144 144 self.shell.displayhook.session = self.session
145 145 self.shell.displayhook.pub_socket = self.iopub_socket
146 146 self.shell.displayhook.topic = self._topic('pyout')
147 147 self.shell.display_pub.session = self.session
148 148 self.shell.display_pub.pub_socket = self.iopub_socket
149 149 self.shell.data_pub.session = self.session
150 150 self.shell.data_pub.pub_socket = self.iopub_socket
151 151
152 152 # TMP - hack while developing
153 153 self.shell._reply_content = None
154 154
155 155 # Build dict of handlers for message types
156 156 msg_types = [ 'execute_request', 'complete_request',
157 157 'object_info_request', 'history_request',
158 158 'kernel_info_request',
159 159 'connect_request', 'shutdown_request',
160 160 'apply_request',
161 161 ]
162 162 self.shell_handlers = {}
163 163 for msg_type in msg_types:
164 164 self.shell_handlers[msg_type] = getattr(self, msg_type)
165 165
166 166 comm_msg_types = [ 'comm_open', 'comm_msg', 'comm_close' ]
167 167 comm_manager = self.shell.comm_manager
168 168 for msg_type in comm_msg_types:
169 169 self.shell_handlers[msg_type] = getattr(comm_manager, msg_type)
170 170
171 171 control_msg_types = msg_types + [ 'clear_request', 'abort_request' ]
172 172 self.control_handlers = {}
173 173 for msg_type in control_msg_types:
174 174 self.control_handlers[msg_type] = getattr(self, msg_type)
175 175
176 176
177 177 def dispatch_control(self, msg):
178 178 """dispatch control requests"""
179 179 idents,msg = self.session.feed_identities(msg, copy=False)
180 180 try:
181 181 msg = self.session.unserialize(msg, content=True, copy=False)
182 182 except:
183 183 self.log.error("Invalid Control Message", exc_info=True)
184 184 return
185 185
186 186 self.log.debug("Control received: %s", msg)
187 187
188 188 header = msg['header']
189 189 msg_id = header['msg_id']
190 190 msg_type = header['msg_type']
191 191
192 192 handler = self.control_handlers.get(msg_type, None)
193 193 if handler is None:
194 194 self.log.error("UNKNOWN CONTROL MESSAGE TYPE: %r", msg_type)
195 195 else:
196 196 try:
197 197 handler(self.control_stream, idents, msg)
198 198 except Exception:
199 199 self.log.error("Exception in control handler:", exc_info=True)
200 200
201 201 def dispatch_shell(self, stream, msg):
202 202 """dispatch shell requests"""
203 203 # flush control requests first
204 204 if self.control_stream:
205 205 self.control_stream.flush()
206 206
207 207 idents,msg = self.session.feed_identities(msg, copy=False)
208 208 try:
209 209 msg = self.session.unserialize(msg, content=True, copy=False)
210 210 except:
211 211 self.log.error("Invalid Message", exc_info=True)
212 212 return
213 213
214 214 header = msg['header']
215 215 msg_id = header['msg_id']
216 216 msg_type = msg['header']['msg_type']
217 217
218 218 # Print some info about this message and leave a '--->' marker, so it's
219 219 # easier to trace visually the message chain when debugging. Each
220 220 # handler prints its message at the end.
221 221 self.log.debug('\n*** MESSAGE TYPE:%s***', msg_type)
222 222 self.log.debug(' Content: %s\n --->\n ', msg['content'])
223 223
224 224 if msg_id in self.aborted:
225 225 self.aborted.remove(msg_id)
226 226 # is it safe to assume a msg_id will not be resubmitted?
227 227 reply_type = msg_type.split('_')[0] + '_reply'
228 228 status = {'status' : 'aborted'}
229 229 md = {'engine' : self.ident}
230 230 md.update(status)
231 231 reply_msg = self.session.send(stream, reply_type, metadata=md,
232 232 content=status, parent=msg, ident=idents)
233 233 return
234 234
235 235 handler = self.shell_handlers.get(msg_type, None)
236 236 if handler is None:
237 237 self.log.error("UNKNOWN MESSAGE TYPE: %r", msg_type)
238 238 else:
239 239 # ensure default_int_handler during handler call
240 240 sig = signal(SIGINT, default_int_handler)
241 241 try:
242 242 handler(stream, idents, msg)
243 243 except Exception:
244 244 self.log.error("Exception in message handler:", exc_info=True)
245 245 finally:
246 246 signal(SIGINT, sig)
247 247
248 248 def enter_eventloop(self):
249 249 """enter eventloop"""
250 250 self.log.info("entering eventloop %s", self.eventloop)
251 251 for stream in self.shell_streams:
252 252 # flush any pending replies,
253 253 # which may be skipped by entering the eventloop
254 254 stream.flush(zmq.POLLOUT)
255 255 # restore default_int_handler
256 256 signal(SIGINT, default_int_handler)
257 257 while self.eventloop is not None:
258 258 try:
259 259 self.eventloop(self)
260 260 except KeyboardInterrupt:
261 261 # Ctrl-C shouldn't crash the kernel
262 262 self.log.error("KeyboardInterrupt caught in kernel")
263 263 continue
264 264 else:
265 265 # eventloop exited cleanly, this means we should stop (right?)
266 266 self.eventloop = None
267 267 break
268 268 self.log.info("exiting eventloop")
269 269
270 270 def start(self):
271 271 """register dispatchers for streams"""
272 272 self.shell.exit_now = False
273 273 if self.control_stream:
274 274 self.control_stream.on_recv(self.dispatch_control, copy=False)
275 275
276 276 def make_dispatcher(stream):
277 277 def dispatcher(msg):
278 278 return self.dispatch_shell(stream, msg)
279 279 return dispatcher
280 280
281 281 for s in self.shell_streams:
282 282 s.on_recv(make_dispatcher(s), copy=False)
283 283
284 284 # publish idle status
285 285 self._publish_status('starting')
286 286
287 287 def do_one_iteration(self):
288 288 """step eventloop just once"""
289 289 if self.control_stream:
290 290 self.control_stream.flush()
291 291 for stream in self.shell_streams:
292 292 # handle at most one request per iteration
293 293 stream.flush(zmq.POLLIN, 1)
294 294 stream.flush(zmq.POLLOUT)
295 295
296 296
297 297 def record_ports(self, ports):
298 298 """Record the ports that this kernel is using.
299 299
300 300 The creator of the Kernel instance must call this methods if they
301 301 want the :meth:`connect_request` method to return the port numbers.
302 302 """
303 303 self._recorded_ports = ports
304 304
305 305 #---------------------------------------------------------------------------
306 306 # Kernel request handlers
307 307 #---------------------------------------------------------------------------
308 308
309 309 def _make_metadata(self, other=None):
310 310 """init metadata dict, for execute/apply_reply"""
311 311 new_md = {
312 312 'dependencies_met' : True,
313 313 'engine' : self.ident,
314 314 'started': datetime.now(),
315 315 }
316 316 if other:
317 317 new_md.update(other)
318 318 return new_md
319 319
320 def _publish_pyin(self, code, parent, execution_count):
321 """Publish the code request on the pyin stream."""
320 def _publish_execute_input(self, code, parent, execution_count):
321 """Publish the code request on the iopub stream."""
322 322
323 self.session.send(self.iopub_socket, u'pyin',
323 self.session.send(self.iopub_socket, u'execute_input',
324 324 {u'code':code, u'execution_count': execution_count},
325 parent=parent, ident=self._topic('pyin')
325 parent=parent, ident=self._topic('execute_input')
326 326 )
327 327
328 328 def _publish_status(self, status, parent=None):
329 329 """send status (busy/idle) on IOPub"""
330 330 self.session.send(self.iopub_socket,
331 331 u'status',
332 332 {u'execution_state': status},
333 333 parent=parent,
334 334 ident=self._topic('status'),
335 335 )
336 336
337 337
338 338 def execute_request(self, stream, ident, parent):
339 339 """handle an execute_request"""
340 340
341 341 self._publish_status(u'busy', parent)
342 342
343 343 try:
344 344 content = parent[u'content']
345 345 code = py3compat.cast_unicode_py2(content[u'code'])
346 346 silent = content[u'silent']
347 347 store_history = content.get(u'store_history', not silent)
348 348 except:
349 349 self.log.error("Got bad msg: ")
350 350 self.log.error("%s", parent)
351 351 return
352 352
353 353 md = self._make_metadata(parent['metadata'])
354 354
355 355 shell = self.shell # we'll need this a lot here
356 356
357 357 # Replace raw_input. Note that is not sufficient to replace
358 358 # raw_input in the user namespace.
359 359 if content.get('allow_stdin', False):
360 360 raw_input = lambda prompt='': self._raw_input(prompt, ident, parent)
361 361 input = lambda prompt='': eval(raw_input(prompt))
362 362 else:
363 363 raw_input = input = lambda prompt='' : self._no_raw_input()
364 364
365 365 if py3compat.PY3:
366 366 self._sys_raw_input = builtin_mod.input
367 367 builtin_mod.input = raw_input
368 368 else:
369 369 self._sys_raw_input = builtin_mod.raw_input
370 370 self._sys_eval_input = builtin_mod.input
371 371 builtin_mod.raw_input = raw_input
372 372 builtin_mod.input = input
373 373
374 374 # Set the parent message of the display hook and out streams.
375 375 shell.set_parent(parent)
376 376
377 377 # Re-broadcast our input for the benefit of listening clients, and
378 378 # start computing output
379 379 if not silent:
380 self._publish_pyin(code, parent, shell.execution_count)
380 self._publish_execute_input(code, parent, shell.execution_count)
381 381
382 382 reply_content = {}
383 383 # FIXME: the shell calls the exception handler itself.
384 384 shell._reply_content = None
385 385 try:
386 386 shell.run_cell(code, store_history=store_history, silent=silent)
387 387 except:
388 388 status = u'error'
389 389 # FIXME: this code right now isn't being used yet by default,
390 390 # because the run_cell() call above directly fires off exception
391 391 # reporting. This code, therefore, is only active in the scenario
392 392 # where runlines itself has an unhandled exception. We need to
393 393 # uniformize this, for all exception construction to come from a
394 394 # single location in the codbase.
395 395 etype, evalue, tb = sys.exc_info()
396 396 tb_list = traceback.format_exception(etype, evalue, tb)
397 397 reply_content.update(shell._showtraceback(etype, evalue, tb_list))
398 398 else:
399 399 status = u'ok'
400 400 finally:
401 401 # Restore raw_input.
402 402 if py3compat.PY3:
403 403 builtin_mod.input = self._sys_raw_input
404 404 else:
405 405 builtin_mod.raw_input = self._sys_raw_input
406 406 builtin_mod.input = self._sys_eval_input
407 407
408 408 reply_content[u'status'] = status
409 409
410 410 # Return the execution counter so clients can display prompts
411 411 reply_content['execution_count'] = shell.execution_count - 1
412 412
413 413 # FIXME - fish exception info out of shell, possibly left there by
414 414 # runlines. We'll need to clean up this logic later.
415 415 if shell._reply_content is not None:
416 416 reply_content.update(shell._reply_content)
417 417 e_info = dict(engine_uuid=self.ident, engine_id=self.int_id, method='execute')
418 418 reply_content['engine_info'] = e_info
419 419 # reset after use
420 420 shell._reply_content = None
421 421
422 422 if 'traceback' in reply_content:
423 423 self.log.info("Exception in execute request:\n%s", '\n'.join(reply_content['traceback']))
424 424
425 425
426 426 # At this point, we can tell whether the main code execution succeeded
427 427 # or not. If it did, we proceed to evaluate user_variables/expressions
428 428 if reply_content['status'] == 'ok':
429 429 reply_content[u'user_variables'] = \
430 430 shell.user_variables(content.get(u'user_variables', []))
431 431 reply_content[u'user_expressions'] = \
432 432 shell.user_expressions(content.get(u'user_expressions', {}))
433 433 else:
434 434 # If there was an error, don't even try to compute variables or
435 435 # expressions
436 436 reply_content[u'user_variables'] = {}
437 437 reply_content[u'user_expressions'] = {}
438 438
439 439 # Payloads should be retrieved regardless of outcome, so we can both
440 440 # recover partial output (that could have been generated early in a
441 441 # block, before an error) and clear the payload system always.
442 442 reply_content[u'payload'] = shell.payload_manager.read_payload()
443 443 # Be agressive about clearing the payload because we don't want
444 444 # it to sit in memory until the next execute_request comes in.
445 445 shell.payload_manager.clear_payload()
446 446
447 447 # Flush output before sending the reply.
448 448 sys.stdout.flush()
449 449 sys.stderr.flush()
450 450 # FIXME: on rare occasions, the flush doesn't seem to make it to the
451 451 # clients... This seems to mitigate the problem, but we definitely need
452 452 # to better understand what's going on.
453 453 if self._execute_sleep:
454 454 time.sleep(self._execute_sleep)
455 455
456 456 # Send the reply.
457 457 reply_content = json_clean(reply_content)
458 458
459 459 md['status'] = reply_content['status']
460 460 if reply_content['status'] == 'error' and \
461 461 reply_content['ename'] == 'UnmetDependency':
462 462 md['dependencies_met'] = False
463 463
464 464 reply_msg = self.session.send(stream, u'execute_reply',
465 465 reply_content, parent, metadata=md,
466 466 ident=ident)
467 467
468 468 self.log.debug("%s", reply_msg)
469 469
470 470 if not silent and reply_msg['content']['status'] == u'error':
471 471 self._abort_queues()
472 472
473 473 self._publish_status(u'idle', parent)
474 474
475 475 def complete_request(self, stream, ident, parent):
476 476 txt, matches = self._complete(parent)
477 477 matches = {'matches' : matches,
478 478 'matched_text' : txt,
479 479 'status' : 'ok'}
480 480 matches = json_clean(matches)
481 481 completion_msg = self.session.send(stream, 'complete_reply',
482 482 matches, parent, ident)
483 483 self.log.debug("%s", completion_msg)
484 484
485 485 def object_info_request(self, stream, ident, parent):
486 486 content = parent['content']
487 487 object_info = self.shell.object_inspect(content['oname'],
488 488 detail_level = content.get('detail_level', 0)
489 489 )
490 490 # Before we send this object over, we scrub it for JSON usage
491 491 oinfo = json_clean(object_info)
492 492 msg = self.session.send(stream, 'object_info_reply',
493 493 oinfo, parent, ident)
494 494 self.log.debug("%s", msg)
495 495
496 496 def history_request(self, stream, ident, parent):
497 497 # We need to pull these out, as passing **kwargs doesn't work with
498 498 # unicode keys before Python 2.6.5.
499 499 hist_access_type = parent['content']['hist_access_type']
500 500 raw = parent['content']['raw']
501 501 output = parent['content']['output']
502 502 if hist_access_type == 'tail':
503 503 n = parent['content']['n']
504 504 hist = self.shell.history_manager.get_tail(n, raw=raw, output=output,
505 505 include_latest=True)
506 506
507 507 elif hist_access_type == 'range':
508 508 session = parent['content']['session']
509 509 start = parent['content']['start']
510 510 stop = parent['content']['stop']
511 511 hist = self.shell.history_manager.get_range(session, start, stop,
512 512 raw=raw, output=output)
513 513
514 514 elif hist_access_type == 'search':
515 515 n = parent['content'].get('n')
516 516 unique = parent['content'].get('unique', False)
517 517 pattern = parent['content']['pattern']
518 518 hist = self.shell.history_manager.search(
519 519 pattern, raw=raw, output=output, n=n, unique=unique)
520 520
521 521 else:
522 522 hist = []
523 523 hist = list(hist)
524 524 content = {'history' : hist}
525 525 content = json_clean(content)
526 526 msg = self.session.send(stream, 'history_reply',
527 527 content, parent, ident)
528 528 self.log.debug("Sending history reply with %i entries", len(hist))
529 529
530 530 def connect_request(self, stream, ident, parent):
531 531 if self._recorded_ports is not None:
532 532 content = self._recorded_ports.copy()
533 533 else:
534 534 content = {}
535 535 msg = self.session.send(stream, 'connect_reply',
536 536 content, parent, ident)
537 537 self.log.debug("%s", msg)
538 538
539 539 def kernel_info_request(self, stream, ident, parent):
540 540 vinfo = {
541 541 'protocol_version': protocol_version,
542 542 'ipython_version': ipython_version,
543 543 'language_version': language_version,
544 544 'language': 'python',
545 545 }
546 546 msg = self.session.send(stream, 'kernel_info_reply',
547 547 vinfo, parent, ident)
548 548 self.log.debug("%s", msg)
549 549
550 550 def shutdown_request(self, stream, ident, parent):
551 551 self.shell.exit_now = True
552 552 content = dict(status='ok')
553 553 content.update(parent['content'])
554 554 self.session.send(stream, u'shutdown_reply', content, parent, ident=ident)
555 555 # same content, but different msg_id for broadcasting on IOPub
556 556 self._shutdown_message = self.session.msg(u'shutdown_reply',
557 557 content, parent
558 558 )
559 559
560 560 self._at_shutdown()
561 561 # call sys.exit after a short delay
562 562 loop = ioloop.IOLoop.instance()
563 563 loop.add_timeout(time.time()+0.1, loop.stop)
564 564
565 565 #---------------------------------------------------------------------------
566 566 # Engine methods
567 567 #---------------------------------------------------------------------------
568 568
569 569 def apply_request(self, stream, ident, parent):
570 570 try:
571 571 content = parent[u'content']
572 572 bufs = parent[u'buffers']
573 573 msg_id = parent['header']['msg_id']
574 574 except:
575 575 self.log.error("Got bad msg: %s", parent, exc_info=True)
576 576 return
577 577
578 578 self._publish_status(u'busy', parent)
579 579
580 580 # Set the parent message of the display hook and out streams.
581 581 shell = self.shell
582 582 shell.set_parent(parent)
583 583
584 # pyin_msg = self.session.msg(u'pyin',{u'code':code}, parent=parent)
585 # self.iopub_socket.send(pyin_msg)
586 # self.session.send(self.iopub_socket, u'pyin', {u'code':code},parent=parent)
584 # execute_input_msg = self.session.msg(u'execute_input',{u'code':code}, parent=parent)
585 # self.iopub_socket.send(execute_input_msg)
586 # self.session.send(self.iopub_socket, u'execute_input', {u'code':code},parent=parent)
587 587 md = self._make_metadata(parent['metadata'])
588 588 try:
589 589 working = shell.user_ns
590 590
591 591 prefix = "_"+str(msg_id).replace("-","")+"_"
592 592
593 593 f,args,kwargs = unpack_apply_message(bufs, working, copy=False)
594 594
595 595 fname = getattr(f, '__name__', 'f')
596 596
597 597 fname = prefix+"f"
598 598 argname = prefix+"args"
599 599 kwargname = prefix+"kwargs"
600 600 resultname = prefix+"result"
601 601
602 602 ns = { fname : f, argname : args, kwargname : kwargs , resultname : None }
603 603 # print ns
604 604 working.update(ns)
605 605 code = "%s = %s(*%s,**%s)" % (resultname, fname, argname, kwargname)
606 606 try:
607 607 exec(code, shell.user_global_ns, shell.user_ns)
608 608 result = working.get(resultname)
609 609 finally:
610 610 for key in ns:
611 611 working.pop(key)
612 612
613 613 result_buf = serialize_object(result,
614 614 buffer_threshold=self.session.buffer_threshold,
615 615 item_threshold=self.session.item_threshold,
616 616 )
617 617
618 618 except:
619 619 # invoke IPython traceback formatting
620 620 shell.showtraceback()
621 621 # FIXME - fish exception info out of shell, possibly left there by
622 622 # run_code. We'll need to clean up this logic later.
623 623 reply_content = {}
624 624 if shell._reply_content is not None:
625 625 reply_content.update(shell._reply_content)
626 626 e_info = dict(engine_uuid=self.ident, engine_id=self.int_id, method='apply')
627 627 reply_content['engine_info'] = e_info
628 628 # reset after use
629 629 shell._reply_content = None
630 630
631 631 self.session.send(self.iopub_socket, u'pyerr', reply_content, parent=parent,
632 632 ident=self._topic('pyerr'))
633 633 self.log.info("Exception in apply request:\n%s", '\n'.join(reply_content['traceback']))
634 634 result_buf = []
635 635
636 636 if reply_content['ename'] == 'UnmetDependency':
637 637 md['dependencies_met'] = False
638 638 else:
639 639 reply_content = {'status' : 'ok'}
640 640
641 641 # put 'ok'/'error' status in header, for scheduler introspection:
642 642 md['status'] = reply_content['status']
643 643
644 644 # flush i/o
645 645 sys.stdout.flush()
646 646 sys.stderr.flush()
647 647
648 648 reply_msg = self.session.send(stream, u'apply_reply', reply_content,
649 649 parent=parent, ident=ident,buffers=result_buf, metadata=md)
650 650
651 651 self._publish_status(u'idle', parent)
652 652
653 653 #---------------------------------------------------------------------------
654 654 # Control messages
655 655 #---------------------------------------------------------------------------
656 656
657 657 def abort_request(self, stream, ident, parent):
658 658 """abort a specifig msg by id"""
659 659 msg_ids = parent['content'].get('msg_ids', None)
660 660 if isinstance(msg_ids, string_types):
661 661 msg_ids = [msg_ids]
662 662 if not msg_ids:
663 663 self.abort_queues()
664 664 for mid in msg_ids:
665 665 self.aborted.add(str(mid))
666 666
667 667 content = dict(status='ok')
668 668 reply_msg = self.session.send(stream, 'abort_reply', content=content,
669 669 parent=parent, ident=ident)
670 670 self.log.debug("%s", reply_msg)
671 671
672 672 def clear_request(self, stream, idents, parent):
673 673 """Clear our namespace."""
674 674 self.shell.reset(False)
675 675 msg = self.session.send(stream, 'clear_reply', ident=idents, parent=parent,
676 676 content = dict(status='ok'))
677 677
678 678
679 679 #---------------------------------------------------------------------------
680 680 # Protected interface
681 681 #---------------------------------------------------------------------------
682 682
683 683 def _wrap_exception(self, method=None):
684 684 # import here, because _wrap_exception is only used in parallel,
685 685 # and parallel has higher min pyzmq version
686 686 from IPython.parallel.error import wrap_exception
687 687 e_info = dict(engine_uuid=self.ident, engine_id=self.int_id, method=method)
688 688 content = wrap_exception(e_info)
689 689 return content
690 690
691 691 def _topic(self, topic):
692 692 """prefixed topic for IOPub messages"""
693 693 if self.int_id >= 0:
694 694 base = "engine.%i" % self.int_id
695 695 else:
696 696 base = "kernel.%s" % self.ident
697 697
698 698 return py3compat.cast_bytes("%s.%s" % (base, topic))
699 699
700 700 def _abort_queues(self):
701 701 for stream in self.shell_streams:
702 702 if stream:
703 703 self._abort_queue(stream)
704 704
705 705 def _abort_queue(self, stream):
706 706 poller = zmq.Poller()
707 707 poller.register(stream.socket, zmq.POLLIN)
708 708 while True:
709 709 idents,msg = self.session.recv(stream, zmq.NOBLOCK, content=True)
710 710 if msg is None:
711 711 return
712 712
713 713 self.log.info("Aborting:")
714 714 self.log.info("%s", msg)
715 715 msg_type = msg['header']['msg_type']
716 716 reply_type = msg_type.split('_')[0] + '_reply'
717 717
718 718 status = {'status' : 'aborted'}
719 719 md = {'engine' : self.ident}
720 720 md.update(status)
721 721 reply_msg = self.session.send(stream, reply_type, metadata=md,
722 722 content=status, parent=msg, ident=idents)
723 723 self.log.debug("%s", reply_msg)
724 724 # We need to wait a bit for requests to come in. This can probably
725 725 # be set shorter for true asynchronous clients.
726 726 poller.poll(50)
727 727
728 728
729 729 def _no_raw_input(self):
730 730 """Raise StdinNotImplentedError if active frontend doesn't support
731 731 stdin."""
732 732 raise StdinNotImplementedError("raw_input was called, but this "
733 733 "frontend does not support stdin.")
734 734
735 735 def _raw_input(self, prompt, ident, parent):
736 736 # Flush output before making the request.
737 737 sys.stderr.flush()
738 738 sys.stdout.flush()
739 739 # flush the stdin socket, to purge stale replies
740 740 while True:
741 741 try:
742 742 self.stdin_socket.recv_multipart(zmq.NOBLOCK)
743 743 except zmq.ZMQError as e:
744 744 if e.errno == zmq.EAGAIN:
745 745 break
746 746 else:
747 747 raise
748 748
749 749 # Send the input request.
750 750 content = json_clean(dict(prompt=prompt))
751 751 self.session.send(self.stdin_socket, u'input_request', content, parent,
752 752 ident=ident)
753 753
754 754 # Await a response.
755 755 while True:
756 756 try:
757 757 ident, reply = self.session.recv(self.stdin_socket, 0)
758 758 except Exception:
759 759 self.log.warn("Invalid Message:", exc_info=True)
760 760 except KeyboardInterrupt:
761 761 # re-raise KeyboardInterrupt, to truncate traceback
762 762 raise KeyboardInterrupt
763 763 else:
764 764 break
765 765 try:
766 766 value = py3compat.unicode_to_str(reply['content']['value'])
767 767 except:
768 768 self.log.error("Got bad raw_input reply: ")
769 769 self.log.error("%s", parent)
770 770 value = ''
771 771 if value == '\x04':
772 772 # EOF
773 773 raise EOFError
774 774 return value
775 775
776 776 def _complete(self, msg):
777 777 c = msg['content']
778 778 try:
779 779 cpos = int(c['cursor_pos'])
780 780 except:
781 781 # If we don't get something that we can convert to an integer, at
782 782 # least attempt the completion guessing the cursor is at the end of
783 783 # the text, if there's any, and otherwise of the line
784 784 cpos = len(c['text'])
785 785 if cpos==0:
786 786 cpos = len(c['line'])
787 787 return self.shell.complete(c['text'], c['line'], cpos)
788 788
789 789 def _at_shutdown(self):
790 790 """Actions taken at shutdown by the kernel, called by python's atexit.
791 791 """
792 792 # io.rprint("Kernel at_shutdown") # dbg
793 793 if self._shutdown_message is not None:
794 794 self.session.send(self.iopub_socket, self._shutdown_message, ident=self._topic('shutdown'))
795 795 self.log.debug("%s", self._shutdown_message)
796 796 [ s.flush(zmq.POLLOUT) for s in self.shell_streams ]
797 797
@@ -1,1875 +1,1863 b''
1 """A semi-synchronous Client for the ZMQ cluster
1 """A semi-synchronous Client for IPython parallel"""
2 2
3 Authors:
3 # Copyright (c) IPython Development Team.
4 # Distributed under the terms of the Modified BSD License.
4 5
5 * MinRK
6 """
7 6 from __future__ import print_function
8 #-----------------------------------------------------------------------------
9 # Copyright (C) 2010-2011 The IPython Development Team
10 #
11 # Distributed under the terms of the BSD License. The full license is in
12 # the file COPYING, distributed as part of this software.
13 #-----------------------------------------------------------------------------
14
15 #-----------------------------------------------------------------------------
16 # Imports
17 #-----------------------------------------------------------------------------
18 7
19 8 import os
20 9 import json
21 10 import sys
22 11 from threading import Thread, Event
23 12 import time
24 13 import warnings
25 14 from datetime import datetime
26 15 from getpass import getpass
27 16 from pprint import pprint
28 17
29 18 pjoin = os.path.join
30 19
31 20 import zmq
32 # from zmq.eventloop import ioloop, zmqstream
33 21
34 22 from IPython.config.configurable import MultipleInstanceError
35 23 from IPython.core.application import BaseIPythonApplication
36 24 from IPython.core.profiledir import ProfileDir, ProfileDirError
37 25
38 26 from IPython.utils.capture import RichOutput
39 27 from IPython.utils.coloransi import TermColors
40 28 from IPython.utils.jsonutil import rekey, extract_dates, parse_date
41 29 from IPython.utils.localinterfaces import localhost, is_local_ip
42 30 from IPython.utils.path import get_ipython_dir
43 31 from IPython.utils.py3compat import cast_bytes, string_types, xrange, iteritems
44 32 from IPython.utils.traitlets import (HasTraits, Integer, Instance, Unicode,
45 33 Dict, List, Bool, Set, Any)
46 34 from IPython.external.decorator import decorator
47 35 from IPython.external.ssh import tunnel
48 36
49 37 from IPython.parallel import Reference
50 38 from IPython.parallel import error
51 39 from IPython.parallel import util
52 40
53 41 from IPython.kernel.zmq.session import Session, Message
54 42 from IPython.kernel.zmq import serialize
55 43
56 44 from .asyncresult import AsyncResult, AsyncHubResult
57 45 from .view import DirectView, LoadBalancedView
58 46
59 47 #--------------------------------------------------------------------------
60 48 # Decorators for Client methods
61 49 #--------------------------------------------------------------------------
62 50
63 51 @decorator
64 52 def spin_first(f, self, *args, **kwargs):
65 53 """Call spin() to sync state prior to calling the method."""
66 54 self.spin()
67 55 return f(self, *args, **kwargs)
68 56
69 57
70 58 #--------------------------------------------------------------------------
71 59 # Classes
72 60 #--------------------------------------------------------------------------
73 61
74 62
75 63 class ExecuteReply(RichOutput):
76 64 """wrapper for finished Execute results"""
77 65 def __init__(self, msg_id, content, metadata):
78 66 self.msg_id = msg_id
79 67 self._content = content
80 68 self.execution_count = content['execution_count']
81 69 self.metadata = metadata
82 70
83 71 # RichOutput overrides
84 72
85 73 @property
86 74 def source(self):
87 75 pyout = self.metadata['pyout']
88 76 if pyout:
89 77 return pyout.get('source', '')
90 78
91 79 @property
92 80 def data(self):
93 81 pyout = self.metadata['pyout']
94 82 if pyout:
95 83 return pyout.get('data', {})
96 84
97 85 @property
98 86 def _metadata(self):
99 87 pyout = self.metadata['pyout']
100 88 if pyout:
101 89 return pyout.get('metadata', {})
102 90
103 91 def display(self):
104 92 from IPython.display import publish_display_data
105 93 publish_display_data(self.source, self.data, self.metadata)
106 94
107 95 def _repr_mime_(self, mime):
108 96 if mime not in self.data:
109 97 return
110 98 data = self.data[mime]
111 99 if mime in self._metadata:
112 100 return data, self._metadata[mime]
113 101 else:
114 102 return data
115 103
116 104 def __getitem__(self, key):
117 105 return self.metadata[key]
118 106
119 107 def __getattr__(self, key):
120 108 if key not in self.metadata:
121 109 raise AttributeError(key)
122 110 return self.metadata[key]
123 111
124 112 def __repr__(self):
125 113 pyout = self.metadata['pyout'] or {'data':{}}
126 114 text_out = pyout['data'].get('text/plain', '')
127 115 if len(text_out) > 32:
128 116 text_out = text_out[:29] + '...'
129 117
130 118 return "<ExecuteReply[%i]: %s>" % (self.execution_count, text_out)
131 119
132 120 def _repr_pretty_(self, p, cycle):
133 121 pyout = self.metadata['pyout'] or {'data':{}}
134 122 text_out = pyout['data'].get('text/plain', '')
135 123
136 124 if not text_out:
137 125 return
138 126
139 127 try:
140 128 ip = get_ipython()
141 129 except NameError:
142 130 colors = "NoColor"
143 131 else:
144 132 colors = ip.colors
145 133
146 134 if colors == "NoColor":
147 135 out = normal = ""
148 136 else:
149 137 out = TermColors.Red
150 138 normal = TermColors.Normal
151 139
152 140 if '\n' in text_out and not text_out.startswith('\n'):
153 141 # add newline for multiline reprs
154 142 text_out = '\n' + text_out
155 143
156 144 p.text(
157 145 out + u'Out[%i:%i]: ' % (
158 146 self.metadata['engine_id'], self.execution_count
159 147 ) + normal + text_out
160 148 )
161 149
162 150
163 151 class Metadata(dict):
164 152 """Subclass of dict for initializing metadata values.
165 153
166 154 Attribute access works on keys.
167 155
168 156 These objects have a strict set of keys - errors will raise if you try
169 157 to add new keys.
170 158 """
171 159 def __init__(self, *args, **kwargs):
172 160 dict.__init__(self)
173 161 md = {'msg_id' : None,
174 162 'submitted' : None,
175 163 'started' : None,
176 164 'completed' : None,
177 165 'received' : None,
178 166 'engine_uuid' : None,
179 167 'engine_id' : None,
180 168 'follow' : None,
181 169 'after' : None,
182 170 'status' : None,
183 171
184 'pyin' : None,
172 'execute_input' : None,
185 173 'pyout' : None,
186 174 'pyerr' : None,
187 175 'stdout' : '',
188 176 'stderr' : '',
189 177 'outputs' : [],
190 178 'data': {},
191 179 'outputs_ready' : False,
192 180 }
193 181 self.update(md)
194 182 self.update(dict(*args, **kwargs))
195 183
196 184 def __getattr__(self, key):
197 185 """getattr aliased to getitem"""
198 186 if key in self:
199 187 return self[key]
200 188 else:
201 189 raise AttributeError(key)
202 190
203 191 def __setattr__(self, key, value):
204 192 """setattr aliased to setitem, with strict"""
205 193 if key in self:
206 194 self[key] = value
207 195 else:
208 196 raise AttributeError(key)
209 197
210 198 def __setitem__(self, key, value):
211 199 """strict static key enforcement"""
212 200 if key in self:
213 201 dict.__setitem__(self, key, value)
214 202 else:
215 203 raise KeyError(key)
216 204
217 205
218 206 class Client(HasTraits):
219 207 """A semi-synchronous client to the IPython ZMQ cluster
220 208
221 209 Parameters
222 210 ----------
223 211
224 212 url_file : str/unicode; path to ipcontroller-client.json
225 213 This JSON file should contain all the information needed to connect to a cluster,
226 214 and is likely the only argument needed.
227 215 Connection information for the Hub's registration. If a json connector
228 216 file is given, then likely no further configuration is necessary.
229 217 [Default: use profile]
230 218 profile : bytes
231 219 The name of the Cluster profile to be used to find connector information.
232 220 If run from an IPython application, the default profile will be the same
233 221 as the running application, otherwise it will be 'default'.
234 222 cluster_id : str
235 223 String id to added to runtime files, to prevent name collisions when using
236 224 multiple clusters with a single profile simultaneously.
237 225 When set, will look for files named like: 'ipcontroller-<cluster_id>-client.json'
238 226 Since this is text inserted into filenames, typical recommendations apply:
239 227 Simple character strings are ideal, and spaces are not recommended (but
240 228 should generally work)
241 229 context : zmq.Context
242 230 Pass an existing zmq.Context instance, otherwise the client will create its own.
243 231 debug : bool
244 232 flag for lots of message printing for debug purposes
245 233 timeout : int/float
246 234 time (in seconds) to wait for connection replies from the Hub
247 235 [Default: 10]
248 236
249 237 #-------------- session related args ----------------
250 238
251 239 config : Config object
252 240 If specified, this will be relayed to the Session for configuration
253 241 username : str
254 242 set username for the session object
255 243
256 244 #-------------- ssh related args ----------------
257 245 # These are args for configuring the ssh tunnel to be used
258 246 # credentials are used to forward connections over ssh to the Controller
259 247 # Note that the ip given in `addr` needs to be relative to sshserver
260 248 # The most basic case is to leave addr as pointing to localhost (127.0.0.1),
261 249 # and set sshserver as the same machine the Controller is on. However,
262 250 # the only requirement is that sshserver is able to see the Controller
263 251 # (i.e. is within the same trusted network).
264 252
265 253 sshserver : str
266 254 A string of the form passed to ssh, i.e. 'server.tld' or 'user@server.tld:port'
267 255 If keyfile or password is specified, and this is not, it will default to
268 256 the ip given in addr.
269 257 sshkey : str; path to ssh private key file
270 258 This specifies a key to be used in ssh login, default None.
271 259 Regular default ssh keys will be used without specifying this argument.
272 260 password : str
273 261 Your ssh password to sshserver. Note that if this is left None,
274 262 you will be prompted for it if passwordless key based login is unavailable.
275 263 paramiko : bool
276 264 flag for whether to use paramiko instead of shell ssh for tunneling.
277 265 [default: True on win32, False else]
278 266
279 267
280 268 Attributes
281 269 ----------
282 270
283 271 ids : list of int engine IDs
284 272 requesting the ids attribute always synchronizes
285 273 the registration state. To request ids without synchronization,
286 274 use semi-private _ids attributes.
287 275
288 276 history : list of msg_ids
289 277 a list of msg_ids, keeping track of all the execution
290 278 messages you have submitted in order.
291 279
292 280 outstanding : set of msg_ids
293 281 a set of msg_ids that have been submitted, but whose
294 282 results have not yet been received.
295 283
296 284 results : dict
297 285 a dict of all our results, keyed by msg_id
298 286
299 287 block : bool
300 288 determines default behavior when block not specified
301 289 in execution methods
302 290
303 291 Methods
304 292 -------
305 293
306 294 spin
307 295 flushes incoming results and registration state changes
308 296 control methods spin, and requesting `ids` also ensures up to date
309 297
310 298 wait
311 299 wait on one or more msg_ids
312 300
313 301 execution methods
314 302 apply
315 303 legacy: execute, run
316 304
317 305 data movement
318 306 push, pull, scatter, gather
319 307
320 308 query methods
321 309 queue_status, get_result, purge, result_status
322 310
323 311 control methods
324 312 abort, shutdown
325 313
326 314 """
327 315
328 316
329 317 block = Bool(False)
330 318 outstanding = Set()
331 319 results = Instance('collections.defaultdict', (dict,))
332 320 metadata = Instance('collections.defaultdict', (Metadata,))
333 321 history = List()
334 322 debug = Bool(False)
335 323 _spin_thread = Any()
336 324 _stop_spinning = Any()
337 325
338 326 profile=Unicode()
339 327 def _profile_default(self):
340 328 if BaseIPythonApplication.initialized():
341 329 # an IPython app *might* be running, try to get its profile
342 330 try:
343 331 return BaseIPythonApplication.instance().profile
344 332 except (AttributeError, MultipleInstanceError):
345 333 # could be a *different* subclass of config.Application,
346 334 # which would raise one of these two errors.
347 335 return u'default'
348 336 else:
349 337 return u'default'
350 338
351 339
352 340 _outstanding_dict = Instance('collections.defaultdict', (set,))
353 341 _ids = List()
354 342 _connected=Bool(False)
355 343 _ssh=Bool(False)
356 344 _context = Instance('zmq.Context')
357 345 _config = Dict()
358 346 _engines=Instance(util.ReverseDict, (), {})
359 347 # _hub_socket=Instance('zmq.Socket')
360 348 _query_socket=Instance('zmq.Socket')
361 349 _control_socket=Instance('zmq.Socket')
362 350 _iopub_socket=Instance('zmq.Socket')
363 351 _notification_socket=Instance('zmq.Socket')
364 352 _mux_socket=Instance('zmq.Socket')
365 353 _task_socket=Instance('zmq.Socket')
366 354 _task_scheme=Unicode()
367 355 _closed = False
368 356 _ignored_control_replies=Integer(0)
369 357 _ignored_hub_replies=Integer(0)
370 358
371 359 def __new__(self, *args, **kw):
372 360 # don't raise on positional args
373 361 return HasTraits.__new__(self, **kw)
374 362
375 363 def __init__(self, url_file=None, profile=None, profile_dir=None, ipython_dir=None,
376 364 context=None, debug=False,
377 365 sshserver=None, sshkey=None, password=None, paramiko=None,
378 366 timeout=10, cluster_id=None, **extra_args
379 367 ):
380 368 if profile:
381 369 super(Client, self).__init__(debug=debug, profile=profile)
382 370 else:
383 371 super(Client, self).__init__(debug=debug)
384 372 if context is None:
385 373 context = zmq.Context.instance()
386 374 self._context = context
387 375 self._stop_spinning = Event()
388 376
389 377 if 'url_or_file' in extra_args:
390 378 url_file = extra_args['url_or_file']
391 379 warnings.warn("url_or_file arg no longer supported, use url_file", DeprecationWarning)
392 380
393 381 if url_file and util.is_url(url_file):
394 382 raise ValueError("single urls cannot be specified, url-files must be used.")
395 383
396 384 self._setup_profile_dir(self.profile, profile_dir, ipython_dir)
397 385
398 386 if self._cd is not None:
399 387 if url_file is None:
400 388 if not cluster_id:
401 389 client_json = 'ipcontroller-client.json'
402 390 else:
403 391 client_json = 'ipcontroller-%s-client.json' % cluster_id
404 392 url_file = pjoin(self._cd.security_dir, client_json)
405 393 if url_file is None:
406 394 raise ValueError(
407 395 "I can't find enough information to connect to a hub!"
408 396 " Please specify at least one of url_file or profile."
409 397 )
410 398
411 399 with open(url_file) as f:
412 400 cfg = json.load(f)
413 401
414 402 self._task_scheme = cfg['task_scheme']
415 403
416 404 # sync defaults from args, json:
417 405 if sshserver:
418 406 cfg['ssh'] = sshserver
419 407
420 408 location = cfg.setdefault('location', None)
421 409
422 410 proto,addr = cfg['interface'].split('://')
423 411 addr = util.disambiguate_ip_address(addr, location)
424 412 cfg['interface'] = "%s://%s" % (proto, addr)
425 413
426 414 # turn interface,port into full urls:
427 415 for key in ('control', 'task', 'mux', 'iopub', 'notification', 'registration'):
428 416 cfg[key] = cfg['interface'] + ':%i' % cfg[key]
429 417
430 418 url = cfg['registration']
431 419
432 420 if location is not None and addr == localhost():
433 421 # location specified, and connection is expected to be local
434 422 if not is_local_ip(location) and not sshserver:
435 423 # load ssh from JSON *only* if the controller is not on
436 424 # this machine
437 425 sshserver=cfg['ssh']
438 426 if not is_local_ip(location) and not sshserver:
439 427 # warn if no ssh specified, but SSH is probably needed
440 428 # This is only a warning, because the most likely cause
441 429 # is a local Controller on a laptop whose IP is dynamic
442 430 warnings.warn("""
443 431 Controller appears to be listening on localhost, but not on this machine.
444 432 If this is true, you should specify Client(...,sshserver='you@%s')
445 433 or instruct your controller to listen on an external IP."""%location,
446 434 RuntimeWarning)
447 435 elif not sshserver:
448 436 # otherwise sync with cfg
449 437 sshserver = cfg['ssh']
450 438
451 439 self._config = cfg
452 440
453 441 self._ssh = bool(sshserver or sshkey or password)
454 442 if self._ssh and sshserver is None:
455 443 # default to ssh via localhost
456 444 sshserver = addr
457 445 if self._ssh and password is None:
458 446 if tunnel.try_passwordless_ssh(sshserver, sshkey, paramiko):
459 447 password=False
460 448 else:
461 449 password = getpass("SSH Password for %s: "%sshserver)
462 450 ssh_kwargs = dict(keyfile=sshkey, password=password, paramiko=paramiko)
463 451
464 452 # configure and construct the session
465 453 try:
466 454 extra_args['packer'] = cfg['pack']
467 455 extra_args['unpacker'] = cfg['unpack']
468 456 extra_args['key'] = cast_bytes(cfg['key'])
469 457 extra_args['signature_scheme'] = cfg['signature_scheme']
470 458 except KeyError as exc:
471 459 msg = '\n'.join([
472 460 "Connection file is invalid (missing '{}'), possibly from an old version of IPython.",
473 461 "If you are reusing connection files, remove them and start ipcontroller again."
474 462 ])
475 463 raise ValueError(msg.format(exc.message))
476 464
477 465 self.session = Session(**extra_args)
478 466
479 467 self._query_socket = self._context.socket(zmq.DEALER)
480 468
481 469 if self._ssh:
482 470 tunnel.tunnel_connection(self._query_socket, cfg['registration'], sshserver, **ssh_kwargs)
483 471 else:
484 472 self._query_socket.connect(cfg['registration'])
485 473
486 474 self.session.debug = self.debug
487 475
488 476 self._notification_handlers = {'registration_notification' : self._register_engine,
489 477 'unregistration_notification' : self._unregister_engine,
490 478 'shutdown_notification' : lambda msg: self.close(),
491 479 }
492 480 self._queue_handlers = {'execute_reply' : self._handle_execute_reply,
493 481 'apply_reply' : self._handle_apply_reply}
494 482
495 483 try:
496 484 self._connect(sshserver, ssh_kwargs, timeout)
497 485 except:
498 486 self.close(linger=0)
499 487 raise
500 488
501 489 # last step: setup magics, if we are in IPython:
502 490
503 491 try:
504 492 ip = get_ipython()
505 493 except NameError:
506 494 return
507 495 else:
508 496 if 'px' not in ip.magics_manager.magics:
509 497 # in IPython but we are the first Client.
510 498 # activate a default view for parallel magics.
511 499 self.activate()
512 500
513 501 def __del__(self):
514 502 """cleanup sockets, but _not_ context."""
515 503 self.close()
516 504
517 505 def _setup_profile_dir(self, profile, profile_dir, ipython_dir):
518 506 if ipython_dir is None:
519 507 ipython_dir = get_ipython_dir()
520 508 if profile_dir is not None:
521 509 try:
522 510 self._cd = ProfileDir.find_profile_dir(profile_dir)
523 511 return
524 512 except ProfileDirError:
525 513 pass
526 514 elif profile is not None:
527 515 try:
528 516 self._cd = ProfileDir.find_profile_dir_by_name(
529 517 ipython_dir, profile)
530 518 return
531 519 except ProfileDirError:
532 520 pass
533 521 self._cd = None
534 522
535 523 def _update_engines(self, engines):
536 524 """Update our engines dict and _ids from a dict of the form: {id:uuid}."""
537 525 for k,v in iteritems(engines):
538 526 eid = int(k)
539 527 if eid not in self._engines:
540 528 self._ids.append(eid)
541 529 self._engines[eid] = v
542 530 self._ids = sorted(self._ids)
543 531 if sorted(self._engines.keys()) != list(range(len(self._engines))) and \
544 532 self._task_scheme == 'pure' and self._task_socket:
545 533 self._stop_scheduling_tasks()
546 534
547 535 def _stop_scheduling_tasks(self):
548 536 """Stop scheduling tasks because an engine has been unregistered
549 537 from a pure ZMQ scheduler.
550 538 """
551 539 self._task_socket.close()
552 540 self._task_socket = None
553 541 msg = "An engine has been unregistered, and we are using pure " +\
554 542 "ZMQ task scheduling. Task farming will be disabled."
555 543 if self.outstanding:
556 544 msg += " If you were running tasks when this happened, " +\
557 545 "some `outstanding` msg_ids may never resolve."
558 546 warnings.warn(msg, RuntimeWarning)
559 547
560 548 def _build_targets(self, targets):
561 549 """Turn valid target IDs or 'all' into two lists:
562 550 (int_ids, uuids).
563 551 """
564 552 if not self._ids:
565 553 # flush notification socket if no engines yet, just in case
566 554 if not self.ids:
567 555 raise error.NoEnginesRegistered("Can't build targets without any engines")
568 556
569 557 if targets is None:
570 558 targets = self._ids
571 559 elif isinstance(targets, string_types):
572 560 if targets.lower() == 'all':
573 561 targets = self._ids
574 562 else:
575 563 raise TypeError("%r not valid str target, must be 'all'"%(targets))
576 564 elif isinstance(targets, int):
577 565 if targets < 0:
578 566 targets = self.ids[targets]
579 567 if targets not in self._ids:
580 568 raise IndexError("No such engine: %i"%targets)
581 569 targets = [targets]
582 570
583 571 if isinstance(targets, slice):
584 572 indices = list(range(len(self._ids))[targets])
585 573 ids = self.ids
586 574 targets = [ ids[i] for i in indices ]
587 575
588 576 if not isinstance(targets, (tuple, list, xrange)):
589 577 raise TypeError("targets by int/slice/collection of ints only, not %s"%(type(targets)))
590 578
591 579 return [cast_bytes(self._engines[t]) for t in targets], list(targets)
592 580
593 581 def _connect(self, sshserver, ssh_kwargs, timeout):
594 582 """setup all our socket connections to the cluster. This is called from
595 583 __init__."""
596 584
597 585 # Maybe allow reconnecting?
598 586 if self._connected:
599 587 return
600 588 self._connected=True
601 589
602 590 def connect_socket(s, url):
603 591 if self._ssh:
604 592 return tunnel.tunnel_connection(s, url, sshserver, **ssh_kwargs)
605 593 else:
606 594 return s.connect(url)
607 595
608 596 self.session.send(self._query_socket, 'connection_request')
609 597 # use Poller because zmq.select has wrong units in pyzmq 2.1.7
610 598 poller = zmq.Poller()
611 599 poller.register(self._query_socket, zmq.POLLIN)
612 600 # poll expects milliseconds, timeout is seconds
613 601 evts = poller.poll(timeout*1000)
614 602 if not evts:
615 603 raise error.TimeoutError("Hub connection request timed out")
616 604 idents,msg = self.session.recv(self._query_socket,mode=0)
617 605 if self.debug:
618 606 pprint(msg)
619 607 content = msg['content']
620 608 # self._config['registration'] = dict(content)
621 609 cfg = self._config
622 610 if content['status'] == 'ok':
623 611 self._mux_socket = self._context.socket(zmq.DEALER)
624 612 connect_socket(self._mux_socket, cfg['mux'])
625 613
626 614 self._task_socket = self._context.socket(zmq.DEALER)
627 615 connect_socket(self._task_socket, cfg['task'])
628 616
629 617 self._notification_socket = self._context.socket(zmq.SUB)
630 618 self._notification_socket.setsockopt(zmq.SUBSCRIBE, b'')
631 619 connect_socket(self._notification_socket, cfg['notification'])
632 620
633 621 self._control_socket = self._context.socket(zmq.DEALER)
634 622 connect_socket(self._control_socket, cfg['control'])
635 623
636 624 self._iopub_socket = self._context.socket(zmq.SUB)
637 625 self._iopub_socket.setsockopt(zmq.SUBSCRIBE, b'')
638 626 connect_socket(self._iopub_socket, cfg['iopub'])
639 627
640 628 self._update_engines(dict(content['engines']))
641 629 else:
642 630 self._connected = False
643 631 raise Exception("Failed to connect!")
644 632
645 633 #--------------------------------------------------------------------------
646 634 # handlers and callbacks for incoming messages
647 635 #--------------------------------------------------------------------------
648 636
649 637 def _unwrap_exception(self, content):
650 638 """unwrap exception, and remap engine_id to int."""
651 639 e = error.unwrap_exception(content)
652 640 # print e.traceback
653 641 if e.engine_info:
654 642 e_uuid = e.engine_info['engine_uuid']
655 643 eid = self._engines[e_uuid]
656 644 e.engine_info['engine_id'] = eid
657 645 return e
658 646
659 647 def _extract_metadata(self, msg):
660 648 header = msg['header']
661 649 parent = msg['parent_header']
662 650 msg_meta = msg['metadata']
663 651 content = msg['content']
664 652 md = {'msg_id' : parent['msg_id'],
665 653 'received' : datetime.now(),
666 654 'engine_uuid' : msg_meta.get('engine', None),
667 655 'follow' : msg_meta.get('follow', []),
668 656 'after' : msg_meta.get('after', []),
669 657 'status' : content['status'],
670 658 }
671 659
672 660 if md['engine_uuid'] is not None:
673 661 md['engine_id'] = self._engines.get(md['engine_uuid'], None)
674 662
675 663 if 'date' in parent:
676 664 md['submitted'] = parent['date']
677 665 if 'started' in msg_meta:
678 666 md['started'] = parse_date(msg_meta['started'])
679 667 if 'date' in header:
680 668 md['completed'] = header['date']
681 669 return md
682 670
683 671 def _register_engine(self, msg):
684 672 """Register a new engine, and update our connection info."""
685 673 content = msg['content']
686 674 eid = content['id']
687 675 d = {eid : content['uuid']}
688 676 self._update_engines(d)
689 677
690 678 def _unregister_engine(self, msg):
691 679 """Unregister an engine that has died."""
692 680 content = msg['content']
693 681 eid = int(content['id'])
694 682 if eid in self._ids:
695 683 self._ids.remove(eid)
696 684 uuid = self._engines.pop(eid)
697 685
698 686 self._handle_stranded_msgs(eid, uuid)
699 687
700 688 if self._task_socket and self._task_scheme == 'pure':
701 689 self._stop_scheduling_tasks()
702 690
703 691 def _handle_stranded_msgs(self, eid, uuid):
704 692 """Handle messages known to be on an engine when the engine unregisters.
705 693
706 694 It is possible that this will fire prematurely - that is, an engine will
707 695 go down after completing a result, and the client will be notified
708 696 of the unregistration and later receive the successful result.
709 697 """
710 698
711 699 outstanding = self._outstanding_dict[uuid]
712 700
713 701 for msg_id in list(outstanding):
714 702 if msg_id in self.results:
715 703 # we already
716 704 continue
717 705 try:
718 706 raise error.EngineError("Engine %r died while running task %r"%(eid, msg_id))
719 707 except:
720 708 content = error.wrap_exception()
721 709 # build a fake message:
722 710 msg = self.session.msg('apply_reply', content=content)
723 711 msg['parent_header']['msg_id'] = msg_id
724 712 msg['metadata']['engine'] = uuid
725 713 self._handle_apply_reply(msg)
726 714
727 715 def _handle_execute_reply(self, msg):
728 716 """Save the reply to an execute_request into our results.
729 717
730 718 execute messages are never actually used. apply is used instead.
731 719 """
732 720
733 721 parent = msg['parent_header']
734 722 msg_id = parent['msg_id']
735 723 if msg_id not in self.outstanding:
736 724 if msg_id in self.history:
737 725 print("got stale result: %s"%msg_id)
738 726 else:
739 727 print("got unknown result: %s"%msg_id)
740 728 else:
741 729 self.outstanding.remove(msg_id)
742 730
743 731 content = msg['content']
744 732 header = msg['header']
745 733
746 734 # construct metadata:
747 735 md = self.metadata[msg_id]
748 736 md.update(self._extract_metadata(msg))
749 737 # is this redundant?
750 738 self.metadata[msg_id] = md
751 739
752 740 e_outstanding = self._outstanding_dict[md['engine_uuid']]
753 741 if msg_id in e_outstanding:
754 742 e_outstanding.remove(msg_id)
755 743
756 744 # construct result:
757 745 if content['status'] == 'ok':
758 746 self.results[msg_id] = ExecuteReply(msg_id, content, md)
759 747 elif content['status'] == 'aborted':
760 748 self.results[msg_id] = error.TaskAborted(msg_id)
761 749 elif content['status'] == 'resubmitted':
762 750 # TODO: handle resubmission
763 751 pass
764 752 else:
765 753 self.results[msg_id] = self._unwrap_exception(content)
766 754
767 755 def _handle_apply_reply(self, msg):
768 756 """Save the reply to an apply_request into our results."""
769 757 parent = msg['parent_header']
770 758 msg_id = parent['msg_id']
771 759 if msg_id not in self.outstanding:
772 760 if msg_id in self.history:
773 761 print("got stale result: %s"%msg_id)
774 762 print(self.results[msg_id])
775 763 print(msg)
776 764 else:
777 765 print("got unknown result: %s"%msg_id)
778 766 else:
779 767 self.outstanding.remove(msg_id)
780 768 content = msg['content']
781 769 header = msg['header']
782 770
783 771 # construct metadata:
784 772 md = self.metadata[msg_id]
785 773 md.update(self._extract_metadata(msg))
786 774 # is this redundant?
787 775 self.metadata[msg_id] = md
788 776
789 777 e_outstanding = self._outstanding_dict[md['engine_uuid']]
790 778 if msg_id in e_outstanding:
791 779 e_outstanding.remove(msg_id)
792 780
793 781 # construct result:
794 782 if content['status'] == 'ok':
795 783 self.results[msg_id] = serialize.unserialize_object(msg['buffers'])[0]
796 784 elif content['status'] == 'aborted':
797 785 self.results[msg_id] = error.TaskAborted(msg_id)
798 786 elif content['status'] == 'resubmitted':
799 787 # TODO: handle resubmission
800 788 pass
801 789 else:
802 790 self.results[msg_id] = self._unwrap_exception(content)
803 791
804 792 def _flush_notifications(self):
805 793 """Flush notifications of engine registrations waiting
806 794 in ZMQ queue."""
807 795 idents,msg = self.session.recv(self._notification_socket, mode=zmq.NOBLOCK)
808 796 while msg is not None:
809 797 if self.debug:
810 798 pprint(msg)
811 799 msg_type = msg['header']['msg_type']
812 800 handler = self._notification_handlers.get(msg_type, None)
813 801 if handler is None:
814 802 raise Exception("Unhandled message type: %s" % msg_type)
815 803 else:
816 804 handler(msg)
817 805 idents,msg = self.session.recv(self._notification_socket, mode=zmq.NOBLOCK)
818 806
819 807 def _flush_results(self, sock):
820 808 """Flush task or queue results waiting in ZMQ queue."""
821 809 idents,msg = self.session.recv(sock, mode=zmq.NOBLOCK)
822 810 while msg is not None:
823 811 if self.debug:
824 812 pprint(msg)
825 813 msg_type = msg['header']['msg_type']
826 814 handler = self._queue_handlers.get(msg_type, None)
827 815 if handler is None:
828 816 raise Exception("Unhandled message type: %s" % msg_type)
829 817 else:
830 818 handler(msg)
831 819 idents,msg = self.session.recv(sock, mode=zmq.NOBLOCK)
832 820
833 821 def _flush_control(self, sock):
834 822 """Flush replies from the control channel waiting
835 823 in the ZMQ queue.
836 824
837 825 Currently: ignore them."""
838 826 if self._ignored_control_replies <= 0:
839 827 return
840 828 idents,msg = self.session.recv(sock, mode=zmq.NOBLOCK)
841 829 while msg is not None:
842 830 self._ignored_control_replies -= 1
843 831 if self.debug:
844 832 pprint(msg)
845 833 idents,msg = self.session.recv(sock, mode=zmq.NOBLOCK)
846 834
847 835 def _flush_ignored_control(self):
848 836 """flush ignored control replies"""
849 837 while self._ignored_control_replies > 0:
850 838 self.session.recv(self._control_socket)
851 839 self._ignored_control_replies -= 1
852 840
853 841 def _flush_ignored_hub_replies(self):
854 842 ident,msg = self.session.recv(self._query_socket, mode=zmq.NOBLOCK)
855 843 while msg is not None:
856 844 ident,msg = self.session.recv(self._query_socket, mode=zmq.NOBLOCK)
857 845
858 846 def _flush_iopub(self, sock):
859 847 """Flush replies from the iopub channel waiting
860 848 in the ZMQ queue.
861 849 """
862 850 idents,msg = self.session.recv(sock, mode=zmq.NOBLOCK)
863 851 while msg is not None:
864 852 if self.debug:
865 853 pprint(msg)
866 854 parent = msg['parent_header']
867 855 # ignore IOPub messages with no parent.
868 856 # Caused by print statements or warnings from before the first execution.
869 857 if not parent:
870 858 idents,msg = self.session.recv(sock, mode=zmq.NOBLOCK)
871 859 continue
872 860 msg_id = parent['msg_id']
873 861 content = msg['content']
874 862 header = msg['header']
875 863 msg_type = msg['header']['msg_type']
876 864
877 865 # init metadata:
878 866 md = self.metadata[msg_id]
879 867
880 868 if msg_type == 'stream':
881 869 name = content['name']
882 870 s = md[name] or ''
883 871 md[name] = s + content['data']
884 872 elif msg_type == 'pyerr':
885 873 md.update({'pyerr' : self._unwrap_exception(content)})
886 elif msg_type == 'pyin':
887 md.update({'pyin' : content['code']})
874 elif msg_type == 'execute_input':
875 md.update({'execute_input' : content['code']})
888 876 elif msg_type == 'display_data':
889 877 md['outputs'].append(content)
890 878 elif msg_type == 'pyout':
891 879 md['pyout'] = content
892 880 elif msg_type == 'data_message':
893 881 data, remainder = serialize.unserialize_object(msg['buffers'])
894 882 md['data'].update(data)
895 883 elif msg_type == 'status':
896 884 # idle message comes after all outputs
897 885 if content['execution_state'] == 'idle':
898 886 md['outputs_ready'] = True
899 887 else:
900 888 # unhandled msg_type (status, etc.)
901 889 pass
902 890
903 891 # reduntant?
904 892 self.metadata[msg_id] = md
905 893
906 894 idents,msg = self.session.recv(sock, mode=zmq.NOBLOCK)
907 895
908 896 #--------------------------------------------------------------------------
909 897 # len, getitem
910 898 #--------------------------------------------------------------------------
911 899
912 900 def __len__(self):
913 901 """len(client) returns # of engines."""
914 902 return len(self.ids)
915 903
916 904 def __getitem__(self, key):
917 905 """index access returns DirectView multiplexer objects
918 906
919 907 Must be int, slice, or list/tuple/xrange of ints"""
920 908 if not isinstance(key, (int, slice, tuple, list, xrange)):
921 909 raise TypeError("key by int/slice/iterable of ints only, not %s"%(type(key)))
922 910 else:
923 911 return self.direct_view(key)
924 912
925 913 def __iter__(self):
926 914 """Since we define getitem, Client is iterable
927 915
928 916 but unless we also define __iter__, it won't work correctly unless engine IDs
929 917 start at zero and are continuous.
930 918 """
931 919 for eid in self.ids:
932 920 yield self.direct_view(eid)
933 921
934 922 #--------------------------------------------------------------------------
935 923 # Begin public methods
936 924 #--------------------------------------------------------------------------
937 925
938 926 @property
939 927 def ids(self):
940 928 """Always up-to-date ids property."""
941 929 self._flush_notifications()
942 930 # always copy:
943 931 return list(self._ids)
944 932
945 933 def activate(self, targets='all', suffix=''):
946 934 """Create a DirectView and register it with IPython magics
947 935
948 936 Defines the magics `%px, %autopx, %pxresult, %%px`
949 937
950 938 Parameters
951 939 ----------
952 940
953 941 targets: int, list of ints, or 'all'
954 942 The engines on which the view's magics will run
955 943 suffix: str [default: '']
956 944 The suffix, if any, for the magics. This allows you to have
957 945 multiple views associated with parallel magics at the same time.
958 946
959 947 e.g. ``rc.activate(targets=0, suffix='0')`` will give you
960 948 the magics ``%px0``, ``%pxresult0``, etc. for running magics just
961 949 on engine 0.
962 950 """
963 951 view = self.direct_view(targets)
964 952 view.block = True
965 953 view.activate(suffix)
966 954 return view
967 955
968 956 def close(self, linger=None):
969 957 """Close my zmq Sockets
970 958
971 959 If `linger`, set the zmq LINGER socket option,
972 960 which allows discarding of messages.
973 961 """
974 962 if self._closed:
975 963 return
976 964 self.stop_spin_thread()
977 965 snames = [ trait for trait in self.trait_names() if trait.endswith("socket") ]
978 966 for name in snames:
979 967 socket = getattr(self, name)
980 968 if socket is not None and not socket.closed:
981 969 if linger is not None:
982 970 socket.close(linger=linger)
983 971 else:
984 972 socket.close()
985 973 self._closed = True
986 974
987 975 def _spin_every(self, interval=1):
988 976 """target func for use in spin_thread"""
989 977 while True:
990 978 if self._stop_spinning.is_set():
991 979 return
992 980 time.sleep(interval)
993 981 self.spin()
994 982
995 983 def spin_thread(self, interval=1):
996 984 """call Client.spin() in a background thread on some regular interval
997 985
998 986 This helps ensure that messages don't pile up too much in the zmq queue
999 987 while you are working on other things, or just leaving an idle terminal.
1000 988
1001 989 It also helps limit potential padding of the `received` timestamp
1002 990 on AsyncResult objects, used for timings.
1003 991
1004 992 Parameters
1005 993 ----------
1006 994
1007 995 interval : float, optional
1008 996 The interval on which to spin the client in the background thread
1009 997 (simply passed to time.sleep).
1010 998
1011 999 Notes
1012 1000 -----
1013 1001
1014 1002 For precision timing, you may want to use this method to put a bound
1015 1003 on the jitter (in seconds) in `received` timestamps used
1016 1004 in AsyncResult.wall_time.
1017 1005
1018 1006 """
1019 1007 if self._spin_thread is not None:
1020 1008 self.stop_spin_thread()
1021 1009 self._stop_spinning.clear()
1022 1010 self._spin_thread = Thread(target=self._spin_every, args=(interval,))
1023 1011 self._spin_thread.daemon = True
1024 1012 self._spin_thread.start()
1025 1013
1026 1014 def stop_spin_thread(self):
1027 1015 """stop background spin_thread, if any"""
1028 1016 if self._spin_thread is not None:
1029 1017 self._stop_spinning.set()
1030 1018 self._spin_thread.join()
1031 1019 self._spin_thread = None
1032 1020
1033 1021 def spin(self):
1034 1022 """Flush any registration notifications and execution results
1035 1023 waiting in the ZMQ queue.
1036 1024 """
1037 1025 if self._notification_socket:
1038 1026 self._flush_notifications()
1039 1027 if self._iopub_socket:
1040 1028 self._flush_iopub(self._iopub_socket)
1041 1029 if self._mux_socket:
1042 1030 self._flush_results(self._mux_socket)
1043 1031 if self._task_socket:
1044 1032 self._flush_results(self._task_socket)
1045 1033 if self._control_socket:
1046 1034 self._flush_control(self._control_socket)
1047 1035 if self._query_socket:
1048 1036 self._flush_ignored_hub_replies()
1049 1037
1050 1038 def wait(self, jobs=None, timeout=-1):
1051 1039 """waits on one or more `jobs`, for up to `timeout` seconds.
1052 1040
1053 1041 Parameters
1054 1042 ----------
1055 1043
1056 1044 jobs : int, str, or list of ints and/or strs, or one or more AsyncResult objects
1057 1045 ints are indices to self.history
1058 1046 strs are msg_ids
1059 1047 default: wait on all outstanding messages
1060 1048 timeout : float
1061 1049 a time in seconds, after which to give up.
1062 1050 default is -1, which means no timeout
1063 1051
1064 1052 Returns
1065 1053 -------
1066 1054
1067 1055 True : when all msg_ids are done
1068 1056 False : timeout reached, some msg_ids still outstanding
1069 1057 """
1070 1058 tic = time.time()
1071 1059 if jobs is None:
1072 1060 theids = self.outstanding
1073 1061 else:
1074 1062 if isinstance(jobs, string_types + (int, AsyncResult)):
1075 1063 jobs = [jobs]
1076 1064 theids = set()
1077 1065 for job in jobs:
1078 1066 if isinstance(job, int):
1079 1067 # index access
1080 1068 job = self.history[job]
1081 1069 elif isinstance(job, AsyncResult):
1082 1070 theids.update(job.msg_ids)
1083 1071 continue
1084 1072 theids.add(job)
1085 1073 if not theids.intersection(self.outstanding):
1086 1074 return True
1087 1075 self.spin()
1088 1076 while theids.intersection(self.outstanding):
1089 1077 if timeout >= 0 and ( time.time()-tic ) > timeout:
1090 1078 break
1091 1079 time.sleep(1e-3)
1092 1080 self.spin()
1093 1081 return len(theids.intersection(self.outstanding)) == 0
1094 1082
1095 1083 #--------------------------------------------------------------------------
1096 1084 # Control methods
1097 1085 #--------------------------------------------------------------------------
1098 1086
1099 1087 @spin_first
1100 1088 def clear(self, targets=None, block=None):
1101 1089 """Clear the namespace in target(s)."""
1102 1090 block = self.block if block is None else block
1103 1091 targets = self._build_targets(targets)[0]
1104 1092 for t in targets:
1105 1093 self.session.send(self._control_socket, 'clear_request', content={}, ident=t)
1106 1094 error = False
1107 1095 if block:
1108 1096 self._flush_ignored_control()
1109 1097 for i in range(len(targets)):
1110 1098 idents,msg = self.session.recv(self._control_socket,0)
1111 1099 if self.debug:
1112 1100 pprint(msg)
1113 1101 if msg['content']['status'] != 'ok':
1114 1102 error = self._unwrap_exception(msg['content'])
1115 1103 else:
1116 1104 self._ignored_control_replies += len(targets)
1117 1105 if error:
1118 1106 raise error
1119 1107
1120 1108
1121 1109 @spin_first
1122 1110 def abort(self, jobs=None, targets=None, block=None):
1123 1111 """Abort specific jobs from the execution queues of target(s).
1124 1112
1125 1113 This is a mechanism to prevent jobs that have already been submitted
1126 1114 from executing.
1127 1115
1128 1116 Parameters
1129 1117 ----------
1130 1118
1131 1119 jobs : msg_id, list of msg_ids, or AsyncResult
1132 1120 The jobs to be aborted
1133 1121
1134 1122 If unspecified/None: abort all outstanding jobs.
1135 1123
1136 1124 """
1137 1125 block = self.block if block is None else block
1138 1126 jobs = jobs if jobs is not None else list(self.outstanding)
1139 1127 targets = self._build_targets(targets)[0]
1140 1128
1141 1129 msg_ids = []
1142 1130 if isinstance(jobs, string_types + (AsyncResult,)):
1143 1131 jobs = [jobs]
1144 1132 bad_ids = [obj for obj in jobs if not isinstance(obj, string_types + (AsyncResult,))]
1145 1133 if bad_ids:
1146 1134 raise TypeError("Invalid msg_id type %r, expected str or AsyncResult"%bad_ids[0])
1147 1135 for j in jobs:
1148 1136 if isinstance(j, AsyncResult):
1149 1137 msg_ids.extend(j.msg_ids)
1150 1138 else:
1151 1139 msg_ids.append(j)
1152 1140 content = dict(msg_ids=msg_ids)
1153 1141 for t in targets:
1154 1142 self.session.send(self._control_socket, 'abort_request',
1155 1143 content=content, ident=t)
1156 1144 error = False
1157 1145 if block:
1158 1146 self._flush_ignored_control()
1159 1147 for i in range(len(targets)):
1160 1148 idents,msg = self.session.recv(self._control_socket,0)
1161 1149 if self.debug:
1162 1150 pprint(msg)
1163 1151 if msg['content']['status'] != 'ok':
1164 1152 error = self._unwrap_exception(msg['content'])
1165 1153 else:
1166 1154 self._ignored_control_replies += len(targets)
1167 1155 if error:
1168 1156 raise error
1169 1157
1170 1158 @spin_first
1171 1159 def shutdown(self, targets='all', restart=False, hub=False, block=None):
1172 1160 """Terminates one or more engine processes, optionally including the hub.
1173 1161
1174 1162 Parameters
1175 1163 ----------
1176 1164
1177 1165 targets: list of ints or 'all' [default: all]
1178 1166 Which engines to shutdown.
1179 1167 hub: bool [default: False]
1180 1168 Whether to include the Hub. hub=True implies targets='all'.
1181 1169 block: bool [default: self.block]
1182 1170 Whether to wait for clean shutdown replies or not.
1183 1171 restart: bool [default: False]
1184 1172 NOT IMPLEMENTED
1185 1173 whether to restart engines after shutting them down.
1186 1174 """
1187 1175 from IPython.parallel.error import NoEnginesRegistered
1188 1176 if restart:
1189 1177 raise NotImplementedError("Engine restart is not yet implemented")
1190 1178
1191 1179 block = self.block if block is None else block
1192 1180 if hub:
1193 1181 targets = 'all'
1194 1182 try:
1195 1183 targets = self._build_targets(targets)[0]
1196 1184 except NoEnginesRegistered:
1197 1185 targets = []
1198 1186 for t in targets:
1199 1187 self.session.send(self._control_socket, 'shutdown_request',
1200 1188 content={'restart':restart},ident=t)
1201 1189 error = False
1202 1190 if block or hub:
1203 1191 self._flush_ignored_control()
1204 1192 for i in range(len(targets)):
1205 1193 idents,msg = self.session.recv(self._control_socket, 0)
1206 1194 if self.debug:
1207 1195 pprint(msg)
1208 1196 if msg['content']['status'] != 'ok':
1209 1197 error = self._unwrap_exception(msg['content'])
1210 1198 else:
1211 1199 self._ignored_control_replies += len(targets)
1212 1200
1213 1201 if hub:
1214 1202 time.sleep(0.25)
1215 1203 self.session.send(self._query_socket, 'shutdown_request')
1216 1204 idents,msg = self.session.recv(self._query_socket, 0)
1217 1205 if self.debug:
1218 1206 pprint(msg)
1219 1207 if msg['content']['status'] != 'ok':
1220 1208 error = self._unwrap_exception(msg['content'])
1221 1209
1222 1210 if error:
1223 1211 raise error
1224 1212
1225 1213 #--------------------------------------------------------------------------
1226 1214 # Execution related methods
1227 1215 #--------------------------------------------------------------------------
1228 1216
1229 1217 def _maybe_raise(self, result):
1230 1218 """wrapper for maybe raising an exception if apply failed."""
1231 1219 if isinstance(result, error.RemoteError):
1232 1220 raise result
1233 1221
1234 1222 return result
1235 1223
1236 1224 def send_apply_request(self, socket, f, args=None, kwargs=None, metadata=None, track=False,
1237 1225 ident=None):
1238 1226 """construct and send an apply message via a socket.
1239 1227
1240 1228 This is the principal method with which all engine execution is performed by views.
1241 1229 """
1242 1230
1243 1231 if self._closed:
1244 1232 raise RuntimeError("Client cannot be used after its sockets have been closed")
1245 1233
1246 1234 # defaults:
1247 1235 args = args if args is not None else []
1248 1236 kwargs = kwargs if kwargs is not None else {}
1249 1237 metadata = metadata if metadata is not None else {}
1250 1238
1251 1239 # validate arguments
1252 1240 if not callable(f) and not isinstance(f, Reference):
1253 1241 raise TypeError("f must be callable, not %s"%type(f))
1254 1242 if not isinstance(args, (tuple, list)):
1255 1243 raise TypeError("args must be tuple or list, not %s"%type(args))
1256 1244 if not isinstance(kwargs, dict):
1257 1245 raise TypeError("kwargs must be dict, not %s"%type(kwargs))
1258 1246 if not isinstance(metadata, dict):
1259 1247 raise TypeError("metadata must be dict, not %s"%type(metadata))
1260 1248
1261 1249 bufs = serialize.pack_apply_message(f, args, kwargs,
1262 1250 buffer_threshold=self.session.buffer_threshold,
1263 1251 item_threshold=self.session.item_threshold,
1264 1252 )
1265 1253
1266 1254 msg = self.session.send(socket, "apply_request", buffers=bufs, ident=ident,
1267 1255 metadata=metadata, track=track)
1268 1256
1269 1257 msg_id = msg['header']['msg_id']
1270 1258 self.outstanding.add(msg_id)
1271 1259 if ident:
1272 1260 # possibly routed to a specific engine
1273 1261 if isinstance(ident, list):
1274 1262 ident = ident[-1]
1275 1263 if ident in self._engines.values():
1276 1264 # save for later, in case of engine death
1277 1265 self._outstanding_dict[ident].add(msg_id)
1278 1266 self.history.append(msg_id)
1279 1267 self.metadata[msg_id]['submitted'] = datetime.now()
1280 1268
1281 1269 return msg
1282 1270
1283 1271 def send_execute_request(self, socket, code, silent=True, metadata=None, ident=None):
1284 1272 """construct and send an execute request via a socket.
1285 1273
1286 1274 """
1287 1275
1288 1276 if self._closed:
1289 1277 raise RuntimeError("Client cannot be used after its sockets have been closed")
1290 1278
1291 1279 # defaults:
1292 1280 metadata = metadata if metadata is not None else {}
1293 1281
1294 1282 # validate arguments
1295 1283 if not isinstance(code, string_types):
1296 1284 raise TypeError("code must be text, not %s" % type(code))
1297 1285 if not isinstance(metadata, dict):
1298 1286 raise TypeError("metadata must be dict, not %s" % type(metadata))
1299 1287
1300 1288 content = dict(code=code, silent=bool(silent), user_variables=[], user_expressions={})
1301 1289
1302 1290
1303 1291 msg = self.session.send(socket, "execute_request", content=content, ident=ident,
1304 1292 metadata=metadata)
1305 1293
1306 1294 msg_id = msg['header']['msg_id']
1307 1295 self.outstanding.add(msg_id)
1308 1296 if ident:
1309 1297 # possibly routed to a specific engine
1310 1298 if isinstance(ident, list):
1311 1299 ident = ident[-1]
1312 1300 if ident in self._engines.values():
1313 1301 # save for later, in case of engine death
1314 1302 self._outstanding_dict[ident].add(msg_id)
1315 1303 self.history.append(msg_id)
1316 1304 self.metadata[msg_id]['submitted'] = datetime.now()
1317 1305
1318 1306 return msg
1319 1307
1320 1308 #--------------------------------------------------------------------------
1321 1309 # construct a View object
1322 1310 #--------------------------------------------------------------------------
1323 1311
1324 1312 def load_balanced_view(self, targets=None):
1325 1313 """construct a DirectView object.
1326 1314
1327 1315 If no arguments are specified, create a LoadBalancedView
1328 1316 using all engines.
1329 1317
1330 1318 Parameters
1331 1319 ----------
1332 1320
1333 1321 targets: list,slice,int,etc. [default: use all engines]
1334 1322 The subset of engines across which to load-balance
1335 1323 """
1336 1324 if targets == 'all':
1337 1325 targets = None
1338 1326 if targets is not None:
1339 1327 targets = self._build_targets(targets)[1]
1340 1328 return LoadBalancedView(client=self, socket=self._task_socket, targets=targets)
1341 1329
1342 1330 def direct_view(self, targets='all'):
1343 1331 """construct a DirectView object.
1344 1332
1345 1333 If no targets are specified, create a DirectView using all engines.
1346 1334
1347 1335 rc.direct_view('all') is distinguished from rc[:] in that 'all' will
1348 1336 evaluate the target engines at each execution, whereas rc[:] will connect to
1349 1337 all *current* engines, and that list will not change.
1350 1338
1351 1339 That is, 'all' will always use all engines, whereas rc[:] will not use
1352 1340 engines added after the DirectView is constructed.
1353 1341
1354 1342 Parameters
1355 1343 ----------
1356 1344
1357 1345 targets: list,slice,int,etc. [default: use all engines]
1358 1346 The engines to use for the View
1359 1347 """
1360 1348 single = isinstance(targets, int)
1361 1349 # allow 'all' to be lazily evaluated at each execution
1362 1350 if targets != 'all':
1363 1351 targets = self._build_targets(targets)[1]
1364 1352 if single:
1365 1353 targets = targets[0]
1366 1354 return DirectView(client=self, socket=self._mux_socket, targets=targets)
1367 1355
1368 1356 #--------------------------------------------------------------------------
1369 1357 # Query methods
1370 1358 #--------------------------------------------------------------------------
1371 1359
1372 1360 @spin_first
1373 1361 def get_result(self, indices_or_msg_ids=None, block=None):
1374 1362 """Retrieve a result by msg_id or history index, wrapped in an AsyncResult object.
1375 1363
1376 1364 If the client already has the results, no request to the Hub will be made.
1377 1365
1378 1366 This is a convenient way to construct AsyncResult objects, which are wrappers
1379 1367 that include metadata about execution, and allow for awaiting results that
1380 1368 were not submitted by this Client.
1381 1369
1382 1370 It can also be a convenient way to retrieve the metadata associated with
1383 1371 blocking execution, since it always retrieves
1384 1372
1385 1373 Examples
1386 1374 --------
1387 1375 ::
1388 1376
1389 1377 In [10]: r = client.apply()
1390 1378
1391 1379 Parameters
1392 1380 ----------
1393 1381
1394 1382 indices_or_msg_ids : integer history index, str msg_id, or list of either
1395 1383 The indices or msg_ids of indices to be retrieved
1396 1384
1397 1385 block : bool
1398 1386 Whether to wait for the result to be done
1399 1387
1400 1388 Returns
1401 1389 -------
1402 1390
1403 1391 AsyncResult
1404 1392 A single AsyncResult object will always be returned.
1405 1393
1406 1394 AsyncHubResult
1407 1395 A subclass of AsyncResult that retrieves results from the Hub
1408 1396
1409 1397 """
1410 1398 block = self.block if block is None else block
1411 1399 if indices_or_msg_ids is None:
1412 1400 indices_or_msg_ids = -1
1413 1401
1414 1402 single_result = False
1415 1403 if not isinstance(indices_or_msg_ids, (list,tuple)):
1416 1404 indices_or_msg_ids = [indices_or_msg_ids]
1417 1405 single_result = True
1418 1406
1419 1407 theids = []
1420 1408 for id in indices_or_msg_ids:
1421 1409 if isinstance(id, int):
1422 1410 id = self.history[id]
1423 1411 if not isinstance(id, string_types):
1424 1412 raise TypeError("indices must be str or int, not %r"%id)
1425 1413 theids.append(id)
1426 1414
1427 1415 local_ids = [msg_id for msg_id in theids if (msg_id in self.outstanding or msg_id in self.results)]
1428 1416 remote_ids = [msg_id for msg_id in theids if msg_id not in local_ids]
1429 1417
1430 1418 # given single msg_id initially, get_result shot get the result itself,
1431 1419 # not a length-one list
1432 1420 if single_result:
1433 1421 theids = theids[0]
1434 1422
1435 1423 if remote_ids:
1436 1424 ar = AsyncHubResult(self, msg_ids=theids)
1437 1425 else:
1438 1426 ar = AsyncResult(self, msg_ids=theids)
1439 1427
1440 1428 if block:
1441 1429 ar.wait()
1442 1430
1443 1431 return ar
1444 1432
1445 1433 @spin_first
1446 1434 def resubmit(self, indices_or_msg_ids=None, metadata=None, block=None):
1447 1435 """Resubmit one or more tasks.
1448 1436
1449 1437 in-flight tasks may not be resubmitted.
1450 1438
1451 1439 Parameters
1452 1440 ----------
1453 1441
1454 1442 indices_or_msg_ids : integer history index, str msg_id, or list of either
1455 1443 The indices or msg_ids of indices to be retrieved
1456 1444
1457 1445 block : bool
1458 1446 Whether to wait for the result to be done
1459 1447
1460 1448 Returns
1461 1449 -------
1462 1450
1463 1451 AsyncHubResult
1464 1452 A subclass of AsyncResult that retrieves results from the Hub
1465 1453
1466 1454 """
1467 1455 block = self.block if block is None else block
1468 1456 if indices_or_msg_ids is None:
1469 1457 indices_or_msg_ids = -1
1470 1458
1471 1459 if not isinstance(indices_or_msg_ids, (list,tuple)):
1472 1460 indices_or_msg_ids = [indices_or_msg_ids]
1473 1461
1474 1462 theids = []
1475 1463 for id in indices_or_msg_ids:
1476 1464 if isinstance(id, int):
1477 1465 id = self.history[id]
1478 1466 if not isinstance(id, string_types):
1479 1467 raise TypeError("indices must be str or int, not %r"%id)
1480 1468 theids.append(id)
1481 1469
1482 1470 content = dict(msg_ids = theids)
1483 1471
1484 1472 self.session.send(self._query_socket, 'resubmit_request', content)
1485 1473
1486 1474 zmq.select([self._query_socket], [], [])
1487 1475 idents,msg = self.session.recv(self._query_socket, zmq.NOBLOCK)
1488 1476 if self.debug:
1489 1477 pprint(msg)
1490 1478 content = msg['content']
1491 1479 if content['status'] != 'ok':
1492 1480 raise self._unwrap_exception(content)
1493 1481 mapping = content['resubmitted']
1494 1482 new_ids = [ mapping[msg_id] for msg_id in theids ]
1495 1483
1496 1484 ar = AsyncHubResult(self, msg_ids=new_ids)
1497 1485
1498 1486 if block:
1499 1487 ar.wait()
1500 1488
1501 1489 return ar
1502 1490
1503 1491 @spin_first
1504 1492 def result_status(self, msg_ids, status_only=True):
1505 1493 """Check on the status of the result(s) of the apply request with `msg_ids`.
1506 1494
1507 1495 If status_only is False, then the actual results will be retrieved, else
1508 1496 only the status of the results will be checked.
1509 1497
1510 1498 Parameters
1511 1499 ----------
1512 1500
1513 1501 msg_ids : list of msg_ids
1514 1502 if int:
1515 1503 Passed as index to self.history for convenience.
1516 1504 status_only : bool (default: True)
1517 1505 if False:
1518 1506 Retrieve the actual results of completed tasks.
1519 1507
1520 1508 Returns
1521 1509 -------
1522 1510
1523 1511 results : dict
1524 1512 There will always be the keys 'pending' and 'completed', which will
1525 1513 be lists of msg_ids that are incomplete or complete. If `status_only`
1526 1514 is False, then completed results will be keyed by their `msg_id`.
1527 1515 """
1528 1516 if not isinstance(msg_ids, (list,tuple)):
1529 1517 msg_ids = [msg_ids]
1530 1518
1531 1519 theids = []
1532 1520 for msg_id in msg_ids:
1533 1521 if isinstance(msg_id, int):
1534 1522 msg_id = self.history[msg_id]
1535 1523 if not isinstance(msg_id, string_types):
1536 1524 raise TypeError("msg_ids must be str, not %r"%msg_id)
1537 1525 theids.append(msg_id)
1538 1526
1539 1527 completed = []
1540 1528 local_results = {}
1541 1529
1542 1530 # comment this block out to temporarily disable local shortcut:
1543 1531 for msg_id in theids:
1544 1532 if msg_id in self.results:
1545 1533 completed.append(msg_id)
1546 1534 local_results[msg_id] = self.results[msg_id]
1547 1535 theids.remove(msg_id)
1548 1536
1549 1537 if theids: # some not locally cached
1550 1538 content = dict(msg_ids=theids, status_only=status_only)
1551 1539 msg = self.session.send(self._query_socket, "result_request", content=content)
1552 1540 zmq.select([self._query_socket], [], [])
1553 1541 idents,msg = self.session.recv(self._query_socket, zmq.NOBLOCK)
1554 1542 if self.debug:
1555 1543 pprint(msg)
1556 1544 content = msg['content']
1557 1545 if content['status'] != 'ok':
1558 1546 raise self._unwrap_exception(content)
1559 1547 buffers = msg['buffers']
1560 1548 else:
1561 1549 content = dict(completed=[],pending=[])
1562 1550
1563 1551 content['completed'].extend(completed)
1564 1552
1565 1553 if status_only:
1566 1554 return content
1567 1555
1568 1556 failures = []
1569 1557 # load cached results into result:
1570 1558 content.update(local_results)
1571 1559
1572 1560 # update cache with results:
1573 1561 for msg_id in sorted(theids):
1574 1562 if msg_id in content['completed']:
1575 1563 rec = content[msg_id]
1576 1564 parent = extract_dates(rec['header'])
1577 1565 header = extract_dates(rec['result_header'])
1578 1566 rcontent = rec['result_content']
1579 1567 iodict = rec['io']
1580 1568 if isinstance(rcontent, str):
1581 1569 rcontent = self.session.unpack(rcontent)
1582 1570
1583 1571 md = self.metadata[msg_id]
1584 1572 md_msg = dict(
1585 1573 content=rcontent,
1586 1574 parent_header=parent,
1587 1575 header=header,
1588 1576 metadata=rec['result_metadata'],
1589 1577 )
1590 1578 md.update(self._extract_metadata(md_msg))
1591 1579 if rec.get('received'):
1592 1580 md['received'] = parse_date(rec['received'])
1593 1581 md.update(iodict)
1594 1582
1595 1583 if rcontent['status'] == 'ok':
1596 1584 if header['msg_type'] == 'apply_reply':
1597 1585 res,buffers = serialize.unserialize_object(buffers)
1598 1586 elif header['msg_type'] == 'execute_reply':
1599 1587 res = ExecuteReply(msg_id, rcontent, md)
1600 1588 else:
1601 1589 raise KeyError("unhandled msg type: %r" % header['msg_type'])
1602 1590 else:
1603 1591 res = self._unwrap_exception(rcontent)
1604 1592 failures.append(res)
1605 1593
1606 1594 self.results[msg_id] = res
1607 1595 content[msg_id] = res
1608 1596
1609 1597 if len(theids) == 1 and failures:
1610 1598 raise failures[0]
1611 1599
1612 1600 error.collect_exceptions(failures, "result_status")
1613 1601 return content
1614 1602
1615 1603 @spin_first
1616 1604 def queue_status(self, targets='all', verbose=False):
1617 1605 """Fetch the status of engine queues.
1618 1606
1619 1607 Parameters
1620 1608 ----------
1621 1609
1622 1610 targets : int/str/list of ints/strs
1623 1611 the engines whose states are to be queried.
1624 1612 default : all
1625 1613 verbose : bool
1626 1614 Whether to return lengths only, or lists of ids for each element
1627 1615 """
1628 1616 if targets == 'all':
1629 1617 # allow 'all' to be evaluated on the engine
1630 1618 engine_ids = None
1631 1619 else:
1632 1620 engine_ids = self._build_targets(targets)[1]
1633 1621 content = dict(targets=engine_ids, verbose=verbose)
1634 1622 self.session.send(self._query_socket, "queue_request", content=content)
1635 1623 idents,msg = self.session.recv(self._query_socket, 0)
1636 1624 if self.debug:
1637 1625 pprint(msg)
1638 1626 content = msg['content']
1639 1627 status = content.pop('status')
1640 1628 if status != 'ok':
1641 1629 raise self._unwrap_exception(content)
1642 1630 content = rekey(content)
1643 1631 if isinstance(targets, int):
1644 1632 return content[targets]
1645 1633 else:
1646 1634 return content
1647 1635
1648 1636 def _build_msgids_from_target(self, targets=None):
1649 1637 """Build a list of msg_ids from the list of engine targets"""
1650 1638 if not targets: # needed as _build_targets otherwise uses all engines
1651 1639 return []
1652 1640 target_ids = self._build_targets(targets)[0]
1653 1641 return [md_id for md_id in self.metadata if self.metadata[md_id]["engine_uuid"] in target_ids]
1654 1642
1655 1643 def _build_msgids_from_jobs(self, jobs=None):
1656 1644 """Build a list of msg_ids from "jobs" """
1657 1645 if not jobs:
1658 1646 return []
1659 1647 msg_ids = []
1660 1648 if isinstance(jobs, string_types + (AsyncResult,)):
1661 1649 jobs = [jobs]
1662 1650 bad_ids = [obj for obj in jobs if not isinstance(obj, string_types + (AsyncResult,))]
1663 1651 if bad_ids:
1664 1652 raise TypeError("Invalid msg_id type %r, expected str or AsyncResult"%bad_ids[0])
1665 1653 for j in jobs:
1666 1654 if isinstance(j, AsyncResult):
1667 1655 msg_ids.extend(j.msg_ids)
1668 1656 else:
1669 1657 msg_ids.append(j)
1670 1658 return msg_ids
1671 1659
1672 1660 def purge_local_results(self, jobs=[], targets=[]):
1673 1661 """Clears the client caches of results and their metadata.
1674 1662
1675 1663 Individual results can be purged by msg_id, or the entire
1676 1664 history of specific targets can be purged.
1677 1665
1678 1666 Use `purge_local_results('all')` to scrub everything from the Clients's
1679 1667 results and metadata caches.
1680 1668
1681 1669 After this call all `AsyncResults` are invalid and should be discarded.
1682 1670
1683 1671 If you must "reget" the results, you can still do so by using
1684 1672 `client.get_result(msg_id)` or `client.get_result(asyncresult)`. This will
1685 1673 redownload the results from the hub if they are still available
1686 1674 (i.e `client.purge_hub_results(...)` has not been called.
1687 1675
1688 1676 Parameters
1689 1677 ----------
1690 1678
1691 1679 jobs : str or list of str or AsyncResult objects
1692 1680 the msg_ids whose results should be purged.
1693 1681 targets : int/list of ints
1694 1682 The engines, by integer ID, whose entire result histories are to be purged.
1695 1683
1696 1684 Raises
1697 1685 ------
1698 1686
1699 1687 RuntimeError : if any of the tasks to be purged are still outstanding.
1700 1688
1701 1689 """
1702 1690 if not targets and not jobs:
1703 1691 raise ValueError("Must specify at least one of `targets` and `jobs`")
1704 1692
1705 1693 if jobs == 'all':
1706 1694 if self.outstanding:
1707 1695 raise RuntimeError("Can't purge outstanding tasks: %s" % self.outstanding)
1708 1696 self.results.clear()
1709 1697 self.metadata.clear()
1710 1698 else:
1711 1699 msg_ids = set()
1712 1700 msg_ids.update(self._build_msgids_from_target(targets))
1713 1701 msg_ids.update(self._build_msgids_from_jobs(jobs))
1714 1702 still_outstanding = self.outstanding.intersection(msg_ids)
1715 1703 if still_outstanding:
1716 1704 raise RuntimeError("Can't purge outstanding tasks: %s" % still_outstanding)
1717 1705 for mid in msg_ids:
1718 1706 self.results.pop(mid)
1719 1707 self.metadata.pop(mid)
1720 1708
1721 1709
1722 1710 @spin_first
1723 1711 def purge_hub_results(self, jobs=[], targets=[]):
1724 1712 """Tell the Hub to forget results.
1725 1713
1726 1714 Individual results can be purged by msg_id, or the entire
1727 1715 history of specific targets can be purged.
1728 1716
1729 1717 Use `purge_results('all')` to scrub everything from the Hub's db.
1730 1718
1731 1719 Parameters
1732 1720 ----------
1733 1721
1734 1722 jobs : str or list of str or AsyncResult objects
1735 1723 the msg_ids whose results should be forgotten.
1736 1724 targets : int/str/list of ints/strs
1737 1725 The targets, by int_id, whose entire history is to be purged.
1738 1726
1739 1727 default : None
1740 1728 """
1741 1729 if not targets and not jobs:
1742 1730 raise ValueError("Must specify at least one of `targets` and `jobs`")
1743 1731 if targets:
1744 1732 targets = self._build_targets(targets)[1]
1745 1733
1746 1734 # construct msg_ids from jobs
1747 1735 if jobs == 'all':
1748 1736 msg_ids = jobs
1749 1737 else:
1750 1738 msg_ids = self._build_msgids_from_jobs(jobs)
1751 1739
1752 1740 content = dict(engine_ids=targets, msg_ids=msg_ids)
1753 1741 self.session.send(self._query_socket, "purge_request", content=content)
1754 1742 idents, msg = self.session.recv(self._query_socket, 0)
1755 1743 if self.debug:
1756 1744 pprint(msg)
1757 1745 content = msg['content']
1758 1746 if content['status'] != 'ok':
1759 1747 raise self._unwrap_exception(content)
1760 1748
1761 1749 def purge_results(self, jobs=[], targets=[]):
1762 1750 """Clears the cached results from both the hub and the local client
1763 1751
1764 1752 Individual results can be purged by msg_id, or the entire
1765 1753 history of specific targets can be purged.
1766 1754
1767 1755 Use `purge_results('all')` to scrub every cached result from both the Hub's and
1768 1756 the Client's db.
1769 1757
1770 1758 Equivalent to calling both `purge_hub_results()` and `purge_client_results()` with
1771 1759 the same arguments.
1772 1760
1773 1761 Parameters
1774 1762 ----------
1775 1763
1776 1764 jobs : str or list of str or AsyncResult objects
1777 1765 the msg_ids whose results should be forgotten.
1778 1766 targets : int/str/list of ints/strs
1779 1767 The targets, by int_id, whose entire history is to be purged.
1780 1768
1781 1769 default : None
1782 1770 """
1783 1771 self.purge_local_results(jobs=jobs, targets=targets)
1784 1772 self.purge_hub_results(jobs=jobs, targets=targets)
1785 1773
1786 1774 def purge_everything(self):
1787 1775 """Clears all content from previous Tasks from both the hub and the local client
1788 1776
1789 1777 In addition to calling `purge_results("all")` it also deletes the history and
1790 1778 other bookkeeping lists.
1791 1779 """
1792 1780 self.purge_results("all")
1793 1781 self.history = []
1794 1782 self.session.digest_history.clear()
1795 1783
1796 1784 @spin_first
1797 1785 def hub_history(self):
1798 1786 """Get the Hub's history
1799 1787
1800 1788 Just like the Client, the Hub has a history, which is a list of msg_ids.
1801 1789 This will contain the history of all clients, and, depending on configuration,
1802 1790 may contain history across multiple cluster sessions.
1803 1791
1804 1792 Any msg_id returned here is a valid argument to `get_result`.
1805 1793
1806 1794 Returns
1807 1795 -------
1808 1796
1809 1797 msg_ids : list of strs
1810 1798 list of all msg_ids, ordered by task submission time.
1811 1799 """
1812 1800
1813 1801 self.session.send(self._query_socket, "history_request", content={})
1814 1802 idents, msg = self.session.recv(self._query_socket, 0)
1815 1803
1816 1804 if self.debug:
1817 1805 pprint(msg)
1818 1806 content = msg['content']
1819 1807 if content['status'] != 'ok':
1820 1808 raise self._unwrap_exception(content)
1821 1809 else:
1822 1810 return content['history']
1823 1811
1824 1812 @spin_first
1825 1813 def db_query(self, query, keys=None):
1826 1814 """Query the Hub's TaskRecord database
1827 1815
1828 1816 This will return a list of task record dicts that match `query`
1829 1817
1830 1818 Parameters
1831 1819 ----------
1832 1820
1833 1821 query : mongodb query dict
1834 1822 The search dict. See mongodb query docs for details.
1835 1823 keys : list of strs [optional]
1836 1824 The subset of keys to be returned. The default is to fetch everything but buffers.
1837 1825 'msg_id' will *always* be included.
1838 1826 """
1839 1827 if isinstance(keys, string_types):
1840 1828 keys = [keys]
1841 1829 content = dict(query=query, keys=keys)
1842 1830 self.session.send(self._query_socket, "db_request", content=content)
1843 1831 idents, msg = self.session.recv(self._query_socket, 0)
1844 1832 if self.debug:
1845 1833 pprint(msg)
1846 1834 content = msg['content']
1847 1835 if content['status'] != 'ok':
1848 1836 raise self._unwrap_exception(content)
1849 1837
1850 1838 records = content['records']
1851 1839
1852 1840 buffer_lens = content['buffer_lens']
1853 1841 result_buffer_lens = content['result_buffer_lens']
1854 1842 buffers = msg['buffers']
1855 1843 has_bufs = buffer_lens is not None
1856 1844 has_rbufs = result_buffer_lens is not None
1857 1845 for i,rec in enumerate(records):
1858 1846 # unpack datetime objects
1859 1847 for hkey in ('header', 'result_header'):
1860 1848 if hkey in rec:
1861 1849 rec[hkey] = extract_dates(rec[hkey])
1862 1850 for dtkey in ('submitted', 'started', 'completed', 'received'):
1863 1851 if dtkey in rec:
1864 1852 rec[dtkey] = parse_date(rec[dtkey])
1865 1853 # relink buffers
1866 1854 if has_bufs:
1867 1855 blen = buffer_lens[i]
1868 1856 rec['buffers'], buffers = buffers[:blen],buffers[blen:]
1869 1857 if has_rbufs:
1870 1858 blen = result_buffer_lens[i]
1871 1859 rec['result_buffers'], buffers = buffers[:blen],buffers[blen:]
1872 1860
1873 1861 return records
1874 1862
1875 1863 __all__ = [ 'Client' ]
@@ -1,1449 +1,1440 b''
1 1 """The IPython Controller Hub with 0MQ
2
2 3 This is the master object that handles connections from engines and clients,
3 4 and monitors traffic through the various queues.
4
5 Authors:
6
7 * Min RK
8 5 """
9 #-----------------------------------------------------------------------------
10 # Copyright (C) 2010-2011 The IPython Development Team
11 #
12 # Distributed under the terms of the BSD License. The full license is in
13 # the file COPYING, distributed as part of this software.
14 #-----------------------------------------------------------------------------
15 6
16 #-----------------------------------------------------------------------------
17 # Imports
18 #-----------------------------------------------------------------------------
7 # Copyright (c) IPython Development Team.
8 # Distributed under the terms of the Modified BSD License.
9
19 10 from __future__ import print_function
20 11
21 12 import json
22 13 import os
23 14 import sys
24 15 import time
25 16 from datetime import datetime
26 17
27 18 import zmq
28 19 from zmq.eventloop import ioloop
29 20 from zmq.eventloop.zmqstream import ZMQStream
30 21
31 22 # internal:
32 23 from IPython.utils.importstring import import_item
33 24 from IPython.utils.jsonutil import extract_dates
34 25 from IPython.utils.localinterfaces import localhost
35 26 from IPython.utils.py3compat import cast_bytes, unicode_type, iteritems
36 27 from IPython.utils.traitlets import (
37 28 HasTraits, Instance, Integer, Unicode, Dict, Set, Tuple, CBytes, DottedObjectName
38 29 )
39 30
40 31 from IPython.parallel import error, util
41 32 from IPython.parallel.factory import RegistrationFactory
42 33
43 34 from IPython.kernel.zmq.session import SessionFactory
44 35
45 36 from .heartmonitor import HeartMonitor
46 37
47 38 #-----------------------------------------------------------------------------
48 39 # Code
49 40 #-----------------------------------------------------------------------------
50 41
51 42 def _passer(*args, **kwargs):
52 43 return
53 44
54 45 def _printer(*args, **kwargs):
55 46 print (args)
56 47 print (kwargs)
57 48
58 49 def empty_record():
59 50 """Return an empty dict with all record keys."""
60 51 return {
61 52 'msg_id' : None,
62 53 'header' : None,
63 54 'metadata' : None,
64 55 'content': None,
65 56 'buffers': None,
66 57 'submitted': None,
67 58 'client_uuid' : None,
68 59 'engine_uuid' : None,
69 60 'started': None,
70 61 'completed': None,
71 62 'resubmitted': None,
72 63 'received': None,
73 64 'result_header' : None,
74 65 'result_metadata' : None,
75 66 'result_content' : None,
76 67 'result_buffers' : None,
77 68 'queue' : None,
78 'pyin' : None,
69 'execute_input' : None,
79 70 'pyout': None,
80 71 'pyerr': None,
81 72 'stdout': '',
82 73 'stderr': '',
83 74 }
84 75
85 76 def init_record(msg):
86 77 """Initialize a TaskRecord based on a request."""
87 78 header = msg['header']
88 79 return {
89 80 'msg_id' : header['msg_id'],
90 81 'header' : header,
91 82 'content': msg['content'],
92 83 'metadata': msg['metadata'],
93 84 'buffers': msg['buffers'],
94 85 'submitted': header['date'],
95 86 'client_uuid' : None,
96 87 'engine_uuid' : None,
97 88 'started': None,
98 89 'completed': None,
99 90 'resubmitted': None,
100 91 'received': None,
101 92 'result_header' : None,
102 93 'result_metadata': None,
103 94 'result_content' : None,
104 95 'result_buffers' : None,
105 96 'queue' : None,
106 'pyin' : None,
97 'execute_input' : None,
107 98 'pyout': None,
108 99 'pyerr': None,
109 100 'stdout': '',
110 101 'stderr': '',
111 102 }
112 103
113 104
114 105 class EngineConnector(HasTraits):
115 106 """A simple object for accessing the various zmq connections of an object.
116 107 Attributes are:
117 108 id (int): engine ID
118 109 uuid (unicode): engine UUID
119 110 pending: set of msg_ids
120 111 stallback: DelayedCallback for stalled registration
121 112 """
122 113
123 114 id = Integer(0)
124 115 uuid = Unicode()
125 116 pending = Set()
126 117 stallback = Instance(ioloop.DelayedCallback)
127 118
128 119
129 120 _db_shortcuts = {
130 121 'sqlitedb' : 'IPython.parallel.controller.sqlitedb.SQLiteDB',
131 122 'mongodb' : 'IPython.parallel.controller.mongodb.MongoDB',
132 123 'dictdb' : 'IPython.parallel.controller.dictdb.DictDB',
133 124 'nodb' : 'IPython.parallel.controller.dictdb.NoDB',
134 125 }
135 126
136 127 class HubFactory(RegistrationFactory):
137 128 """The Configurable for setting up a Hub."""
138 129
139 130 # port-pairs for monitoredqueues:
140 131 hb = Tuple(Integer,Integer,config=True,
141 132 help="""PUB/ROUTER Port pair for Engine heartbeats""")
142 133 def _hb_default(self):
143 134 return tuple(util.select_random_ports(2))
144 135
145 136 mux = Tuple(Integer,Integer,config=True,
146 137 help="""Client/Engine Port pair for MUX queue""")
147 138
148 139 def _mux_default(self):
149 140 return tuple(util.select_random_ports(2))
150 141
151 142 task = Tuple(Integer,Integer,config=True,
152 143 help="""Client/Engine Port pair for Task queue""")
153 144 def _task_default(self):
154 145 return tuple(util.select_random_ports(2))
155 146
156 147 control = Tuple(Integer,Integer,config=True,
157 148 help="""Client/Engine Port pair for Control queue""")
158 149
159 150 def _control_default(self):
160 151 return tuple(util.select_random_ports(2))
161 152
162 153 iopub = Tuple(Integer,Integer,config=True,
163 154 help="""Client/Engine Port pair for IOPub relay""")
164 155
165 156 def _iopub_default(self):
166 157 return tuple(util.select_random_ports(2))
167 158
168 159 # single ports:
169 160 mon_port = Integer(config=True,
170 161 help="""Monitor (SUB) port for queue traffic""")
171 162
172 163 def _mon_port_default(self):
173 164 return util.select_random_ports(1)[0]
174 165
175 166 notifier_port = Integer(config=True,
176 167 help="""PUB port for sending engine status notifications""")
177 168
178 169 def _notifier_port_default(self):
179 170 return util.select_random_ports(1)[0]
180 171
181 172 engine_ip = Unicode(config=True,
182 173 help="IP on which to listen for engine connections. [default: loopback]")
183 174 def _engine_ip_default(self):
184 175 return localhost()
185 176 engine_transport = Unicode('tcp', config=True,
186 177 help="0MQ transport for engine connections. [default: tcp]")
187 178
188 179 client_ip = Unicode(config=True,
189 180 help="IP on which to listen for client connections. [default: loopback]")
190 181 client_transport = Unicode('tcp', config=True,
191 182 help="0MQ transport for client connections. [default : tcp]")
192 183
193 184 monitor_ip = Unicode(config=True,
194 185 help="IP on which to listen for monitor messages. [default: loopback]")
195 186 monitor_transport = Unicode('tcp', config=True,
196 187 help="0MQ transport for monitor messages. [default : tcp]")
197 188
198 189 _client_ip_default = _monitor_ip_default = _engine_ip_default
199 190
200 191
201 192 monitor_url = Unicode('')
202 193
203 194 db_class = DottedObjectName('NoDB',
204 195 config=True, help="""The class to use for the DB backend
205 196
206 197 Options include:
207 198
208 199 SQLiteDB: SQLite
209 200 MongoDB : use MongoDB
210 201 DictDB : in-memory storage (fastest, but be mindful of memory growth of the Hub)
211 202 NoDB : disable database altogether (default)
212 203
213 204 """)
214 205
215 206 registration_timeout = Integer(0, config=True,
216 207 help="Engine registration timeout in seconds [default: max(30,"
217 208 "10*heartmonitor.period)]" )
218 209
219 210 def _registration_timeout_default(self):
220 211 if self.heartmonitor is None:
221 212 # early initialization, this value will be ignored
222 213 return 0
223 214 # heartmonitor period is in milliseconds, so 10x in seconds is .01
224 215 return max(30, int(.01 * self.heartmonitor.period))
225 216
226 217 # not configurable
227 218 db = Instance('IPython.parallel.controller.dictdb.BaseDB')
228 219 heartmonitor = Instance('IPython.parallel.controller.heartmonitor.HeartMonitor')
229 220
230 221 def _ip_changed(self, name, old, new):
231 222 self.engine_ip = new
232 223 self.client_ip = new
233 224 self.monitor_ip = new
234 225 self._update_monitor_url()
235 226
236 227 def _update_monitor_url(self):
237 228 self.monitor_url = "%s://%s:%i" % (self.monitor_transport, self.monitor_ip, self.mon_port)
238 229
239 230 def _transport_changed(self, name, old, new):
240 231 self.engine_transport = new
241 232 self.client_transport = new
242 233 self.monitor_transport = new
243 234 self._update_monitor_url()
244 235
245 236 def __init__(self, **kwargs):
246 237 super(HubFactory, self).__init__(**kwargs)
247 238 self._update_monitor_url()
248 239
249 240
250 241 def construct(self):
251 242 self.init_hub()
252 243
253 244 def start(self):
254 245 self.heartmonitor.start()
255 246 self.log.info("Heartmonitor started")
256 247
257 248 def client_url(self, channel):
258 249 """return full zmq url for a named client channel"""
259 250 return "%s://%s:%i" % (self.client_transport, self.client_ip, self.client_info[channel])
260 251
261 252 def engine_url(self, channel):
262 253 """return full zmq url for a named engine channel"""
263 254 return "%s://%s:%i" % (self.engine_transport, self.engine_ip, self.engine_info[channel])
264 255
265 256 def init_hub(self):
266 257 """construct Hub object"""
267 258
268 259 ctx = self.context
269 260 loop = self.loop
270 261 if 'TaskScheduler.scheme_name' in self.config:
271 262 scheme = self.config.TaskScheduler.scheme_name
272 263 else:
273 264 from .scheduler import TaskScheduler
274 265 scheme = TaskScheduler.scheme_name.get_default_value()
275 266
276 267 # build connection dicts
277 268 engine = self.engine_info = {
278 269 'interface' : "%s://%s" % (self.engine_transport, self.engine_ip),
279 270 'registration' : self.regport,
280 271 'control' : self.control[1],
281 272 'mux' : self.mux[1],
282 273 'hb_ping' : self.hb[0],
283 274 'hb_pong' : self.hb[1],
284 275 'task' : self.task[1],
285 276 'iopub' : self.iopub[1],
286 277 }
287 278
288 279 client = self.client_info = {
289 280 'interface' : "%s://%s" % (self.client_transport, self.client_ip),
290 281 'registration' : self.regport,
291 282 'control' : self.control[0],
292 283 'mux' : self.mux[0],
293 284 'task' : self.task[0],
294 285 'task_scheme' : scheme,
295 286 'iopub' : self.iopub[0],
296 287 'notification' : self.notifier_port,
297 288 }
298 289
299 290 self.log.debug("Hub engine addrs: %s", self.engine_info)
300 291 self.log.debug("Hub client addrs: %s", self.client_info)
301 292
302 293 # Registrar socket
303 294 q = ZMQStream(ctx.socket(zmq.ROUTER), loop)
304 295 util.set_hwm(q, 0)
305 296 q.bind(self.client_url('registration'))
306 297 self.log.info("Hub listening on %s for registration.", self.client_url('registration'))
307 298 if self.client_ip != self.engine_ip:
308 299 q.bind(self.engine_url('registration'))
309 300 self.log.info("Hub listening on %s for registration.", self.engine_url('registration'))
310 301
311 302 ### Engine connections ###
312 303
313 304 # heartbeat
314 305 hpub = ctx.socket(zmq.PUB)
315 306 hpub.bind(self.engine_url('hb_ping'))
316 307 hrep = ctx.socket(zmq.ROUTER)
317 308 util.set_hwm(hrep, 0)
318 309 hrep.bind(self.engine_url('hb_pong'))
319 310 self.heartmonitor = HeartMonitor(loop=loop, parent=self, log=self.log,
320 311 pingstream=ZMQStream(hpub,loop),
321 312 pongstream=ZMQStream(hrep,loop)
322 313 )
323 314
324 315 ### Client connections ###
325 316
326 317 # Notifier socket
327 318 n = ZMQStream(ctx.socket(zmq.PUB), loop)
328 319 n.bind(self.client_url('notification'))
329 320
330 321 ### build and launch the queues ###
331 322
332 323 # monitor socket
333 324 sub = ctx.socket(zmq.SUB)
334 325 sub.setsockopt(zmq.SUBSCRIBE, b"")
335 326 sub.bind(self.monitor_url)
336 327 sub.bind('inproc://monitor')
337 328 sub = ZMQStream(sub, loop)
338 329
339 330 # connect the db
340 331 db_class = _db_shortcuts.get(self.db_class.lower(), self.db_class)
341 332 self.log.info('Hub using DB backend: %r', (db_class.split('.')[-1]))
342 333 self.db = import_item(str(db_class))(session=self.session.session,
343 334 parent=self, log=self.log)
344 335 time.sleep(.25)
345 336
346 337 # resubmit stream
347 338 r = ZMQStream(ctx.socket(zmq.DEALER), loop)
348 339 url = util.disambiguate_url(self.client_url('task'))
349 340 r.connect(url)
350 341
351 342 # convert seconds to msec
352 343 registration_timeout = 1000*self.registration_timeout
353 344
354 345 self.hub = Hub(loop=loop, session=self.session, monitor=sub, heartmonitor=self.heartmonitor,
355 346 query=q, notifier=n, resubmit=r, db=self.db,
356 347 engine_info=self.engine_info, client_info=self.client_info,
357 348 log=self.log, registration_timeout=registration_timeout)
358 349
359 350
360 351 class Hub(SessionFactory):
361 352 """The IPython Controller Hub with 0MQ connections
362 353
363 354 Parameters
364 355 ==========
365 356 loop: zmq IOLoop instance
366 357 session: Session object
367 358 <removed> context: zmq context for creating new connections (?)
368 359 queue: ZMQStream for monitoring the command queue (SUB)
369 360 query: ZMQStream for engine registration and client queries requests (ROUTER)
370 361 heartbeat: HeartMonitor object checking the pulse of the engines
371 362 notifier: ZMQStream for broadcasting engine registration changes (PUB)
372 363 db: connection to db for out of memory logging of commands
373 364 NotImplemented
374 365 engine_info: dict of zmq connection information for engines to connect
375 366 to the queues.
376 367 client_info: dict of zmq connection information for engines to connect
377 368 to the queues.
378 369 """
379 370
380 371 engine_state_file = Unicode()
381 372
382 373 # internal data structures:
383 374 ids=Set() # engine IDs
384 375 keytable=Dict()
385 376 by_ident=Dict()
386 377 engines=Dict()
387 378 clients=Dict()
388 379 hearts=Dict()
389 380 pending=Set()
390 381 queues=Dict() # pending msg_ids keyed by engine_id
391 382 tasks=Dict() # pending msg_ids submitted as tasks, keyed by client_id
392 383 completed=Dict() # completed msg_ids keyed by engine_id
393 384 all_completed=Set() # completed msg_ids keyed by engine_id
394 385 dead_engines=Set() # completed msg_ids keyed by engine_id
395 386 unassigned=Set() # set of task msg_ds not yet assigned a destination
396 387 incoming_registrations=Dict()
397 388 registration_timeout=Integer()
398 389 _idcounter=Integer(0)
399 390
400 391 # objects from constructor:
401 392 query=Instance(ZMQStream)
402 393 monitor=Instance(ZMQStream)
403 394 notifier=Instance(ZMQStream)
404 395 resubmit=Instance(ZMQStream)
405 396 heartmonitor=Instance(HeartMonitor)
406 397 db=Instance(object)
407 398 client_info=Dict()
408 399 engine_info=Dict()
409 400
410 401
411 402 def __init__(self, **kwargs):
412 403 """
413 404 # universal:
414 405 loop: IOLoop for creating future connections
415 406 session: streamsession for sending serialized data
416 407 # engine:
417 408 queue: ZMQStream for monitoring queue messages
418 409 query: ZMQStream for engine+client registration and client requests
419 410 heartbeat: HeartMonitor object for tracking engines
420 411 # extra:
421 412 db: ZMQStream for db connection (NotImplemented)
422 413 engine_info: zmq address/protocol dict for engine connections
423 414 client_info: zmq address/protocol dict for client connections
424 415 """
425 416
426 417 super(Hub, self).__init__(**kwargs)
427 418
428 419 # register our callbacks
429 420 self.query.on_recv(self.dispatch_query)
430 421 self.monitor.on_recv(self.dispatch_monitor_traffic)
431 422
432 423 self.heartmonitor.add_heart_failure_handler(self.handle_heart_failure)
433 424 self.heartmonitor.add_new_heart_handler(self.handle_new_heart)
434 425
435 426 self.monitor_handlers = {b'in' : self.save_queue_request,
436 427 b'out': self.save_queue_result,
437 428 b'intask': self.save_task_request,
438 429 b'outtask': self.save_task_result,
439 430 b'tracktask': self.save_task_destination,
440 431 b'incontrol': _passer,
441 432 b'outcontrol': _passer,
442 433 b'iopub': self.save_iopub_message,
443 434 }
444 435
445 436 self.query_handlers = {'queue_request': self.queue_status,
446 437 'result_request': self.get_results,
447 438 'history_request': self.get_history,
448 439 'db_request': self.db_query,
449 440 'purge_request': self.purge_results,
450 441 'load_request': self.check_load,
451 442 'resubmit_request': self.resubmit_task,
452 443 'shutdown_request': self.shutdown_request,
453 444 'registration_request' : self.register_engine,
454 445 'unregistration_request' : self.unregister_engine,
455 446 'connection_request': self.connection_request,
456 447 }
457 448
458 449 # ignore resubmit replies
459 450 self.resubmit.on_recv(lambda msg: None, copy=False)
460 451
461 452 self.log.info("hub::created hub")
462 453
463 454 @property
464 455 def _next_id(self):
465 456 """gemerate a new ID.
466 457
467 458 No longer reuse old ids, just count from 0."""
468 459 newid = self._idcounter
469 460 self._idcounter += 1
470 461 return newid
471 462 # newid = 0
472 463 # incoming = [id[0] for id in itervalues(self.incoming_registrations)]
473 464 # # print newid, self.ids, self.incoming_registrations
474 465 # while newid in self.ids or newid in incoming:
475 466 # newid += 1
476 467 # return newid
477 468
478 469 #-----------------------------------------------------------------------------
479 470 # message validation
480 471 #-----------------------------------------------------------------------------
481 472
482 473 def _validate_targets(self, targets):
483 474 """turn any valid targets argument into a list of integer ids"""
484 475 if targets is None:
485 476 # default to all
486 477 return self.ids
487 478
488 479 if isinstance(targets, (int,str,unicode_type)):
489 480 # only one target specified
490 481 targets = [targets]
491 482 _targets = []
492 483 for t in targets:
493 484 # map raw identities to ids
494 485 if isinstance(t, (str,unicode_type)):
495 486 t = self.by_ident.get(cast_bytes(t), t)
496 487 _targets.append(t)
497 488 targets = _targets
498 489 bad_targets = [ t for t in targets if t not in self.ids ]
499 490 if bad_targets:
500 491 raise IndexError("No Such Engine: %r" % bad_targets)
501 492 if not targets:
502 493 raise IndexError("No Engines Registered")
503 494 return targets
504 495
505 496 #-----------------------------------------------------------------------------
506 497 # dispatch methods (1 per stream)
507 498 #-----------------------------------------------------------------------------
508 499
509 500
510 501 @util.log_errors
511 502 def dispatch_monitor_traffic(self, msg):
512 503 """all ME and Task queue messages come through here, as well as
513 504 IOPub traffic."""
514 505 self.log.debug("monitor traffic: %r", msg[0])
515 506 switch = msg[0]
516 507 try:
517 508 idents, msg = self.session.feed_identities(msg[1:])
518 509 except ValueError:
519 510 idents=[]
520 511 if not idents:
521 512 self.log.error("Monitor message without topic: %r", msg)
522 513 return
523 514 handler = self.monitor_handlers.get(switch, None)
524 515 if handler is not None:
525 516 handler(idents, msg)
526 517 else:
527 518 self.log.error("Unrecognized monitor topic: %r", switch)
528 519
529 520
530 521 @util.log_errors
531 522 def dispatch_query(self, msg):
532 523 """Route registration requests and queries from clients."""
533 524 try:
534 525 idents, msg = self.session.feed_identities(msg)
535 526 except ValueError:
536 527 idents = []
537 528 if not idents:
538 529 self.log.error("Bad Query Message: %r", msg)
539 530 return
540 531 client_id = idents[0]
541 532 try:
542 533 msg = self.session.unserialize(msg, content=True)
543 534 except Exception:
544 535 content = error.wrap_exception()
545 536 self.log.error("Bad Query Message: %r", msg, exc_info=True)
546 537 self.session.send(self.query, "hub_error", ident=client_id,
547 538 content=content)
548 539 return
549 540 # print client_id, header, parent, content
550 541 #switch on message type:
551 542 msg_type = msg['header']['msg_type']
552 543 self.log.info("client::client %r requested %r", client_id, msg_type)
553 544 handler = self.query_handlers.get(msg_type, None)
554 545 try:
555 546 assert handler is not None, "Bad Message Type: %r" % msg_type
556 547 except:
557 548 content = error.wrap_exception()
558 549 self.log.error("Bad Message Type: %r", msg_type, exc_info=True)
559 550 self.session.send(self.query, "hub_error", ident=client_id,
560 551 content=content)
561 552 return
562 553
563 554 else:
564 555 handler(idents, msg)
565 556
566 557 def dispatch_db(self, msg):
567 558 """"""
568 559 raise NotImplementedError
569 560
570 561 #---------------------------------------------------------------------------
571 562 # handler methods (1 per event)
572 563 #---------------------------------------------------------------------------
573 564
574 565 #----------------------- Heartbeat --------------------------------------
575 566
576 567 def handle_new_heart(self, heart):
577 568 """handler to attach to heartbeater.
578 569 Called when a new heart starts to beat.
579 570 Triggers completion of registration."""
580 571 self.log.debug("heartbeat::handle_new_heart(%r)", heart)
581 572 if heart not in self.incoming_registrations:
582 573 self.log.info("heartbeat::ignoring new heart: %r", heart)
583 574 else:
584 575 self.finish_registration(heart)
585 576
586 577
587 578 def handle_heart_failure(self, heart):
588 579 """handler to attach to heartbeater.
589 580 called when a previously registered heart fails to respond to beat request.
590 581 triggers unregistration"""
591 582 self.log.debug("heartbeat::handle_heart_failure(%r)", heart)
592 583 eid = self.hearts.get(heart, None)
593 584 uuid = self.engines[eid].uuid
594 585 if eid is None or self.keytable[eid] in self.dead_engines:
595 586 self.log.info("heartbeat::ignoring heart failure %r (not an engine or already dead)", heart)
596 587 else:
597 588 self.unregister_engine(heart, dict(content=dict(id=eid, queue=uuid)))
598 589
599 590 #----------------------- MUX Queue Traffic ------------------------------
600 591
601 592 def save_queue_request(self, idents, msg):
602 593 if len(idents) < 2:
603 594 self.log.error("invalid identity prefix: %r", idents)
604 595 return
605 596 queue_id, client_id = idents[:2]
606 597 try:
607 598 msg = self.session.unserialize(msg)
608 599 except Exception:
609 600 self.log.error("queue::client %r sent invalid message to %r: %r", client_id, queue_id, msg, exc_info=True)
610 601 return
611 602
612 603 eid = self.by_ident.get(queue_id, None)
613 604 if eid is None:
614 605 self.log.error("queue::target %r not registered", queue_id)
615 606 self.log.debug("queue:: valid are: %r", self.by_ident.keys())
616 607 return
617 608 record = init_record(msg)
618 609 msg_id = record['msg_id']
619 610 self.log.info("queue::client %r submitted request %r to %s", client_id, msg_id, eid)
620 611 # Unicode in records
621 612 record['engine_uuid'] = queue_id.decode('ascii')
622 613 record['client_uuid'] = msg['header']['session']
623 614 record['queue'] = 'mux'
624 615
625 616 try:
626 617 # it's posible iopub arrived first:
627 618 existing = self.db.get_record(msg_id)
628 619 for key,evalue in iteritems(existing):
629 620 rvalue = record.get(key, None)
630 621 if evalue and rvalue and evalue != rvalue:
631 622 self.log.warn("conflicting initial state for record: %r:%r <%r> %r", msg_id, rvalue, key, evalue)
632 623 elif evalue and not rvalue:
633 624 record[key] = evalue
634 625 try:
635 626 self.db.update_record(msg_id, record)
636 627 except Exception:
637 628 self.log.error("DB Error updating record %r", msg_id, exc_info=True)
638 629 except KeyError:
639 630 try:
640 631 self.db.add_record(msg_id, record)
641 632 except Exception:
642 633 self.log.error("DB Error adding record %r", msg_id, exc_info=True)
643 634
644 635
645 636 self.pending.add(msg_id)
646 637 self.queues[eid].append(msg_id)
647 638
648 639 def save_queue_result(self, idents, msg):
649 640 if len(idents) < 2:
650 641 self.log.error("invalid identity prefix: %r", idents)
651 642 return
652 643
653 644 client_id, queue_id = idents[:2]
654 645 try:
655 646 msg = self.session.unserialize(msg)
656 647 except Exception:
657 648 self.log.error("queue::engine %r sent invalid message to %r: %r",
658 649 queue_id, client_id, msg, exc_info=True)
659 650 return
660 651
661 652 eid = self.by_ident.get(queue_id, None)
662 653 if eid is None:
663 654 self.log.error("queue::unknown engine %r is sending a reply: ", queue_id)
664 655 return
665 656
666 657 parent = msg['parent_header']
667 658 if not parent:
668 659 return
669 660 msg_id = parent['msg_id']
670 661 if msg_id in self.pending:
671 662 self.pending.remove(msg_id)
672 663 self.all_completed.add(msg_id)
673 664 self.queues[eid].remove(msg_id)
674 665 self.completed[eid].append(msg_id)
675 666 self.log.info("queue::request %r completed on %s", msg_id, eid)
676 667 elif msg_id not in self.all_completed:
677 668 # it could be a result from a dead engine that died before delivering the
678 669 # result
679 670 self.log.warn("queue:: unknown msg finished %r", msg_id)
680 671 return
681 672 # update record anyway, because the unregistration could have been premature
682 673 rheader = msg['header']
683 674 md = msg['metadata']
684 675 completed = rheader['date']
685 676 started = extract_dates(md.get('started', None))
686 677 result = {
687 678 'result_header' : rheader,
688 679 'result_metadata': md,
689 680 'result_content': msg['content'],
690 681 'received': datetime.now(),
691 682 'started' : started,
692 683 'completed' : completed
693 684 }
694 685
695 686 result['result_buffers'] = msg['buffers']
696 687 try:
697 688 self.db.update_record(msg_id, result)
698 689 except Exception:
699 690 self.log.error("DB Error updating record %r", msg_id, exc_info=True)
700 691
701 692
702 693 #--------------------- Task Queue Traffic ------------------------------
703 694
704 695 def save_task_request(self, idents, msg):
705 696 """Save the submission of a task."""
706 697 client_id = idents[0]
707 698
708 699 try:
709 700 msg = self.session.unserialize(msg)
710 701 except Exception:
711 702 self.log.error("task::client %r sent invalid task message: %r",
712 703 client_id, msg, exc_info=True)
713 704 return
714 705 record = init_record(msg)
715 706
716 707 record['client_uuid'] = msg['header']['session']
717 708 record['queue'] = 'task'
718 709 header = msg['header']
719 710 msg_id = header['msg_id']
720 711 self.pending.add(msg_id)
721 712 self.unassigned.add(msg_id)
722 713 try:
723 714 # it's posible iopub arrived first:
724 715 existing = self.db.get_record(msg_id)
725 716 if existing['resubmitted']:
726 717 for key in ('submitted', 'client_uuid', 'buffers'):
727 718 # don't clobber these keys on resubmit
728 719 # submitted and client_uuid should be different
729 720 # and buffers might be big, and shouldn't have changed
730 721 record.pop(key)
731 722 # still check content,header which should not change
732 723 # but are not expensive to compare as buffers
733 724
734 725 for key,evalue in iteritems(existing):
735 726 if key.endswith('buffers'):
736 727 # don't compare buffers
737 728 continue
738 729 rvalue = record.get(key, None)
739 730 if evalue and rvalue and evalue != rvalue:
740 731 self.log.warn("conflicting initial state for record: %r:%r <%r> %r", msg_id, rvalue, key, evalue)
741 732 elif evalue and not rvalue:
742 733 record[key] = evalue
743 734 try:
744 735 self.db.update_record(msg_id, record)
745 736 except Exception:
746 737 self.log.error("DB Error updating record %r", msg_id, exc_info=True)
747 738 except KeyError:
748 739 try:
749 740 self.db.add_record(msg_id, record)
750 741 except Exception:
751 742 self.log.error("DB Error adding record %r", msg_id, exc_info=True)
752 743 except Exception:
753 744 self.log.error("DB Error saving task request %r", msg_id, exc_info=True)
754 745
755 746 def save_task_result(self, idents, msg):
756 747 """save the result of a completed task."""
757 748 client_id = idents[0]
758 749 try:
759 750 msg = self.session.unserialize(msg)
760 751 except Exception:
761 752 self.log.error("task::invalid task result message send to %r: %r",
762 753 client_id, msg, exc_info=True)
763 754 return
764 755
765 756 parent = msg['parent_header']
766 757 if not parent:
767 758 # print msg
768 759 self.log.warn("Task %r had no parent!", msg)
769 760 return
770 761 msg_id = parent['msg_id']
771 762 if msg_id in self.unassigned:
772 763 self.unassigned.remove(msg_id)
773 764
774 765 header = msg['header']
775 766 md = msg['metadata']
776 767 engine_uuid = md.get('engine', u'')
777 768 eid = self.by_ident.get(cast_bytes(engine_uuid), None)
778 769
779 770 status = md.get('status', None)
780 771
781 772 if msg_id in self.pending:
782 773 self.log.info("task::task %r finished on %s", msg_id, eid)
783 774 self.pending.remove(msg_id)
784 775 self.all_completed.add(msg_id)
785 776 if eid is not None:
786 777 if status != 'aborted':
787 778 self.completed[eid].append(msg_id)
788 779 if msg_id in self.tasks[eid]:
789 780 self.tasks[eid].remove(msg_id)
790 781 completed = header['date']
791 782 started = extract_dates(md.get('started', None))
792 783 result = {
793 784 'result_header' : header,
794 785 'result_metadata': msg['metadata'],
795 786 'result_content': msg['content'],
796 787 'started' : started,
797 788 'completed' : completed,
798 789 'received' : datetime.now(),
799 790 'engine_uuid': engine_uuid,
800 791 }
801 792
802 793 result['result_buffers'] = msg['buffers']
803 794 try:
804 795 self.db.update_record(msg_id, result)
805 796 except Exception:
806 797 self.log.error("DB Error saving task request %r", msg_id, exc_info=True)
807 798
808 799 else:
809 800 self.log.debug("task::unknown task %r finished", msg_id)
810 801
811 802 def save_task_destination(self, idents, msg):
812 803 try:
813 804 msg = self.session.unserialize(msg, content=True)
814 805 except Exception:
815 806 self.log.error("task::invalid task tracking message", exc_info=True)
816 807 return
817 808 content = msg['content']
818 809 # print (content)
819 810 msg_id = content['msg_id']
820 811 engine_uuid = content['engine_id']
821 812 eid = self.by_ident[cast_bytes(engine_uuid)]
822 813
823 814 self.log.info("task::task %r arrived on %r", msg_id, eid)
824 815 if msg_id in self.unassigned:
825 816 self.unassigned.remove(msg_id)
826 817 # else:
827 818 # self.log.debug("task::task %r not listed as MIA?!"%(msg_id))
828 819
829 820 self.tasks[eid].append(msg_id)
830 821 # self.pending[msg_id][1].update(received=datetime.now(),engine=(eid,engine_uuid))
831 822 try:
832 823 self.db.update_record(msg_id, dict(engine_uuid=engine_uuid))
833 824 except Exception:
834 825 self.log.error("DB Error saving task destination %r", msg_id, exc_info=True)
835 826
836 827
837 828 def mia_task_request(self, idents, msg):
838 829 raise NotImplementedError
839 830 client_id = idents[0]
840 831 # content = dict(mia=self.mia,status='ok')
841 832 # self.session.send('mia_reply', content=content, idents=client_id)
842 833
843 834
844 835 #--------------------- IOPub Traffic ------------------------------
845 836
846 837 def save_iopub_message(self, topics, msg):
847 838 """save an iopub message into the db"""
848 839 # print (topics)
849 840 try:
850 841 msg = self.session.unserialize(msg, content=True)
851 842 except Exception:
852 843 self.log.error("iopub::invalid IOPub message", exc_info=True)
853 844 return
854 845
855 846 parent = msg['parent_header']
856 847 if not parent:
857 848 self.log.warn("iopub::IOPub message lacks parent: %r", msg)
858 849 return
859 850 msg_id = parent['msg_id']
860 851 msg_type = msg['header']['msg_type']
861 852 content = msg['content']
862 853
863 854 # ensure msg_id is in db
864 855 try:
865 856 rec = self.db.get_record(msg_id)
866 857 except KeyError:
867 858 rec = empty_record()
868 859 rec['msg_id'] = msg_id
869 860 self.db.add_record(msg_id, rec)
870 861 # stream
871 862 d = {}
872 863 if msg_type == 'stream':
873 864 name = content['name']
874 865 s = rec[name] or ''
875 866 d[name] = s + content['data']
876 867
877 868 elif msg_type == 'pyerr':
878 869 d['pyerr'] = content
879 elif msg_type == 'pyin':
880 d['pyin'] = content['code']
870 elif msg_type == 'execute_input':
871 d['execute_input'] = content['code']
881 872 elif msg_type in ('display_data', 'pyout'):
882 873 d[msg_type] = content
883 874 elif msg_type == 'status':
884 875 pass
885 876 elif msg_type == 'data_pub':
886 877 self.log.info("ignored data_pub message for %s" % msg_id)
887 878 else:
888 879 self.log.warn("unhandled iopub msg_type: %r", msg_type)
889 880
890 881 if not d:
891 882 return
892 883
893 884 try:
894 885 self.db.update_record(msg_id, d)
895 886 except Exception:
896 887 self.log.error("DB Error saving iopub message %r", msg_id, exc_info=True)
897 888
898 889
899 890
900 891 #-------------------------------------------------------------------------
901 892 # Registration requests
902 893 #-------------------------------------------------------------------------
903 894
904 895 def connection_request(self, client_id, msg):
905 896 """Reply with connection addresses for clients."""
906 897 self.log.info("client::client %r connected", client_id)
907 898 content = dict(status='ok')
908 899 jsonable = {}
909 900 for k,v in iteritems(self.keytable):
910 901 if v not in self.dead_engines:
911 902 jsonable[str(k)] = v
912 903 content['engines'] = jsonable
913 904 self.session.send(self.query, 'connection_reply', content, parent=msg, ident=client_id)
914 905
915 906 def register_engine(self, reg, msg):
916 907 """Register a new engine."""
917 908 content = msg['content']
918 909 try:
919 910 uuid = content['uuid']
920 911 except KeyError:
921 912 self.log.error("registration::queue not specified", exc_info=True)
922 913 return
923 914
924 915 eid = self._next_id
925 916
926 917 self.log.debug("registration::register_engine(%i, %r)", eid, uuid)
927 918
928 919 content = dict(id=eid,status='ok',hb_period=self.heartmonitor.period)
929 920 # check if requesting available IDs:
930 921 if cast_bytes(uuid) in self.by_ident:
931 922 try:
932 923 raise KeyError("uuid %r in use" % uuid)
933 924 except:
934 925 content = error.wrap_exception()
935 926 self.log.error("uuid %r in use", uuid, exc_info=True)
936 927 else:
937 928 for h, ec in iteritems(self.incoming_registrations):
938 929 if uuid == h:
939 930 try:
940 931 raise KeyError("heart_id %r in use" % uuid)
941 932 except:
942 933 self.log.error("heart_id %r in use", uuid, exc_info=True)
943 934 content = error.wrap_exception()
944 935 break
945 936 elif uuid == ec.uuid:
946 937 try:
947 938 raise KeyError("uuid %r in use" % uuid)
948 939 except:
949 940 self.log.error("uuid %r in use", uuid, exc_info=True)
950 941 content = error.wrap_exception()
951 942 break
952 943
953 944 msg = self.session.send(self.query, "registration_reply",
954 945 content=content,
955 946 ident=reg)
956 947
957 948 heart = cast_bytes(uuid)
958 949
959 950 if content['status'] == 'ok':
960 951 if heart in self.heartmonitor.hearts:
961 952 # already beating
962 953 self.incoming_registrations[heart] = EngineConnector(id=eid,uuid=uuid)
963 954 self.finish_registration(heart)
964 955 else:
965 956 purge = lambda : self._purge_stalled_registration(heart)
966 957 dc = ioloop.DelayedCallback(purge, self.registration_timeout, self.loop)
967 958 dc.start()
968 959 self.incoming_registrations[heart] = EngineConnector(id=eid,uuid=uuid,stallback=dc)
969 960 else:
970 961 self.log.error("registration::registration %i failed: %r", eid, content['evalue'])
971 962
972 963 return eid
973 964
974 965 def unregister_engine(self, ident, msg):
975 966 """Unregister an engine that explicitly requested to leave."""
976 967 try:
977 968 eid = msg['content']['id']
978 969 except:
979 970 self.log.error("registration::bad engine id for unregistration: %r", ident, exc_info=True)
980 971 return
981 972 self.log.info("registration::unregister_engine(%r)", eid)
982 973 # print (eid)
983 974 uuid = self.keytable[eid]
984 975 content=dict(id=eid, uuid=uuid)
985 976 self.dead_engines.add(uuid)
986 977 # self.ids.remove(eid)
987 978 # uuid = self.keytable.pop(eid)
988 979 #
989 980 # ec = self.engines.pop(eid)
990 981 # self.hearts.pop(ec.heartbeat)
991 982 # self.by_ident.pop(ec.queue)
992 983 # self.completed.pop(eid)
993 984 handleit = lambda : self._handle_stranded_msgs(eid, uuid)
994 985 dc = ioloop.DelayedCallback(handleit, self.registration_timeout, self.loop)
995 986 dc.start()
996 987 ############## TODO: HANDLE IT ################
997 988
998 989 self._save_engine_state()
999 990
1000 991 if self.notifier:
1001 992 self.session.send(self.notifier, "unregistration_notification", content=content)
1002 993
1003 994 def _handle_stranded_msgs(self, eid, uuid):
1004 995 """Handle messages known to be on an engine when the engine unregisters.
1005 996
1006 997 It is possible that this will fire prematurely - that is, an engine will
1007 998 go down after completing a result, and the client will be notified
1008 999 that the result failed and later receive the actual result.
1009 1000 """
1010 1001
1011 1002 outstanding = self.queues[eid]
1012 1003
1013 1004 for msg_id in outstanding:
1014 1005 self.pending.remove(msg_id)
1015 1006 self.all_completed.add(msg_id)
1016 1007 try:
1017 1008 raise error.EngineError("Engine %r died while running task %r" % (eid, msg_id))
1018 1009 except:
1019 1010 content = error.wrap_exception()
1020 1011 # build a fake header:
1021 1012 header = {}
1022 1013 header['engine'] = uuid
1023 1014 header['date'] = datetime.now()
1024 1015 rec = dict(result_content=content, result_header=header, result_buffers=[])
1025 1016 rec['completed'] = header['date']
1026 1017 rec['engine_uuid'] = uuid
1027 1018 try:
1028 1019 self.db.update_record(msg_id, rec)
1029 1020 except Exception:
1030 1021 self.log.error("DB Error handling stranded msg %r", msg_id, exc_info=True)
1031 1022
1032 1023
1033 1024 def finish_registration(self, heart):
1034 1025 """Second half of engine registration, called after our HeartMonitor
1035 1026 has received a beat from the Engine's Heart."""
1036 1027 try:
1037 1028 ec = self.incoming_registrations.pop(heart)
1038 1029 except KeyError:
1039 1030 self.log.error("registration::tried to finish nonexistant registration", exc_info=True)
1040 1031 return
1041 1032 self.log.info("registration::finished registering engine %i:%s", ec.id, ec.uuid)
1042 1033 if ec.stallback is not None:
1043 1034 ec.stallback.stop()
1044 1035 eid = ec.id
1045 1036 self.ids.add(eid)
1046 1037 self.keytable[eid] = ec.uuid
1047 1038 self.engines[eid] = ec
1048 1039 self.by_ident[cast_bytes(ec.uuid)] = ec.id
1049 1040 self.queues[eid] = list()
1050 1041 self.tasks[eid] = list()
1051 1042 self.completed[eid] = list()
1052 1043 self.hearts[heart] = eid
1053 1044 content = dict(id=eid, uuid=self.engines[eid].uuid)
1054 1045 if self.notifier:
1055 1046 self.session.send(self.notifier, "registration_notification", content=content)
1056 1047 self.log.info("engine::Engine Connected: %i", eid)
1057 1048
1058 1049 self._save_engine_state()
1059 1050
1060 1051 def _purge_stalled_registration(self, heart):
1061 1052 if heart in self.incoming_registrations:
1062 1053 ec = self.incoming_registrations.pop(heart)
1063 1054 self.log.info("registration::purging stalled registration: %i", ec.id)
1064 1055 else:
1065 1056 pass
1066 1057
1067 1058 #-------------------------------------------------------------------------
1068 1059 # Engine State
1069 1060 #-------------------------------------------------------------------------
1070 1061
1071 1062
1072 1063 def _cleanup_engine_state_file(self):
1073 1064 """cleanup engine state mapping"""
1074 1065
1075 1066 if os.path.exists(self.engine_state_file):
1076 1067 self.log.debug("cleaning up engine state: %s", self.engine_state_file)
1077 1068 try:
1078 1069 os.remove(self.engine_state_file)
1079 1070 except IOError:
1080 1071 self.log.error("Couldn't cleanup file: %s", self.engine_state_file, exc_info=True)
1081 1072
1082 1073
1083 1074 def _save_engine_state(self):
1084 1075 """save engine mapping to JSON file"""
1085 1076 if not self.engine_state_file:
1086 1077 return
1087 1078 self.log.debug("save engine state to %s" % self.engine_state_file)
1088 1079 state = {}
1089 1080 engines = {}
1090 1081 for eid, ec in iteritems(self.engines):
1091 1082 if ec.uuid not in self.dead_engines:
1092 1083 engines[eid] = ec.uuid
1093 1084
1094 1085 state['engines'] = engines
1095 1086
1096 1087 state['next_id'] = self._idcounter
1097 1088
1098 1089 with open(self.engine_state_file, 'w') as f:
1099 1090 json.dump(state, f)
1100 1091
1101 1092
1102 1093 def _load_engine_state(self):
1103 1094 """load engine mapping from JSON file"""
1104 1095 if not os.path.exists(self.engine_state_file):
1105 1096 return
1106 1097
1107 1098 self.log.info("loading engine state from %s" % self.engine_state_file)
1108 1099
1109 1100 with open(self.engine_state_file) as f:
1110 1101 state = json.load(f)
1111 1102
1112 1103 save_notifier = self.notifier
1113 1104 self.notifier = None
1114 1105 for eid, uuid in iteritems(state['engines']):
1115 1106 heart = uuid.encode('ascii')
1116 1107 # start with this heart as current and beating:
1117 1108 self.heartmonitor.responses.add(heart)
1118 1109 self.heartmonitor.hearts.add(heart)
1119 1110
1120 1111 self.incoming_registrations[heart] = EngineConnector(id=int(eid), uuid=uuid)
1121 1112 self.finish_registration(heart)
1122 1113
1123 1114 self.notifier = save_notifier
1124 1115
1125 1116 self._idcounter = state['next_id']
1126 1117
1127 1118 #-------------------------------------------------------------------------
1128 1119 # Client Requests
1129 1120 #-------------------------------------------------------------------------
1130 1121
1131 1122 def shutdown_request(self, client_id, msg):
1132 1123 """handle shutdown request."""
1133 1124 self.session.send(self.query, 'shutdown_reply', content={'status': 'ok'}, ident=client_id)
1134 1125 # also notify other clients of shutdown
1135 1126 self.session.send(self.notifier, 'shutdown_notice', content={'status': 'ok'})
1136 1127 dc = ioloop.DelayedCallback(lambda : self._shutdown(), 1000, self.loop)
1137 1128 dc.start()
1138 1129
1139 1130 def _shutdown(self):
1140 1131 self.log.info("hub::hub shutting down.")
1141 1132 time.sleep(0.1)
1142 1133 sys.exit(0)
1143 1134
1144 1135
1145 1136 def check_load(self, client_id, msg):
1146 1137 content = msg['content']
1147 1138 try:
1148 1139 targets = content['targets']
1149 1140 targets = self._validate_targets(targets)
1150 1141 except:
1151 1142 content = error.wrap_exception()
1152 1143 self.session.send(self.query, "hub_error",
1153 1144 content=content, ident=client_id)
1154 1145 return
1155 1146
1156 1147 content = dict(status='ok')
1157 1148 # loads = {}
1158 1149 for t in targets:
1159 1150 content[bytes(t)] = len(self.queues[t])+len(self.tasks[t])
1160 1151 self.session.send(self.query, "load_reply", content=content, ident=client_id)
1161 1152
1162 1153
1163 1154 def queue_status(self, client_id, msg):
1164 1155 """Return the Queue status of one or more targets.
1165 1156
1166 1157 If verbose, return the msg_ids, else return len of each type.
1167 1158
1168 1159 Keys:
1169 1160
1170 1161 * queue (pending MUX jobs)
1171 1162 * tasks (pending Task jobs)
1172 1163 * completed (finished jobs from both queues)
1173 1164 """
1174 1165 content = msg['content']
1175 1166 targets = content['targets']
1176 1167 try:
1177 1168 targets = self._validate_targets(targets)
1178 1169 except:
1179 1170 content = error.wrap_exception()
1180 1171 self.session.send(self.query, "hub_error",
1181 1172 content=content, ident=client_id)
1182 1173 return
1183 1174 verbose = content.get('verbose', False)
1184 1175 content = dict(status='ok')
1185 1176 for t in targets:
1186 1177 queue = self.queues[t]
1187 1178 completed = self.completed[t]
1188 1179 tasks = self.tasks[t]
1189 1180 if not verbose:
1190 1181 queue = len(queue)
1191 1182 completed = len(completed)
1192 1183 tasks = len(tasks)
1193 1184 content[str(t)] = {'queue': queue, 'completed': completed , 'tasks': tasks}
1194 1185 content['unassigned'] = list(self.unassigned) if verbose else len(self.unassigned)
1195 1186 # print (content)
1196 1187 self.session.send(self.query, "queue_reply", content=content, ident=client_id)
1197 1188
1198 1189 def purge_results(self, client_id, msg):
1199 1190 """Purge results from memory. This method is more valuable before we move
1200 1191 to a DB based message storage mechanism."""
1201 1192 content = msg['content']
1202 1193 self.log.info("Dropping records with %s", content)
1203 1194 msg_ids = content.get('msg_ids', [])
1204 1195 reply = dict(status='ok')
1205 1196 if msg_ids == 'all':
1206 1197 try:
1207 1198 self.db.drop_matching_records(dict(completed={'$ne':None}))
1208 1199 except Exception:
1209 1200 reply = error.wrap_exception()
1210 1201 self.log.exception("Error dropping records")
1211 1202 else:
1212 1203 pending = [m for m in msg_ids if (m in self.pending)]
1213 1204 if pending:
1214 1205 try:
1215 1206 raise IndexError("msg pending: %r" % pending[0])
1216 1207 except:
1217 1208 reply = error.wrap_exception()
1218 1209 self.log.exception("Error dropping records")
1219 1210 else:
1220 1211 try:
1221 1212 self.db.drop_matching_records(dict(msg_id={'$in':msg_ids}))
1222 1213 except Exception:
1223 1214 reply = error.wrap_exception()
1224 1215 self.log.exception("Error dropping records")
1225 1216
1226 1217 if reply['status'] == 'ok':
1227 1218 eids = content.get('engine_ids', [])
1228 1219 for eid in eids:
1229 1220 if eid not in self.engines:
1230 1221 try:
1231 1222 raise IndexError("No such engine: %i" % eid)
1232 1223 except:
1233 1224 reply = error.wrap_exception()
1234 1225 self.log.exception("Error dropping records")
1235 1226 break
1236 1227 uid = self.engines[eid].uuid
1237 1228 try:
1238 1229 self.db.drop_matching_records(dict(engine_uuid=uid, completed={'$ne':None}))
1239 1230 except Exception:
1240 1231 reply = error.wrap_exception()
1241 1232 self.log.exception("Error dropping records")
1242 1233 break
1243 1234
1244 1235 self.session.send(self.query, 'purge_reply', content=reply, ident=client_id)
1245 1236
1246 1237 def resubmit_task(self, client_id, msg):
1247 1238 """Resubmit one or more tasks."""
1248 1239 def finish(reply):
1249 1240 self.session.send(self.query, 'resubmit_reply', content=reply, ident=client_id)
1250 1241
1251 1242 content = msg['content']
1252 1243 msg_ids = content['msg_ids']
1253 1244 reply = dict(status='ok')
1254 1245 try:
1255 1246 records = self.db.find_records({'msg_id' : {'$in' : msg_ids}}, keys=[
1256 1247 'header', 'content', 'buffers'])
1257 1248 except Exception:
1258 1249 self.log.error('db::db error finding tasks to resubmit', exc_info=True)
1259 1250 return finish(error.wrap_exception())
1260 1251
1261 1252 # validate msg_ids
1262 1253 found_ids = [ rec['msg_id'] for rec in records ]
1263 1254 pending_ids = [ msg_id for msg_id in found_ids if msg_id in self.pending ]
1264 1255 if len(records) > len(msg_ids):
1265 1256 try:
1266 1257 raise RuntimeError("DB appears to be in an inconsistent state."
1267 1258 "More matching records were found than should exist")
1268 1259 except Exception:
1269 1260 self.log.exception("Failed to resubmit task")
1270 1261 return finish(error.wrap_exception())
1271 1262 elif len(records) < len(msg_ids):
1272 1263 missing = [ m for m in msg_ids if m not in found_ids ]
1273 1264 try:
1274 1265 raise KeyError("No such msg(s): %r" % missing)
1275 1266 except KeyError:
1276 1267 self.log.exception("Failed to resubmit task")
1277 1268 return finish(error.wrap_exception())
1278 1269 elif pending_ids:
1279 1270 pass
1280 1271 # no need to raise on resubmit of pending task, now that we
1281 1272 # resubmit under new ID, but do we want to raise anyway?
1282 1273 # msg_id = invalid_ids[0]
1283 1274 # try:
1284 1275 # raise ValueError("Task(s) %r appears to be inflight" % )
1285 1276 # except Exception:
1286 1277 # return finish(error.wrap_exception())
1287 1278
1288 1279 # mapping of original IDs to resubmitted IDs
1289 1280 resubmitted = {}
1290 1281
1291 1282 # send the messages
1292 1283 for rec in records:
1293 1284 header = rec['header']
1294 1285 msg = self.session.msg(header['msg_type'], parent=header)
1295 1286 msg_id = msg['msg_id']
1296 1287 msg['content'] = rec['content']
1297 1288
1298 1289 # use the old header, but update msg_id and timestamp
1299 1290 fresh = msg['header']
1300 1291 header['msg_id'] = fresh['msg_id']
1301 1292 header['date'] = fresh['date']
1302 1293 msg['header'] = header
1303 1294
1304 1295 self.session.send(self.resubmit, msg, buffers=rec['buffers'])
1305 1296
1306 1297 resubmitted[rec['msg_id']] = msg_id
1307 1298 self.pending.add(msg_id)
1308 1299 msg['buffers'] = rec['buffers']
1309 1300 try:
1310 1301 self.db.add_record(msg_id, init_record(msg))
1311 1302 except Exception:
1312 1303 self.log.error("db::DB Error updating record: %s", msg_id, exc_info=True)
1313 1304 return finish(error.wrap_exception())
1314 1305
1315 1306 finish(dict(status='ok', resubmitted=resubmitted))
1316 1307
1317 1308 # store the new IDs in the Task DB
1318 1309 for msg_id, resubmit_id in iteritems(resubmitted):
1319 1310 try:
1320 1311 self.db.update_record(msg_id, {'resubmitted' : resubmit_id})
1321 1312 except Exception:
1322 1313 self.log.error("db::DB Error updating record: %s", msg_id, exc_info=True)
1323 1314
1324 1315
1325 1316 def _extract_record(self, rec):
1326 1317 """decompose a TaskRecord dict into subsection of reply for get_result"""
1327 1318 io_dict = {}
1328 for key in ('pyin', 'pyout', 'pyerr', 'stdout', 'stderr'):
1319 for key in ('execute_input', 'pyout', 'pyerr', 'stdout', 'stderr'):
1329 1320 io_dict[key] = rec[key]
1330 1321 content = {
1331 1322 'header': rec['header'],
1332 1323 'metadata': rec['metadata'],
1333 1324 'result_metadata': rec['result_metadata'],
1334 1325 'result_header' : rec['result_header'],
1335 1326 'result_content': rec['result_content'],
1336 1327 'received' : rec['received'],
1337 1328 'io' : io_dict,
1338 1329 }
1339 1330 if rec['result_buffers']:
1340 1331 buffers = list(map(bytes, rec['result_buffers']))
1341 1332 else:
1342 1333 buffers = []
1343 1334
1344 1335 return content, buffers
1345 1336
1346 1337 def get_results(self, client_id, msg):
1347 1338 """Get the result of 1 or more messages."""
1348 1339 content = msg['content']
1349 1340 msg_ids = sorted(set(content['msg_ids']))
1350 1341 statusonly = content.get('status_only', False)
1351 1342 pending = []
1352 1343 completed = []
1353 1344 content = dict(status='ok')
1354 1345 content['pending'] = pending
1355 1346 content['completed'] = completed
1356 1347 buffers = []
1357 1348 if not statusonly:
1358 1349 try:
1359 1350 matches = self.db.find_records(dict(msg_id={'$in':msg_ids}))
1360 1351 # turn match list into dict, for faster lookup
1361 1352 records = {}
1362 1353 for rec in matches:
1363 1354 records[rec['msg_id']] = rec
1364 1355 except Exception:
1365 1356 content = error.wrap_exception()
1366 1357 self.log.exception("Failed to get results")
1367 1358 self.session.send(self.query, "result_reply", content=content,
1368 1359 parent=msg, ident=client_id)
1369 1360 return
1370 1361 else:
1371 1362 records = {}
1372 1363 for msg_id in msg_ids:
1373 1364 if msg_id in self.pending:
1374 1365 pending.append(msg_id)
1375 1366 elif msg_id in self.all_completed:
1376 1367 completed.append(msg_id)
1377 1368 if not statusonly:
1378 1369 c,bufs = self._extract_record(records[msg_id])
1379 1370 content[msg_id] = c
1380 1371 buffers.extend(bufs)
1381 1372 elif msg_id in records:
1382 1373 if rec['completed']:
1383 1374 completed.append(msg_id)
1384 1375 c,bufs = self._extract_record(records[msg_id])
1385 1376 content[msg_id] = c
1386 1377 buffers.extend(bufs)
1387 1378 else:
1388 1379 pending.append(msg_id)
1389 1380 else:
1390 1381 try:
1391 1382 raise KeyError('No such message: '+msg_id)
1392 1383 except:
1393 1384 content = error.wrap_exception()
1394 1385 break
1395 1386 self.session.send(self.query, "result_reply", content=content,
1396 1387 parent=msg, ident=client_id,
1397 1388 buffers=buffers)
1398 1389
1399 1390 def get_history(self, client_id, msg):
1400 1391 """Get a list of all msg_ids in our DB records"""
1401 1392 try:
1402 1393 msg_ids = self.db.get_history()
1403 1394 except Exception as e:
1404 1395 content = error.wrap_exception()
1405 1396 self.log.exception("Failed to get history")
1406 1397 else:
1407 1398 content = dict(status='ok', history=msg_ids)
1408 1399
1409 1400 self.session.send(self.query, "history_reply", content=content,
1410 1401 parent=msg, ident=client_id)
1411 1402
1412 1403 def db_query(self, client_id, msg):
1413 1404 """Perform a raw query on the task record database."""
1414 1405 content = msg['content']
1415 1406 query = extract_dates(content.get('query', {}))
1416 1407 keys = content.get('keys', None)
1417 1408 buffers = []
1418 1409 empty = list()
1419 1410 try:
1420 1411 records = self.db.find_records(query, keys)
1421 1412 except Exception as e:
1422 1413 content = error.wrap_exception()
1423 1414 self.log.exception("DB query failed")
1424 1415 else:
1425 1416 # extract buffers from reply content:
1426 1417 if keys is not None:
1427 1418 buffer_lens = [] if 'buffers' in keys else None
1428 1419 result_buffer_lens = [] if 'result_buffers' in keys else None
1429 1420 else:
1430 1421 buffer_lens = None
1431 1422 result_buffer_lens = None
1432 1423
1433 1424 for rec in records:
1434 1425 # buffers may be None, so double check
1435 1426 b = rec.pop('buffers', empty) or empty
1436 1427 if buffer_lens is not None:
1437 1428 buffer_lens.append(len(b))
1438 1429 buffers.extend(b)
1439 1430 rb = rec.pop('result_buffers', empty) or empty
1440 1431 if result_buffer_lens is not None:
1441 1432 result_buffer_lens.append(len(rb))
1442 1433 buffers.extend(rb)
1443 1434 content = dict(status='ok', records=records, buffer_lens=buffer_lens,
1444 1435 result_buffer_lens=result_buffer_lens)
1445 1436 # self.log.debug (content)
1446 1437 self.session.send(self.query, "db_reply", content=content,
1447 1438 parent=msg, ident=client_id,
1448 1439 buffers=buffers)
1449 1440
@@ -1,422 +1,414 b''
1 """A TaskRecord backend using sqlite3
1 """A TaskRecord backend using sqlite3"""
2 2
3 Authors:
4
5 * Min RK
6 """
7 #-----------------------------------------------------------------------------
8 # Copyright (C) 2011 The IPython Development Team
9 #
10 # Distributed under the terms of the BSD License. The full license is in
11 # the file COPYING, distributed as part of this software.
12 #-----------------------------------------------------------------------------
3 # Copyright (c) IPython Development Team.
4 # Distributed under the terms of the Modified BSD License.
13 5
14 6 import json
15 7 import os
16 8 try:
17 9 import cPickle as pickle
18 10 except ImportError:
19 11 import pickle
20 12 from datetime import datetime
21 13
22 14 try:
23 15 import sqlite3
24 16 except ImportError:
25 17 sqlite3 = None
26 18
27 19 from zmq.eventloop import ioloop
28 20
29 21 from IPython.utils.traitlets import Unicode, Instance, List, Dict
30 22 from .dictdb import BaseDB
31 23 from IPython.utils.jsonutil import date_default, extract_dates, squash_dates
32 24 from IPython.utils.py3compat import iteritems
33 25
34 26 #-----------------------------------------------------------------------------
35 27 # SQLite operators, adapters, and converters
36 28 #-----------------------------------------------------------------------------
37 29
38 30 try:
39 31 buffer
40 32 except NameError:
41 33 # py3k
42 34 buffer = memoryview
43 35
44 36 operators = {
45 37 '$lt' : "<",
46 38 '$gt' : ">",
47 39 # null is handled weird with ==,!=
48 40 '$eq' : "=",
49 41 '$ne' : "!=",
50 42 '$lte': "<=",
51 43 '$gte': ">=",
52 44 '$in' : ('=', ' OR '),
53 45 '$nin': ('!=', ' AND '),
54 46 # '$all': None,
55 47 # '$mod': None,
56 48 # '$exists' : None
57 49 }
58 50 null_operators = {
59 51 '=' : "IS NULL",
60 52 '!=' : "IS NOT NULL",
61 53 }
62 54
63 55 def _adapt_dict(d):
64 56 return json.dumps(d, default=date_default)
65 57
66 58 def _convert_dict(ds):
67 59 if ds is None:
68 60 return ds
69 61 else:
70 62 if isinstance(ds, bytes):
71 63 # If I understand the sqlite doc correctly, this will always be utf8
72 64 ds = ds.decode('utf8')
73 65 return extract_dates(json.loads(ds))
74 66
75 67 def _adapt_bufs(bufs):
76 68 # this is *horrible*
77 69 # copy buffers into single list and pickle it:
78 70 if bufs and isinstance(bufs[0], (bytes, buffer)):
79 71 return sqlite3.Binary(pickle.dumps(list(map(bytes, bufs)),-1))
80 72 elif bufs:
81 73 return bufs
82 74 else:
83 75 return None
84 76
85 77 def _convert_bufs(bs):
86 78 if bs is None:
87 79 return []
88 80 else:
89 81 return pickle.loads(bytes(bs))
90 82
91 83 #-----------------------------------------------------------------------------
92 84 # SQLiteDB class
93 85 #-----------------------------------------------------------------------------
94 86
95 87 class SQLiteDB(BaseDB):
96 88 """SQLite3 TaskRecord backend."""
97 89
98 90 filename = Unicode('tasks.db', config=True,
99 91 help="""The filename of the sqlite task database. [default: 'tasks.db']""")
100 92 location = Unicode('', config=True,
101 93 help="""The directory containing the sqlite task database. The default
102 94 is to use the cluster_dir location.""")
103 95 table = Unicode("ipython-tasks", config=True,
104 96 help="""The SQLite Table to use for storing tasks for this session. If unspecified,
105 97 a new table will be created with the Hub's IDENT. Specifying the table will result
106 98 in tasks from previous sessions being available via Clients' db_query and
107 99 get_result methods.""")
108 100
109 101 if sqlite3 is not None:
110 102 _db = Instance('sqlite3.Connection')
111 103 else:
112 104 _db = None
113 105 # the ordered list of column names
114 106 _keys = List(['msg_id' ,
115 107 'header' ,
116 108 'metadata',
117 109 'content',
118 110 'buffers',
119 111 'submitted',
120 112 'client_uuid' ,
121 113 'engine_uuid' ,
122 114 'started',
123 115 'completed',
124 116 'resubmitted',
125 117 'received',
126 118 'result_header' ,
127 119 'result_metadata',
128 120 'result_content' ,
129 121 'result_buffers' ,
130 122 'queue' ,
131 'pyin' ,
123 'execute_input' ,
132 124 'pyout',
133 125 'pyerr',
134 126 'stdout',
135 127 'stderr',
136 128 ])
137 129 # sqlite datatypes for checking that db is current format
138 130 _types = Dict({'msg_id' : 'text' ,
139 131 'header' : 'dict text',
140 132 'metadata' : 'dict text',
141 133 'content' : 'dict text',
142 134 'buffers' : 'bufs blob',
143 135 'submitted' : 'timestamp',
144 136 'client_uuid' : 'text',
145 137 'engine_uuid' : 'text',
146 138 'started' : 'timestamp',
147 139 'completed' : 'timestamp',
148 140 'resubmitted' : 'text',
149 141 'received' : 'timestamp',
150 142 'result_header' : 'dict text',
151 143 'result_metadata' : 'dict text',
152 144 'result_content' : 'dict text',
153 145 'result_buffers' : 'bufs blob',
154 146 'queue' : 'text',
155 'pyin' : 'text',
147 'execute_input' : 'text',
156 148 'pyout' : 'text',
157 149 'pyerr' : 'text',
158 150 'stdout' : 'text',
159 151 'stderr' : 'text',
160 152 })
161 153
162 154 def __init__(self, **kwargs):
163 155 super(SQLiteDB, self).__init__(**kwargs)
164 156 if sqlite3 is None:
165 157 raise ImportError("SQLiteDB requires sqlite3")
166 158 if not self.table:
167 159 # use session, and prefix _, since starting with # is illegal
168 160 self.table = '_'+self.session.replace('-','_')
169 161 if not self.location:
170 162 # get current profile
171 163 from IPython.core.application import BaseIPythonApplication
172 164 if BaseIPythonApplication.initialized():
173 165 app = BaseIPythonApplication.instance()
174 166 if app.profile_dir is not None:
175 167 self.location = app.profile_dir.location
176 168 else:
177 169 self.location = u'.'
178 170 else:
179 171 self.location = u'.'
180 172 self._init_db()
181 173
182 174 # register db commit as 2s periodic callback
183 175 # to prevent clogging pipes
184 176 # assumes we are being run in a zmq ioloop app
185 177 loop = ioloop.IOLoop.instance()
186 178 pc = ioloop.PeriodicCallback(self._db.commit, 2000, loop)
187 179 pc.start()
188 180
189 181 def _defaults(self, keys=None):
190 182 """create an empty record"""
191 183 d = {}
192 184 keys = self._keys if keys is None else keys
193 185 for key in keys:
194 186 d[key] = None
195 187 return d
196 188
197 189 def _check_table(self):
198 190 """Ensure that an incorrect table doesn't exist
199 191
200 192 If a bad (old) table does exist, return False
201 193 """
202 194 cursor = self._db.execute("PRAGMA table_info('%s')"%self.table)
203 195 lines = cursor.fetchall()
204 196 if not lines:
205 197 # table does not exist
206 198 return True
207 199 types = {}
208 200 keys = []
209 201 for line in lines:
210 202 keys.append(line[1])
211 203 types[line[1]] = line[2]
212 204 if self._keys != keys:
213 205 # key mismatch
214 206 self.log.warn('keys mismatch')
215 207 return False
216 208 for key in self._keys:
217 209 if types[key] != self._types[key]:
218 210 self.log.warn(
219 211 'type mismatch: %s: %s != %s'%(key,types[key],self._types[key])
220 212 )
221 213 return False
222 214 return True
223 215
224 216 def _init_db(self):
225 217 """Connect to the database and get new session number."""
226 218 # register adapters
227 219 sqlite3.register_adapter(dict, _adapt_dict)
228 220 sqlite3.register_converter('dict', _convert_dict)
229 221 sqlite3.register_adapter(list, _adapt_bufs)
230 222 sqlite3.register_converter('bufs', _convert_bufs)
231 223 # connect to the db
232 224 dbfile = os.path.join(self.location, self.filename)
233 225 self._db = sqlite3.connect(dbfile, detect_types=sqlite3.PARSE_DECLTYPES,
234 226 # isolation_level = None)#,
235 227 cached_statements=64)
236 228 # print dir(self._db)
237 229 first_table = previous_table = self.table
238 230 i=0
239 231 while not self._check_table():
240 232 i+=1
241 233 self.table = first_table+'_%i'%i
242 234 self.log.warn(
243 235 "Table %s exists and doesn't match db format, trying %s"%
244 236 (previous_table, self.table)
245 237 )
246 238 previous_table = self.table
247 239
248 240 self._db.execute("""CREATE TABLE IF NOT EXISTS '%s'
249 241 (msg_id text PRIMARY KEY,
250 242 header dict text,
251 243 metadata dict text,
252 244 content dict text,
253 245 buffers bufs blob,
254 246 submitted timestamp,
255 247 client_uuid text,
256 248 engine_uuid text,
257 249 started timestamp,
258 250 completed timestamp,
259 251 resubmitted text,
260 252 received timestamp,
261 253 result_header dict text,
262 254 result_metadata dict text,
263 255 result_content dict text,
264 256 result_buffers bufs blob,
265 257 queue text,
266 pyin text,
258 execute_input text,
267 259 pyout text,
268 260 pyerr text,
269 261 stdout text,
270 262 stderr text)
271 263 """%self.table)
272 264 self._db.commit()
273 265
274 266 def _dict_to_list(self, d):
275 267 """turn a mongodb-style record dict into a list."""
276 268
277 269 return [ d[key] for key in self._keys ]
278 270
279 271 def _list_to_dict(self, line, keys=None):
280 272 """Inverse of dict_to_list"""
281 273 keys = self._keys if keys is None else keys
282 274 d = self._defaults(keys)
283 275 for key,value in zip(keys, line):
284 276 d[key] = value
285 277
286 278 return d
287 279
288 280 def _render_expression(self, check):
289 281 """Turn a mongodb-style search dict into an SQL query."""
290 282 expressions = []
291 283 args = []
292 284
293 285 skeys = set(check.keys())
294 286 skeys.difference_update(set(self._keys))
295 287 skeys.difference_update(set(['buffers', 'result_buffers']))
296 288 if skeys:
297 289 raise KeyError("Illegal testing key(s): %s"%skeys)
298 290
299 291 for name,sub_check in iteritems(check):
300 292 if isinstance(sub_check, dict):
301 293 for test,value in iteritems(sub_check):
302 294 try:
303 295 op = operators[test]
304 296 except KeyError:
305 297 raise KeyError("Unsupported operator: %r"%test)
306 298 if isinstance(op, tuple):
307 299 op, join = op
308 300
309 301 if value is None and op in null_operators:
310 302 expr = "%s %s" % (name, null_operators[op])
311 303 else:
312 304 expr = "%s %s ?"%(name, op)
313 305 if isinstance(value, (tuple,list)):
314 306 if op in null_operators and any([v is None for v in value]):
315 307 # equality tests don't work with NULL
316 308 raise ValueError("Cannot use %r test with NULL values on SQLite backend"%test)
317 309 expr = '( %s )'%( join.join([expr]*len(value)) )
318 310 args.extend(value)
319 311 else:
320 312 args.append(value)
321 313 expressions.append(expr)
322 314 else:
323 315 # it's an equality check
324 316 if sub_check is None:
325 317 expressions.append("%s IS NULL" % name)
326 318 else:
327 319 expressions.append("%s = ?"%name)
328 320 args.append(sub_check)
329 321
330 322 expr = " AND ".join(expressions)
331 323 return expr, args
332 324
333 325 def add_record(self, msg_id, rec):
334 326 """Add a new Task Record, by msg_id."""
335 327 d = self._defaults()
336 328 d.update(rec)
337 329 d['msg_id'] = msg_id
338 330 line = self._dict_to_list(d)
339 331 tups = '(%s)'%(','.join(['?']*len(line)))
340 332 self._db.execute("INSERT INTO '%s' VALUES %s"%(self.table, tups), line)
341 333 # self._db.commit()
342 334
343 335 def get_record(self, msg_id):
344 336 """Get a specific Task Record, by msg_id."""
345 337 cursor = self._db.execute("""SELECT * FROM '%s' WHERE msg_id==?"""%self.table, (msg_id,))
346 338 line = cursor.fetchone()
347 339 if line is None:
348 340 raise KeyError("No such msg: %r"%msg_id)
349 341 return self._list_to_dict(line)
350 342
351 343 def update_record(self, msg_id, rec):
352 344 """Update the data in an existing record."""
353 345 query = "UPDATE '%s' SET "%self.table
354 346 sets = []
355 347 keys = sorted(rec.keys())
356 348 values = []
357 349 for key in keys:
358 350 sets.append('%s = ?'%key)
359 351 values.append(rec[key])
360 352 query += ', '.join(sets)
361 353 query += ' WHERE msg_id == ?'
362 354 values.append(msg_id)
363 355 self._db.execute(query, values)
364 356 # self._db.commit()
365 357
366 358 def drop_record(self, msg_id):
367 359 """Remove a record from the DB."""
368 360 self._db.execute("""DELETE FROM '%s' WHERE msg_id==?"""%self.table, (msg_id,))
369 361 # self._db.commit()
370 362
371 363 def drop_matching_records(self, check):
372 364 """Remove a record from the DB."""
373 365 expr,args = self._render_expression(check)
374 366 query = "DELETE FROM '%s' WHERE %s"%(self.table, expr)
375 367 self._db.execute(query,args)
376 368 # self._db.commit()
377 369
378 370 def find_records(self, check, keys=None):
379 371 """Find records matching a query dict, optionally extracting subset of keys.
380 372
381 373 Returns list of matching records.
382 374
383 375 Parameters
384 376 ----------
385 377
386 378 check: dict
387 379 mongodb-style query argument
388 380 keys: list of strs [optional]
389 381 if specified, the subset of keys to extract. msg_id will *always* be
390 382 included.
391 383 """
392 384 if keys:
393 385 bad_keys = [ key for key in keys if key not in self._keys ]
394 386 if bad_keys:
395 387 raise KeyError("Bad record key(s): %s"%bad_keys)
396 388
397 389 if keys:
398 390 # ensure msg_id is present and first:
399 391 if 'msg_id' in keys:
400 392 keys.remove('msg_id')
401 393 keys.insert(0, 'msg_id')
402 394 req = ', '.join(keys)
403 395 else:
404 396 req = '*'
405 397 expr,args = self._render_expression(check)
406 398 query = """SELECT %s FROM '%s' WHERE %s"""%(req, self.table, expr)
407 399 cursor = self._db.execute(query, args)
408 400 matches = cursor.fetchall()
409 401 records = []
410 402 for line in matches:
411 403 rec = self._list_to_dict(line, keys)
412 404 records.append(rec)
413 405 return records
414 406
415 407 def get_history(self):
416 408 """get all msg_ids, ordered by time submitted."""
417 409 query = """SELECT msg_id FROM '%s' ORDER by submitted ASC"""%self.table
418 410 cursor = self._db.execute(query)
419 411 # will be a list of length 1 tuples
420 412 return [ tup[0] for tup in cursor.fetchall()]
421 413
422 414 __all__ = ['SQLiteDB'] No newline at end of file
@@ -1,215 +1,215 b''
1 """ Defines a KernelManager that provides signals and slots.
2 """
1 """Defines a KernelManager that provides signals and slots."""
2
3 # Copyright (c) IPython Development Team.
4 # Distributed under the terms of the Modified BSD License.
3 5
4 # System library imports.
5 6 from IPython.external.qt import QtCore
6 7
7 # IPython imports.
8 8 from IPython.utils.traitlets import HasTraits, Type
9 9 from .util import MetaQObjectHasTraits, SuperQObject
10 10
11 11
12 12 class ChannelQObject(SuperQObject):
13 13
14 14 # Emitted when the channel is started.
15 15 started = QtCore.Signal()
16 16
17 17 # Emitted when the channel is stopped.
18 18 stopped = QtCore.Signal()
19 19
20 20 #---------------------------------------------------------------------------
21 21 # Channel interface
22 22 #---------------------------------------------------------------------------
23 23
24 24 def start(self):
25 25 """ Reimplemented to emit signal.
26 26 """
27 27 super(ChannelQObject, self).start()
28 28 self.started.emit()
29 29
30 30 def stop(self):
31 31 """ Reimplemented to emit signal.
32 32 """
33 33 super(ChannelQObject, self).stop()
34 34 self.stopped.emit()
35 35
36 36 #---------------------------------------------------------------------------
37 37 # InProcessChannel interface
38 38 #---------------------------------------------------------------------------
39 39
40 40 def call_handlers_later(self, *args, **kwds):
41 41 """ Call the message handlers later.
42 42 """
43 43 do_later = lambda: self.call_handlers(*args, **kwds)
44 44 QtCore.QTimer.singleShot(0, do_later)
45 45
46 46 def process_events(self):
47 47 """ Process any pending GUI events.
48 48 """
49 49 QtCore.QCoreApplication.instance().processEvents()
50 50
51 51
52 52 class QtShellChannelMixin(ChannelQObject):
53 53
54 54 # Emitted when any message is received.
55 55 message_received = QtCore.Signal(object)
56 56
57 57 # Emitted when a reply has been received for the corresponding request type.
58 58 execute_reply = QtCore.Signal(object)
59 59 complete_reply = QtCore.Signal(object)
60 60 object_info_reply = QtCore.Signal(object)
61 61 history_reply = QtCore.Signal(object)
62 62
63 63 #---------------------------------------------------------------------------
64 64 # 'ShellChannel' interface
65 65 #---------------------------------------------------------------------------
66 66
67 67 def call_handlers(self, msg):
68 68 """ Reimplemented to emit signals instead of making callbacks.
69 69 """
70 70 # Emit the generic signal.
71 71 self.message_received.emit(msg)
72 72
73 73 # Emit signals for specialized message types.
74 74 msg_type = msg['header']['msg_type']
75 75 signal = getattr(self, msg_type, None)
76 76 if signal:
77 77 signal.emit(msg)
78 78
79 79
80 80 class QtIOPubChannelMixin(ChannelQObject):
81 81
82 82 # Emitted when any message is received.
83 83 message_received = QtCore.Signal(object)
84 84
85 85 # Emitted when a message of type 'stream' is received.
86 86 stream_received = QtCore.Signal(object)
87 87
88 # Emitted when a message of type 'pyin' is received.
89 pyin_received = QtCore.Signal(object)
88 # Emitted when a message of type 'execute_input' is received.
89 execute_input_received = QtCore.Signal(object)
90 90
91 91 # Emitted when a message of type 'pyout' is received.
92 92 pyout_received = QtCore.Signal(object)
93 93
94 94 # Emitted when a message of type 'pyerr' is received.
95 95 pyerr_received = QtCore.Signal(object)
96 96
97 97 # Emitted when a message of type 'display_data' is received
98 98 display_data_received = QtCore.Signal(object)
99 99
100 100 # Emitted when a crash report message is received from the kernel's
101 101 # last-resort sys.excepthook.
102 102 crash_received = QtCore.Signal(object)
103 103
104 104 # Emitted when a shutdown is noticed.
105 105 shutdown_reply_received = QtCore.Signal(object)
106 106
107 107 #---------------------------------------------------------------------------
108 108 # 'IOPubChannel' interface
109 109 #---------------------------------------------------------------------------
110 110
111 111 def call_handlers(self, msg):
112 112 """ Reimplemented to emit signals instead of making callbacks.
113 113 """
114 114 # Emit the generic signal.
115 115 self.message_received.emit(msg)
116 116 # Emit signals for specialized message types.
117 117 msg_type = msg['header']['msg_type']
118 118 signal = getattr(self, msg_type + '_received', None)
119 119 if signal:
120 120 signal.emit(msg)
121 121 elif msg_type in ('stdout', 'stderr'):
122 122 self.stream_received.emit(msg)
123 123
124 124 def flush(self):
125 125 """ Reimplemented to ensure that signals are dispatched immediately.
126 126 """
127 127 super(QtIOPubChannelMixin, self).flush()
128 128 QtCore.QCoreApplication.instance().processEvents()
129 129
130 130
131 131 class QtStdInChannelMixin(ChannelQObject):
132 132
133 133 # Emitted when any message is received.
134 134 message_received = QtCore.Signal(object)
135 135
136 136 # Emitted when an input request is received.
137 137 input_requested = QtCore.Signal(object)
138 138
139 139 #---------------------------------------------------------------------------
140 140 # 'StdInChannel' interface
141 141 #---------------------------------------------------------------------------
142 142
143 143 def call_handlers(self, msg):
144 144 """ Reimplemented to emit signals instead of making callbacks.
145 145 """
146 146 # Emit the generic signal.
147 147 self.message_received.emit(msg)
148 148
149 149 # Emit signals for specialized message types.
150 150 msg_type = msg['header']['msg_type']
151 151 if msg_type == 'input_request':
152 152 self.input_requested.emit(msg)
153 153
154 154
155 155 class QtHBChannelMixin(ChannelQObject):
156 156
157 157 # Emitted when the kernel has died.
158 158 kernel_died = QtCore.Signal(object)
159 159
160 160 #---------------------------------------------------------------------------
161 161 # 'HBChannel' interface
162 162 #---------------------------------------------------------------------------
163 163
164 164 def call_handlers(self, since_last_heartbeat):
165 165 """ Reimplemented to emit signals instead of making callbacks.
166 166 """
167 167 # Emit the generic signal.
168 168 self.kernel_died.emit(since_last_heartbeat)
169 169
170 170
171 171 class QtKernelRestarterMixin(MetaQObjectHasTraits('NewBase', (HasTraits, SuperQObject), {})):
172 172
173 173 _timer = None
174 174
175 175
176 176 class QtKernelManagerMixin(MetaQObjectHasTraits('NewBase', (HasTraits, SuperQObject), {})):
177 177 """ A KernelClient that provides signals and slots.
178 178 """
179 179
180 180 kernel_restarted = QtCore.Signal()
181 181
182 182
183 183 class QtKernelClientMixin(MetaQObjectHasTraits('NewBase', (HasTraits, SuperQObject), {})):
184 184 """ A KernelClient that provides signals and slots.
185 185 """
186 186
187 187 # Emitted when the kernel client has started listening.
188 188 started_channels = QtCore.Signal()
189 189
190 190 # Emitted when the kernel client has stopped listening.
191 191 stopped_channels = QtCore.Signal()
192 192
193 193 # Use Qt-specific channel classes that emit signals.
194 194 iopub_channel_class = Type(QtIOPubChannelMixin)
195 195 shell_channel_class = Type(QtShellChannelMixin)
196 196 stdin_channel_class = Type(QtStdInChannelMixin)
197 197 hb_channel_class = Type(QtHBChannelMixin)
198 198
199 199 #---------------------------------------------------------------------------
200 200 # 'KernelClient' interface
201 201 #---------------------------------------------------------------------------
202 202
203 203 #------ Channel management -------------------------------------------------
204 204
205 205 def start_channels(self, *args, **kw):
206 206 """ Reimplemented to emit signal.
207 207 """
208 208 super(QtKernelClientMixin, self).start_channels(*args, **kw)
209 209 self.started_channels.emit()
210 210
211 211 def stop_channels(self):
212 212 """ Reimplemented to emit signal.
213 213 """
214 214 super(QtKernelClientMixin, self).stop_channels()
215 215 self.stopped_channels.emit()
General Comments 0
You need to be logged in to leave comments. Login now