##// END OF EJS Templates
Fixing another bug in msg_type refactoring.
Brian E. Granger -
Show More
@@ -1,440 +1,440
1 1 #!/usr/bin/env python
2 2 """
3 3 Kernel adapted from kernel.py to use ZMQ Streams
4 4
5 5 Authors:
6 6
7 7 * Min RK
8 8 * Brian Granger
9 9 * Fernando Perez
10 10 * Evan Patterson
11 11 """
12 12 #-----------------------------------------------------------------------------
13 13 # Copyright (C) 2010-2011 The IPython Development Team
14 14 #
15 15 # Distributed under the terms of the BSD License. The full license is in
16 16 # the file COPYING, distributed as part of this software.
17 17 #-----------------------------------------------------------------------------
18 18
19 19 #-----------------------------------------------------------------------------
20 20 # Imports
21 21 #-----------------------------------------------------------------------------
22 22
23 23 # Standard library imports.
24 24 from __future__ import print_function
25 25
26 26 import sys
27 27 import time
28 28
29 29 from code import CommandCompiler
30 30 from datetime import datetime
31 31 from pprint import pprint
32 32
33 33 # System library imports.
34 34 import zmq
35 35 from zmq.eventloop import ioloop, zmqstream
36 36
37 37 # Local imports.
38 38 from IPython.utils.traitlets import Instance, List, Int, Dict, Set, Unicode, CBytes
39 39 from IPython.zmq.completer import KernelCompleter
40 40
41 41 from IPython.parallel.error import wrap_exception
42 42 from IPython.parallel.factory import SessionFactory
43 43 from IPython.parallel.util import serialize_object, unpack_apply_message, asbytes
44 44
45 45 def printer(*args):
46 46 pprint(args, stream=sys.__stdout__)
47 47
48 48
49 49 class _Passer(zmqstream.ZMQStream):
50 50 """Empty class that implements `send()` that does nothing.
51 51
52 52 Subclass ZMQStream for Session typechecking
53 53
54 54 """
55 55 def __init__(self, *args, **kwargs):
56 56 pass
57 57
58 58 def send(self, *args, **kwargs):
59 59 pass
60 60 send_multipart = send
61 61
62 62
63 63 #-----------------------------------------------------------------------------
64 64 # Main kernel class
65 65 #-----------------------------------------------------------------------------
66 66
67 67 class Kernel(SessionFactory):
68 68
69 69 #---------------------------------------------------------------------------
70 70 # Kernel interface
71 71 #---------------------------------------------------------------------------
72 72
73 73 # kwargs:
74 74 exec_lines = List(Unicode, config=True,
75 75 help="List of lines to execute")
76 76
77 77 # identities:
78 78 int_id = Int(-1)
79 79 bident = CBytes()
80 80 ident = Unicode()
81 81 def _ident_changed(self, name, old, new):
82 82 self.bident = asbytes(new)
83 83
84 84 user_ns = Dict(config=True, help="""Set the user's namespace of the Kernel""")
85 85
86 86 control_stream = Instance(zmqstream.ZMQStream)
87 87 task_stream = Instance(zmqstream.ZMQStream)
88 88 iopub_stream = Instance(zmqstream.ZMQStream)
89 89 client = Instance('IPython.parallel.Client')
90 90
91 91 # internals
92 92 shell_streams = List()
93 93 compiler = Instance(CommandCompiler, (), {})
94 94 completer = Instance(KernelCompleter)
95 95
96 96 aborted = Set()
97 97 shell_handlers = Dict()
98 98 control_handlers = Dict()
99 99
100 100 def _set_prefix(self):
101 101 self.prefix = "engine.%s"%self.int_id
102 102
103 103 def _connect_completer(self):
104 104 self.completer = KernelCompleter(self.user_ns)
105 105
106 106 def __init__(self, **kwargs):
107 107 super(Kernel, self).__init__(**kwargs)
108 108 self._set_prefix()
109 109 self._connect_completer()
110 110
111 111 self.on_trait_change(self._set_prefix, 'id')
112 112 self.on_trait_change(self._connect_completer, 'user_ns')
113 113
114 114 # Build dict of handlers for message types
115 115 for msg_type in ['execute_request', 'complete_request', 'apply_request',
116 116 'clear_request']:
117 117 self.shell_handlers[msg_type] = getattr(self, msg_type)
118 118
119 119 for msg_type in ['shutdown_request', 'abort_request']+self.shell_handlers.keys():
120 120 self.control_handlers[msg_type] = getattr(self, msg_type)
121 121
122 122 self._initial_exec_lines()
123 123
124 124 def _wrap_exception(self, method=None):
125 125 e_info = dict(engine_uuid=self.ident, engine_id=self.int_id, method=method)
126 126 content=wrap_exception(e_info)
127 127 return content
128 128
129 129 def _initial_exec_lines(self):
130 130 s = _Passer()
131 131 content = dict(silent=True, user_variable=[],user_expressions=[])
132 132 for line in self.exec_lines:
133 133 self.log.debug("executing initialization: %s"%line)
134 134 content.update({'code':line})
135 135 msg = self.session.msg('execute_request', content)
136 136 self.execute_request(s, [], msg)
137 137
138 138
139 139 #-------------------- control handlers -----------------------------
140 140 def abort_queues(self):
141 141 for stream in self.shell_streams:
142 142 if stream:
143 143 self.abort_queue(stream)
144 144
145 145 def abort_queue(self, stream):
146 146 while True:
147 147 idents,msg = self.session.recv(stream, zmq.NOBLOCK, content=True)
148 148 if msg is None:
149 149 return
150 150
151 151 self.log.info("Aborting:")
152 152 self.log.info(str(msg))
153 153 msg_type = msg['header']['msg_type']
154 154 reply_type = msg_type.split('_')[0] + '_reply'
155 155 # reply_msg = self.session.msg(reply_type, {'status' : 'aborted'}, msg)
156 156 # self.reply_socket.send(ident,zmq.SNDMORE)
157 157 # self.reply_socket.send_json(reply_msg)
158 158 reply_msg = self.session.send(stream, reply_type,
159 159 content={'status' : 'aborted'}, parent=msg, ident=idents)
160 160 self.log.debug(str(reply_msg))
161 161 # We need to wait a bit for requests to come in. This can probably
162 162 # be set shorter for true asynchronous clients.
163 163 time.sleep(0.05)
164 164
165 165 def abort_request(self, stream, ident, parent):
166 166 """abort a specifig msg by id"""
167 167 msg_ids = parent['content'].get('msg_ids', None)
168 168 if isinstance(msg_ids, basestring):
169 169 msg_ids = [msg_ids]
170 170 if not msg_ids:
171 171 self.abort_queues()
172 172 for mid in msg_ids:
173 173 self.aborted.add(str(mid))
174 174
175 175 content = dict(status='ok')
176 176 reply_msg = self.session.send(stream, 'abort_reply', content=content,
177 177 parent=parent, ident=ident)
178 178 self.log.debug(str(reply_msg))
179 179
180 180 def shutdown_request(self, stream, ident, parent):
181 181 """kill ourself. This should really be handled in an external process"""
182 182 try:
183 183 self.abort_queues()
184 184 except:
185 185 content = self._wrap_exception('shutdown')
186 186 else:
187 187 content = dict(parent['content'])
188 188 content['status'] = 'ok'
189 189 msg = self.session.send(stream, 'shutdown_reply',
190 190 content=content, parent=parent, ident=ident)
191 191 self.log.debug(str(msg))
192 192 dc = ioloop.DelayedCallback(lambda : sys.exit(0), 1000, self.loop)
193 193 dc.start()
194 194
195 195 def dispatch_control(self, msg):
196 196 idents,msg = self.session.feed_identities(msg, copy=False)
197 197 try:
198 198 msg = self.session.unserialize(msg, content=True, copy=False)
199 199 except:
200 200 self.log.error("Invalid Message", exc_info=True)
201 201 return
202 202 else:
203 203 self.log.debug("Control received, %s", msg)
204 204
205 205 header = msg['header']
206 206 msg_id = header['msg_id']
207 207 msg_type = header['msg_type']
208 208
209 209 handler = self.control_handlers.get(msg_type, None)
210 210 if handler is None:
211 211 self.log.error("UNKNOWN CONTROL MESSAGE TYPE: %r"%msg_type)
212 212 else:
213 213 handler(self.control_stream, idents, msg)
214 214
215 215
216 216 #-------------------- queue helpers ------------------------------
217 217
218 218 def check_dependencies(self, dependencies):
219 219 if not dependencies:
220 220 return True
221 221 if len(dependencies) == 2 and dependencies[0] in 'any all'.split():
222 222 anyorall = dependencies[0]
223 223 dependencies = dependencies[1]
224 224 else:
225 225 anyorall = 'all'
226 226 results = self.client.get_results(dependencies,status_only=True)
227 227 if results['status'] != 'ok':
228 228 return False
229 229
230 230 if anyorall == 'any':
231 231 if not results['completed']:
232 232 return False
233 233 else:
234 234 if results['pending']:
235 235 return False
236 236
237 237 return True
238 238
239 239 def check_aborted(self, msg_id):
240 240 return msg_id in self.aborted
241 241
242 242 #-------------------- queue handlers -----------------------------
243 243
244 244 def clear_request(self, stream, idents, parent):
245 245 """Clear our namespace."""
246 246 self.user_ns = {}
247 247 msg = self.session.send(stream, 'clear_reply', ident=idents, parent=parent,
248 248 content = dict(status='ok'))
249 249 self._initial_exec_lines()
250 250
251 251 def execute_request(self, stream, ident, parent):
252 252 self.log.debug('execute request %s'%parent)
253 253 try:
254 254 code = parent[u'content'][u'code']
255 255 except:
256 256 self.log.error("Got bad msg: %s"%parent, exc_info=True)
257 257 return
258 258 self.session.send(self.iopub_stream, u'pyin', {u'code':code},parent=parent,
259 259 ident=asbytes('%s.pyin'%self.prefix))
260 260 started = datetime.now()
261 261 try:
262 262 comp_code = self.compiler(code, '<zmq-kernel>')
263 263 # allow for not overriding displayhook
264 264 if hasattr(sys.displayhook, 'set_parent'):
265 265 sys.displayhook.set_parent(parent)
266 266 sys.stdout.set_parent(parent)
267 267 sys.stderr.set_parent(parent)
268 268 exec comp_code in self.user_ns, self.user_ns
269 269 except:
270 270 exc_content = self._wrap_exception('execute')
271 271 # exc_msg = self.session.msg(u'pyerr', exc_content, parent)
272 272 self.session.send(self.iopub_stream, u'pyerr', exc_content, parent=parent,
273 273 ident=asbytes('%s.pyerr'%self.prefix))
274 274 reply_content = exc_content
275 275 else:
276 276 reply_content = {'status' : 'ok'}
277 277
278 278 reply_msg = self.session.send(stream, u'execute_reply', reply_content, parent=parent,
279 279 ident=ident, subheader = dict(started=started))
280 280 self.log.debug(str(reply_msg))
281 281 if reply_msg['content']['status'] == u'error':
282 282 self.abort_queues()
283 283
284 284 def complete_request(self, stream, ident, parent):
285 285 matches = {'matches' : self.complete(parent),
286 286 'status' : 'ok'}
287 287 completion_msg = self.session.send(stream, 'complete_reply',
288 288 matches, parent, ident)
289 289 # print >> sys.__stdout__, completion_msg
290 290
291 291 def complete(self, msg):
292 292 return self.completer.complete(msg.content.line, msg.content.text)
293 293
294 294 def apply_request(self, stream, ident, parent):
295 295 # flush previous reply, so this request won't block it
296 296 stream.flush(zmq.POLLOUT)
297 297 try:
298 298 content = parent[u'content']
299 299 bufs = parent[u'buffers']
300 300 msg_id = parent['header']['msg_id']
301 301 # bound = parent['header'].get('bound', False)
302 302 except:
303 303 self.log.error("Got bad msg: %s"%parent, exc_info=True)
304 304 return
305 305 # pyin_msg = self.session.msg(u'pyin',{u'code':code}, parent=parent)
306 306 # self.iopub_stream.send(pyin_msg)
307 307 # self.session.send(self.iopub_stream, u'pyin', {u'code':code},parent=parent)
308 308 sub = {'dependencies_met' : True, 'engine' : self.ident,
309 309 'started': datetime.now()}
310 310 try:
311 311 # allow for not overriding displayhook
312 312 if hasattr(sys.displayhook, 'set_parent'):
313 313 sys.displayhook.set_parent(parent)
314 314 sys.stdout.set_parent(parent)
315 315 sys.stderr.set_parent(parent)
316 316 # exec "f(*args,**kwargs)" in self.user_ns, self.user_ns
317 317 working = self.user_ns
318 318 # suffix =
319 319 prefix = "_"+str(msg_id).replace("-","")+"_"
320 320
321 321 f,args,kwargs = unpack_apply_message(bufs, working, copy=False)
322 322 # if bound:
323 323 # bound_ns = Namespace(working)
324 324 # args = [bound_ns]+list(args)
325 325
326 326 fname = getattr(f, '__name__', 'f')
327 327
328 328 fname = prefix+"f"
329 329 argname = prefix+"args"
330 330 kwargname = prefix+"kwargs"
331 331 resultname = prefix+"result"
332 332
333 333 ns = { fname : f, argname : args, kwargname : kwargs , resultname : None }
334 334 # print ns
335 335 working.update(ns)
336 336 code = "%s=%s(*%s,**%s)"%(resultname, fname, argname, kwargname)
337 337 try:
338 338 exec code in working,working
339 339 result = working.get(resultname)
340 340 finally:
341 341 for key in ns.iterkeys():
342 342 working.pop(key)
343 343 # if bound:
344 344 # working.update(bound_ns)
345 345
346 346 packed_result,buf = serialize_object(result)
347 347 result_buf = [packed_result]+buf
348 348 except:
349 349 exc_content = self._wrap_exception('apply')
350 350 # exc_msg = self.session.msg(u'pyerr', exc_content, parent)
351 351 self.session.send(self.iopub_stream, u'pyerr', exc_content, parent=parent,
352 352 ident=asbytes('%s.pyerr'%self.prefix))
353 353 reply_content = exc_content
354 354 result_buf = []
355 355
356 356 if exc_content['ename'] == 'UnmetDependency':
357 357 sub['dependencies_met'] = False
358 358 else:
359 359 reply_content = {'status' : 'ok'}
360 360
361 361 # put 'ok'/'error' status in header, for scheduler introspection:
362 362 sub['status'] = reply_content['status']
363 363
364 364 reply_msg = self.session.send(stream, u'apply_reply', reply_content,
365 365 parent=parent, ident=ident,buffers=result_buf, subheader=sub)
366 366
367 367 # flush i/o
368 368 # should this be before reply_msg is sent, like in the single-kernel code,
369 369 # or should nothing get in the way of real results?
370 370 sys.stdout.flush()
371 371 sys.stderr.flush()
372 372
373 373 def dispatch_queue(self, stream, msg):
374 374 self.control_stream.flush()
375 375 idents,msg = self.session.feed_identities(msg, copy=False)
376 376 try:
377 377 msg = self.session.unserialize(msg, content=True, copy=False)
378 378 except:
379 379 self.log.error("Invalid Message", exc_info=True)
380 380 return
381 381 else:
382 382 self.log.debug("Message received, %s", msg)
383 383
384 384
385 385 header = msg['header']
386 386 msg_id = header['msg_id']
387 msg['header']['msg_type']
387 msg_type = msg['header']['msg_type']
388 388 if self.check_aborted(msg_id):
389 389 self.aborted.remove(msg_id)
390 390 # is it safe to assume a msg_id will not be resubmitted?
391 391 reply_type = msg_type.split('_')[0] + '_reply'
392 392 status = {'status' : 'aborted'}
393 393 reply_msg = self.session.send(stream, reply_type, subheader=status,
394 394 content=status, parent=msg, ident=idents)
395 395 return
396 396 handler = self.shell_handlers.get(msg_type, None)
397 397 if handler is None:
398 398 self.log.error("UNKNOWN MESSAGE TYPE: %r"%msg_type)
399 399 else:
400 400 handler(stream, idents, msg)
401 401
402 402 def start(self):
403 403 #### stream mode:
404 404 if self.control_stream:
405 405 self.control_stream.on_recv(self.dispatch_control, copy=False)
406 406 self.control_stream.on_err(printer)
407 407
408 408 def make_dispatcher(stream):
409 409 def dispatcher(msg):
410 410 return self.dispatch_queue(stream, msg)
411 411 return dispatcher
412 412
413 413 for s in self.shell_streams:
414 414 s.on_recv(make_dispatcher(s), copy=False)
415 415 s.on_err(printer)
416 416
417 417 if self.iopub_stream:
418 418 self.iopub_stream.on_err(printer)
419 419
420 420 #### while True mode:
421 421 # while True:
422 422 # idle = True
423 423 # try:
424 424 # msg = self.shell_stream.socket.recv_multipart(
425 425 # zmq.NOBLOCK, copy=False)
426 426 # except zmq.ZMQError, e:
427 427 # if e.errno != zmq.EAGAIN:
428 428 # raise e
429 429 # else:
430 430 # idle=False
431 431 # self.dispatch_queue(self.shell_stream, msg)
432 432 #
433 433 # if not self.task_stream.empty():
434 434 # idle=False
435 435 # msg = self.task_stream.recv_multipart()
436 436 # self.dispatch_queue(self.task_stream, msg)
437 437 # if idle:
438 438 # # don't busywait
439 439 # time.sleep(1e-3)
440 440
General Comments 0
You need to be logged in to leave comments. Login now