##// END OF EJS Templates
added dependency decorator
MinRK -
Show More
@@ -0,0 +1,66
1 """Dependency utilities"""
2
3 from IPython.external.decorator import decorator
4
5 # flags
6 ALL = 1 << 0
7 ANY = 1 << 1
8 HERE = 1 << 2
9 ANYWHERE = 1 << 3
10
11 class UnmetDependency(Exception):
12 pass
13
14 class depend2(object):
15 """dependency decorator"""
16 def __init__(self, f, *args, **kwargs):
17 print "Inside __init__()"
18 self.dependency = (f,args,kwargs)
19
20 def __call__(self, f, *args, **kwargs):
21 f._dependency = self.dependency
22 return decorator(_depend_wrapper, f)
23
24 class depend(object):
25 """Dependency decorator, for use with tasks."""
26 def __init__(self, f, *args, **kwargs):
27 print "Inside __init__()"
28 self.f = f
29 self.args = args
30 self.kwargs = kwargs
31
32 def __call__(self, f):
33 return dependent(f, self.f, *self.args, **self.kwargs)
34
35 class dependent(object):
36 """A function that depends on another function.
37 This is an object to prevent the closure used
38 in traditional decorators, which are not picklable.
39 """
40
41 def __init__(self, f, df, *dargs, **dkwargs):
42 self.f = f
43 self.func_name = self.f.func_name
44 self.df = df
45 self.dargs = dargs
46 self.dkwargs = dkwargs
47
48 def __call__(self, *args, **kwargs):
49 if self.df(*self.dargs, **self.dkwargs) is False:
50 raise UnmetDependency()
51 return self.f(*args, **kwargs)
52
53 def evaluate_dependency(deps):
54 """Evaluate wheter dependencies are met.
55
56 Parameters
57 ----------
58 deps : dict
59 """
60 pass
61
62 def _check_dependency(flag):
63 pass
64
65
66 __all__ = ['UnmetDependency', 'depend', 'evaluate_dependencies'] No newline at end of file
@@ -1,510 +1,507
1 1 #!/usr/bin/env python
2 2 """
3 3 Kernel adapted from kernel.py to use ZMQ Streams
4 4 """
5 5
6 6 import __builtin__
7 7 import os
8 8 import sys
9 9 import time
10 10 import traceback
11 11 from signal import SIGTERM, SIGKILL
12 12 from pprint import pprint
13 13
14 14 from code import CommandCompiler
15 15
16 16 import zmq
17 17 from zmq.eventloop import ioloop, zmqstream
18 18
19 from IPython.zmq.completer import KernelCompleter
20
19 21 from streamsession import StreamSession, Message, extract_header, serialize_object,\
20 22 unpack_apply_message
21 from IPython.zmq.completer import KernelCompleter
23 from dependency import UnmetDependency
22 24
23 25 def printer(*args):
24 26 pprint(args)
25 27
26 28 class OutStream(object):
27 29 """A file like object that publishes the stream to a 0MQ PUB socket."""
28 30
29 31 def __init__(self, session, pub_socket, name, max_buffer=200):
30 32 self.session = session
31 33 self.pub_socket = pub_socket
32 34 self.name = name
33 35 self._buffer = []
34 36 self._buffer_len = 0
35 37 self.max_buffer = max_buffer
36 38 self.parent_header = {}
37 39
38 40 def set_parent(self, parent):
39 41 self.parent_header = extract_header(parent)
40 42
41 43 def close(self):
42 44 self.pub_socket = None
43 45
44 46 def flush(self):
45 47 if self.pub_socket is None:
46 48 raise ValueError(u'I/O operation on closed file')
47 49 else:
48 50 if self._buffer:
49 51 data = ''.join(self._buffer)
50 52 content = {u'name':self.name, u'data':data}
51 53 # msg = self.session.msg(u'stream', content=content,
52 54 # parent=self.parent_header)
53 55 msg = self.session.send(self.pub_socket, u'stream', content=content, parent=self.parent_header)
54 56 # print>>sys.__stdout__, Message(msg)
55 57 # self.pub_socket.send_json(msg)
56 58 self._buffer_len = 0
57 59 self._buffer = []
58 60
59 61 def isattr(self):
60 62 return False
61 63
62 64 def next(self):
63 65 raise IOError('Read not supported on a write only stream.')
64 66
65 67 def read(self, size=None):
66 68 raise IOError('Read not supported on a write only stream.')
67 69
68 70 readline=read
69 71
70 72 def write(self, s):
71 73 if self.pub_socket is None:
72 74 raise ValueError('I/O operation on closed file')
73 75 else:
74 76 self._buffer.append(s)
75 77 self._buffer_len += len(s)
76 78 self._maybe_send()
77 79
78 80 def _maybe_send(self):
79 81 if '\n' in self._buffer[-1]:
80 82 self.flush()
81 83 if self._buffer_len > self.max_buffer:
82 84 self.flush()
83 85
84 86 def writelines(self, sequence):
85 87 if self.pub_socket is None:
86 88 raise ValueError('I/O operation on closed file')
87 89 else:
88 90 for s in sequence:
89 91 self.write(s)
90 92
91 93
92 94 class DisplayHook(object):
93 95
94 96 def __init__(self, session, pub_socket):
95 97 self.session = session
96 98 self.pub_socket = pub_socket
97 99 self.parent_header = {}
98 100
99 101 def __call__(self, obj):
100 102 if obj is None:
101 103 return
102 104
103 105 __builtin__._ = obj
104 106 # msg = self.session.msg(u'pyout', {u'data':repr(obj)},
105 107 # parent=self.parent_header)
106 108 # self.pub_socket.send_json(msg)
107 109 self.session.send(self.pub_socket, u'pyout', content={u'data':repr(obj)}, parent=self.parent_header)
108 110
109 111 def set_parent(self, parent):
110 112 self.parent_header = extract_header(parent)
111 113
112 114
113 115 class RawInput(object):
114 116
115 117 def __init__(self, session, socket):
116 118 self.session = session
117 119 self.socket = socket
118 120
119 121 def __call__(self, prompt=None):
120 122 msg = self.session.msg(u'raw_input')
121 123 self.socket.send_json(msg)
122 124 while True:
123 125 try:
124 126 reply = self.socket.recv_json(zmq.NOBLOCK)
125 127 except zmq.ZMQError, e:
126 128 if e.errno == zmq.EAGAIN:
127 129 pass
128 130 else:
129 131 raise
130 132 else:
131 133 break
132 134 return reply[u'content'][u'data']
133 135
134 136
135 137 class Kernel(object):
136 138
137 139 def __init__(self, session, control_stream, reply_stream, pub_stream,
138 140 task_stream=None, client=None):
139 141 self.session = session
140 142 self.control_stream = control_stream
141 143 # self.control_socket = control_stream.socket
142 144 self.reply_stream = reply_stream
143 145 self.task_stream = task_stream
144 146 self.pub_stream = pub_stream
145 147 self.client = client
146 148 self.user_ns = {}
147 149 self.history = []
148 150 self.compiler = CommandCompiler()
149 151 self.completer = KernelCompleter(self.user_ns)
150 152 self.aborted = set()
151 153
152 154 # Build dict of handlers for message types
153 155 self.queue_handlers = {}
154 156 self.control_handlers = {}
155 157 for msg_type in ['execute_request', 'complete_request', 'apply_request']:
156 158 self.queue_handlers[msg_type] = getattr(self, msg_type)
157 159
158 for msg_type in ['kill_request', 'abort_request']:
160 for msg_type in ['kill_request', 'abort_request']+self.queue_handlers.keys():
159 161 self.control_handlers[msg_type] = getattr(self, msg_type)
160 162
161 163 #-------------------- control handlers -----------------------------
162 164 def abort_queues(self):
163 165 for stream in (self.task_stream, self.reply_stream):
164 166 if stream:
165 167 self.abort_queue(stream)
166 168
167 169 def abort_queue(self, stream):
168 170 while True:
169 171 try:
170 172 msg = self.session.recv(stream, zmq.NOBLOCK,content=True)
171 173 except zmq.ZMQError, e:
172 174 if e.errno == zmq.EAGAIN:
173 175 break
174 176 else:
175 177 return
176 178 else:
177 179 if msg is None:
178 180 return
179 181 else:
180 182 idents,msg = msg
181 183
182 184 # assert self.reply_socketly_socket.rcvmore(), "Unexpected missing message part."
183 185 # msg = self.reply_socket.recv_json()
184 186 print>>sys.__stdout__, "Aborting:"
185 187 print>>sys.__stdout__, Message(msg)
186 188 msg_type = msg['msg_type']
187 189 reply_type = msg_type.split('_')[0] + '_reply'
188 190 # reply_msg = self.session.msg(reply_type, {'status' : 'aborted'}, msg)
189 191 # self.reply_socket.send(ident,zmq.SNDMORE)
190 192 # self.reply_socket.send_json(reply_msg)
191 193 reply_msg = self.session.send(stream, reply_type,
192 194 content={'status' : 'aborted'}, parent=msg, ident=idents)
193 195 print>>sys.__stdout__, Message(reply_msg)
194 196 # We need to wait a bit for requests to come in. This can probably
195 197 # be set shorter for true asynchronous clients.
196 198 time.sleep(0.05)
197 199
198 200 def abort_request(self, stream, ident, parent):
199 201 """abort a specifig msg by id"""
200 202 msg_ids = parent['content'].get('msg_ids', None)
201 203 if isinstance(msg_ids, basestring):
202 204 msg_ids = [msg_ids]
203 205 if not msg_ids:
204 206 self.abort_queues()
205 207 for mid in msg_ids:
206 208 self.aborted.add(str(mid))
207 209
208 210 content = dict(status='ok')
209 211 reply_msg = self.session.send(stream, 'abort_reply', content=content, parent=parent,
210 212 ident=ident)
211 213 print>>sys.__stdout__, Message(reply_msg)
212 214
213 215 def kill_request(self, stream, idents, parent):
214 216 """kill ourselves. This should really be handled in an external process"""
215 217 self.abort_queues()
216 218 msg = self.session.send(stream, 'kill_reply', ident=idents, parent=parent,
217 219 content = dict(status='ok'))
218 220 # we can know that a message is done if we *don't* use streams, but
219 221 # use a socket directly with MessageTracker
220 222 time.sleep(.5)
221 223 os.kill(os.getpid(), SIGTERM)
222 224 time.sleep(1)
223 225 os.kill(os.getpid(), SIGKILL)
224 226
225 227 def dispatch_control(self, msg):
226 228 idents,msg = self.session.feed_identities(msg, copy=False)
227 229 msg = self.session.unpack_message(msg, content=True, copy=False)
228 230
229 231 header = msg['header']
230 232 msg_id = header['msg_id']
231 233
232 234 handler = self.control_handlers.get(msg['msg_type'], None)
233 235 if handler is None:
234 236 print >> sys.__stderr__, "UNKNOWN CONTROL MESSAGE TYPE:", msg
235 237 else:
236 238 handler(self.control_stream, idents, msg)
237 239
238 240 # def flush_control(self):
239 241 # while any(zmq.select([self.control_socket],[],[],1e-4)):
240 242 # try:
241 243 # msg = self.control_socket.recv_multipart(zmq.NOBLOCK, copy=False)
242 244 # except zmq.ZMQError, e:
243 245 # if e.errno != zmq.EAGAIN:
244 246 # raise e
245 247 # return
246 248 # else:
247 249 # self.dispatch_control(msg)
248 250
249 251
250 252 #-------------------- queue helpers ------------------------------
251 253
252 254 def check_dependencies(self, dependencies):
253 255 if not dependencies:
254 256 return True
255 257 if len(dependencies) == 2 and dependencies[0] in 'any all'.split():
256 258 anyorall = dependencies[0]
257 259 dependencies = dependencies[1]
258 260 else:
259 261 anyorall = 'all'
260 262 results = self.client.get_results(dependencies,status_only=True)
261 263 if results['status'] != 'ok':
262 264 return False
263 265
264 266 if anyorall == 'any':
265 267 if not results['completed']:
266 268 return False
267 269 else:
268 270 if results['pending']:
269 271 return False
270 272
271 273 return True
272 274
273 275 def check_aborted(self, msg_id):
274 276 return msg_id in self.aborted
275 277
276 def unmet_dependencies(self, stream, idents, msg):
277 reply_type = msg['msg_type'].split('_')[0] + '_reply'
278 content = dict(status='resubmitted', reason='unmet dependencies')
279 reply_msg = self.session.send(stream, reply_type,
280 content=content, parent=msg, ident=idents)
281 ### TODO: actually resubmit it ###
282
283 278 #-------------------- queue handlers -----------------------------
284 279
285 280 def execute_request(self, stream, ident, parent):
286 281 try:
287 282 code = parent[u'content'][u'code']
288 283 except:
289 284 print>>sys.__stderr__, "Got bad msg: "
290 285 print>>sys.__stderr__, Message(parent)
291 286 return
292 287 # pyin_msg = self.session.msg(u'pyin',{u'code':code}, parent=parent)
293 288 # self.pub_stream.send(pyin_msg)
294 289 self.session.send(self.pub_stream, u'pyin', {u'code':code},parent=parent)
295 290 try:
296 291 comp_code = self.compiler(code, '<zmq-kernel>')
297 292 # allow for not overriding displayhook
298 293 if hasattr(sys.displayhook, 'set_parent'):
299 294 sys.displayhook.set_parent(parent)
300 295 exec comp_code in self.user_ns, self.user_ns
301 296 except:
302 297 # result = u'error'
303 298 etype, evalue, tb = sys.exc_info()
304 299 tb = traceback.format_exception(etype, evalue, tb)
305 300 exc_content = {
306 301 u'status' : u'error',
307 302 u'traceback' : tb,
308 303 u'etype' : unicode(etype),
309 304 u'evalue' : unicode(evalue)
310 305 }
311 306 # exc_msg = self.session.msg(u'pyerr', exc_content, parent)
312 307 self.session.send(self.pub_stream, u'pyerr', exc_content, parent=parent)
313 308 reply_content = exc_content
314 309 else:
315 310 reply_content = {'status' : 'ok'}
316 311 # reply_msg = self.session.msg(u'execute_reply', reply_content, parent)
317 312 # self.reply_socket.send(ident, zmq.SNDMORE)
318 313 # self.reply_socket.send_json(reply_msg)
319 314 reply_msg = self.session.send(stream, u'execute_reply', reply_content, parent=parent, ident=ident)
320 315 # print>>sys.__stdout__, Message(reply_msg)
321 316 if reply_msg['content']['status'] == u'error':
322 317 self.abort_queues()
323 318
324 319 def complete_request(self, stream, ident, parent):
325 320 matches = {'matches' : self.complete(parent),
326 321 'status' : 'ok'}
327 322 completion_msg = self.session.send(stream, 'complete_reply',
328 323 matches, parent, ident)
329 324 # print >> sys.__stdout__, completion_msg
330 325
331 326 def complete(self, msg):
332 327 return self.completer.complete(msg.content.line, msg.content.text)
333 328
334 329 def apply_request(self, stream, ident, parent):
335 330 try:
336 331 content = parent[u'content']
337 332 bufs = parent[u'buffers']
338 333 msg_id = parent['header']['msg_id']
339 334 bound = content.get('bound', False)
340 335 except:
341 336 print>>sys.__stderr__, "Got bad msg: "
342 337 print>>sys.__stderr__, Message(parent)
343 338 return
344 339 # pyin_msg = self.session.msg(u'pyin',{u'code':code}, parent=parent)
345 340 # self.pub_stream.send(pyin_msg)
346 341 # self.session.send(self.pub_stream, u'pyin', {u'code':code},parent=parent)
342 sub = {'dependencies_met' : True}
347 343 try:
348 344 # allow for not overriding displayhook
349 345 if hasattr(sys.displayhook, 'set_parent'):
350 346 sys.displayhook.set_parent(parent)
351 347 # exec "f(*args,**kwargs)" in self.user_ns, self.user_ns
352 348 if bound:
353 349 working = self.user_ns
354 350 suffix = str(msg_id).replace("-","")
355 351 prefix = "_"
356 352
357 353 else:
358 354 working = dict()
359 355 suffix = prefix = "_" # prevent keyword collisions with lambda
360 356 f,args,kwargs = unpack_apply_message(bufs, working, copy=False)
361 357 # if f.fun
362 358 fname = prefix+f.func_name.strip('<>')+suffix
363 359 argname = prefix+"args"+suffix
364 360 kwargname = prefix+"kwargs"+suffix
365 361 resultname = prefix+"result"+suffix
366 362
367 363 ns = { fname : f, argname : args, kwargname : kwargs }
368 364 # print ns
369 365 working.update(ns)
370 366 code = "%s=%s(*%s,**%s)"%(resultname, fname, argname, kwargname)
371 367 exec code in working, working
372 368 result = working.get(resultname)
373 369 # clear the namespace
374 370 if bound:
375 371 for key in ns.iterkeys():
376 372 self.user_ns.pop(key)
377 373 else:
378 374 del working
379 375
380 376 packed_result,buf = serialize_object(result)
381 377 result_buf = [packed_result]+buf
382 378 except:
383 379 result = u'error'
384 380 etype, evalue, tb = sys.exc_info()
385 381 tb = traceback.format_exception(etype, evalue, tb)
386 382 exc_content = {
387 383 u'status' : u'error',
388 384 u'traceback' : tb,
389 385 u'etype' : unicode(etype),
390 386 u'evalue' : unicode(evalue)
391 387 }
392 388 # exc_msg = self.session.msg(u'pyerr', exc_content, parent)
393 389 self.session.send(self.pub_stream, u'pyerr', exc_content, parent=parent)
394 390 reply_content = exc_content
395 391 result_buf = []
392
393 if etype is UnmetDependency:
394 sub = {'dependencies_met' : False}
396 395 else:
397 396 reply_content = {'status' : 'ok'}
398 397 # reply_msg = self.session.msg(u'execute_reply', reply_content, parent)
399 398 # self.reply_socket.send(ident, zmq.SNDMORE)
400 399 # self.reply_socket.send_json(reply_msg)
401 reply_msg = self.session.send(stream, u'apply_reply', reply_content, parent=parent, ident=ident,buffers=result_buf)
400 reply_msg = self.session.send(stream, u'apply_reply', reply_content,
401 parent=parent, ident=ident,buffers=result_buf, subheader=sub)
402 402 # print>>sys.__stdout__, Message(reply_msg)
403 if reply_msg['content']['status'] == u'error':
404 self.abort_queues()
403 # if reply_msg['content']['status'] == u'error':
404 # self.abort_queues()
405 405
406 406 def dispatch_queue(self, stream, msg):
407 407 self.control_stream.flush()
408 408 idents,msg = self.session.feed_identities(msg, copy=False)
409 409 msg = self.session.unpack_message(msg, content=True, copy=False)
410 410
411 411 header = msg['header']
412 412 msg_id = header['msg_id']
413 dependencies = header.get('dependencies', [])
414 413 if self.check_aborted(msg_id):
415 414 self.aborted.remove(msg_id)
416 415 # is it safe to assume a msg_id will not be resubmitted?
417 416 reply_type = msg['msg_type'].split('_')[0] + '_reply'
418 417 reply_msg = self.session.send(stream, reply_type,
419 418 content={'status' : 'aborted'}, parent=msg, ident=idents)
420 419 return
421 if not self.check_dependencies(dependencies):
422 return self.unmet_dependencies(stream, idents, msg)
423 420 handler = self.queue_handlers.get(msg['msg_type'], None)
424 421 if handler is None:
425 422 print >> sys.__stderr__, "UNKNOWN MESSAGE TYPE:", msg
426 423 else:
427 424 handler(stream, idents, msg)
428 425
429 426 def start(self):
430 427 #### stream mode:
431 428 if self.control_stream:
432 429 self.control_stream.on_recv(self.dispatch_control, copy=False)
433 430 self.control_stream.on_err(printer)
434 431 if self.reply_stream:
435 432 self.reply_stream.on_recv(lambda msg:
436 433 self.dispatch_queue(self.reply_stream, msg), copy=False)
437 434 self.reply_stream.on_err(printer)
438 435 if self.task_stream:
439 436 self.task_stream.on_recv(lambda msg:
440 437 self.dispatch_queue(self.task_stream, msg), copy=False)
441 438 self.task_stream.on_err(printer)
442 439
443 440 #### while True mode:
444 441 # while True:
445 442 # idle = True
446 443 # try:
447 444 # msg = self.reply_stream.socket.recv_multipart(
448 445 # zmq.NOBLOCK, copy=False)
449 446 # except zmq.ZMQError, e:
450 447 # if e.errno != zmq.EAGAIN:
451 448 # raise e
452 449 # else:
453 450 # idle=False
454 451 # self.dispatch_queue(self.reply_stream, msg)
455 452 #
456 453 # if not self.task_stream.empty():
457 454 # idle=False
458 455 # msg = self.task_stream.recv_multipart()
459 456 # self.dispatch_queue(self.task_stream, msg)
460 457 # if idle:
461 458 # # don't busywait
462 459 # time.sleep(1e-3)
463 460
464 461
465 462 def main():
466 463 raise Exception("Don't run me anymore")
467 464 loop = ioloop.IOLoop.instance()
468 465 c = zmq.Context()
469 466
470 467 ip = '127.0.0.1'
471 468 port_base = 5575
472 469 connection = ('tcp://%s' % ip) + ':%i'
473 470 rep_conn = connection % port_base
474 471 pub_conn = connection % (port_base+1)
475 472
476 473 print >>sys.__stdout__, "Starting the kernel..."
477 474 # print >>sys.__stdout__, "XREQ Channel:", rep_conn
478 475 # print >>sys.__stdout__, "PUB Channel:", pub_conn
479 476
480 477 session = StreamSession(username=u'kernel')
481 478
482 479 reply_socket = c.socket(zmq.XREQ)
483 480 reply_socket.connect(rep_conn)
484 481
485 482 pub_socket = c.socket(zmq.PUB)
486 483 pub_socket.connect(pub_conn)
487 484
488 485 stdout = OutStream(session, pub_socket, u'stdout')
489 486 stderr = OutStream(session, pub_socket, u'stderr')
490 487 sys.stdout = stdout
491 488 sys.stderr = stderr
492 489
493 490 display_hook = DisplayHook(session, pub_socket)
494 491 sys.displayhook = display_hook
495 492 reply_stream = zmqstream.ZMQStream(reply_socket,loop)
496 493 pub_stream = zmqstream.ZMQStream(pub_socket,loop)
497 494 kernel = Kernel(session, reply_stream, pub_stream)
498 495
499 496 # For debugging convenience, put sleep and a string in the namespace, so we
500 497 # have them every time we start.
501 498 kernel.user_ns['sleep'] = time.sleep
502 499 kernel.user_ns['s'] = 'Test string'
503 500
504 501 print >>sys.__stdout__, "Use Ctrl-\\ (NOT Ctrl-C!) to terminate."
505 502 kernel.start()
506 503 loop.start()
507 504
508 505
509 506 if __name__ == '__main__':
510 507 main()
@@ -1,447 +1,447
1 1 #!/usr/bin/env python
2 2 """edited session.py to work with streams, and move msg_type to the header
3 3 """
4 4
5 5
6 6 import os
7 7 import sys
8 8 import traceback
9 9 import pprint
10 10 import uuid
11 11
12 12 import zmq
13 13 from zmq.utils import jsonapi
14 14 from zmq.eventloop.zmqstream import ZMQStream
15 15
16 16 from IPython.zmq.pickleutil import can, uncan, canSequence, uncanSequence
17 17 from IPython.zmq.newserialized import serialize, unserialize
18 18
19 19 try:
20 20 import cPickle
21 21 pickle = cPickle
22 22 except:
23 23 cPickle = None
24 24 import pickle
25 25
26 26 # packer priority: jsonlib[2], cPickle, simplejson/json, pickle
27 27 json_name = '' if not jsonapi.jsonmod else jsonapi.jsonmod.__name__
28 28 if json_name in ('jsonlib', 'jsonlib2'):
29 29 use_json = True
30 30 elif json_name:
31 31 if cPickle is None:
32 32 use_json = True
33 33 else:
34 34 use_json = False
35 35 else:
36 36 use_json = False
37 37
38 38 if use_json:
39 39 default_packer = jsonapi.dumps
40 40 default_unpacker = jsonapi.loads
41 41 else:
42 42 default_packer = lambda o: pickle.dumps(o,-1)
43 43 default_unpacker = pickle.loads
44 44
45 45
46 46 DELIM="<IDS|MSG>"
47 47
48 48 def wrap_exception():
49 49 etype, evalue, tb = sys.exc_info()
50 50 tb = traceback.format_exception(etype, evalue, tb)
51 51 exc_content = {
52 52 u'status' : u'error',
53 53 u'traceback' : tb,
54 54 u'etype' : unicode(etype),
55 55 u'evalue' : unicode(evalue)
56 56 }
57 57 return exc_content
58 58
59 59 class KernelError(Exception):
60 60 pass
61 61
62 62 def unwrap_exception(content):
63 63 err = KernelError(content['etype'], content['evalue'])
64 64 err.evalue = content['evalue']
65 65 err.etype = content['etype']
66 66 err.traceback = ''.join(content['traceback'])
67 67 return err
68 68
69 69
70 70 class Message(object):
71 71 """A simple message object that maps dict keys to attributes.
72 72
73 73 A Message can be created from a dict and a dict from a Message instance
74 74 simply by calling dict(msg_obj)."""
75 75
76 76 def __init__(self, msg_dict):
77 77 dct = self.__dict__
78 78 for k, v in dict(msg_dict).iteritems():
79 79 if isinstance(v, dict):
80 80 v = Message(v)
81 81 dct[k] = v
82 82
83 83 # Having this iterator lets dict(msg_obj) work out of the box.
84 84 def __iter__(self):
85 85 return iter(self.__dict__.iteritems())
86 86
87 87 def __repr__(self):
88 88 return repr(self.__dict__)
89 89
90 90 def __str__(self):
91 91 return pprint.pformat(self.__dict__)
92 92
93 93 def __contains__(self, k):
94 94 return k in self.__dict__
95 95
96 96 def __getitem__(self, k):
97 97 return self.__dict__[k]
98 98
99 99
100 100 def msg_header(msg_id, msg_type, username, session):
101 101 return locals()
102 102 # return {
103 103 # 'msg_id' : msg_id,
104 104 # 'msg_type': msg_type,
105 105 # 'username' : username,
106 106 # 'session' : session
107 107 # }
108 108
109 109
110 110 def extract_header(msg_or_header):
111 111 """Given a message or header, return the header."""
112 112 if not msg_or_header:
113 113 return {}
114 114 try:
115 115 # See if msg_or_header is the entire message.
116 116 h = msg_or_header['header']
117 117 except KeyError:
118 118 try:
119 119 # See if msg_or_header is just the header
120 120 h = msg_or_header['msg_id']
121 121 except KeyError:
122 122 raise
123 123 else:
124 124 h = msg_or_header
125 125 if not isinstance(h, dict):
126 126 h = dict(h)
127 127 return h
128 128
129 129 def rekey(dikt):
130 130 """rekey a dict that has been forced to use str keys where there should be
131 131 ints by json. This belongs in the jsonutil added by fperez."""
132 132 for k in dikt.iterkeys():
133 133 if isinstance(k, str):
134 134 ik=fk=None
135 135 try:
136 136 ik = int(k)
137 137 except ValueError:
138 138 try:
139 139 fk = float(k)
140 140 except ValueError:
141 141 continue
142 142 if ik is not None:
143 143 nk = ik
144 144 else:
145 145 nk = fk
146 146 if nk in dikt:
147 147 raise KeyError("already have key %r"%nk)
148 148 dikt[nk] = dikt.pop(k)
149 149 return dikt
150 150
151 151 def serialize_object(obj, threshold=64e-6):
152 152 """serialize an object into a list of sendable buffers.
153 153
154 154 Returns: (pmd, bufs)
155 155 where pmd is the pickled metadata wrapper, and bufs
156 156 is a list of data buffers"""
157 157 # threshold is 100 B
158 158 databuffers = []
159 159 if isinstance(obj, (list, tuple)):
160 160 clist = canSequence(obj)
161 161 slist = map(serialize, clist)
162 162 for s in slist:
163 163 if s.typeDescriptor in ('buffer', 'ndarray') or s.getDataSize() > threshold:
164 164 databuffers.append(s.getData())
165 165 s.data = None
166 166 return pickle.dumps(slist,-1), databuffers
167 167 elif isinstance(obj, dict):
168 168 sobj = {}
169 169 for k in sorted(obj.iterkeys()):
170 170 s = serialize(can(obj[k]))
171 171 if s.getDataSize() > threshold:
172 172 databuffers.append(s.getData())
173 173 s.data = None
174 174 sobj[k] = s
175 175 return pickle.dumps(sobj,-1),databuffers
176 176 else:
177 177 s = serialize(can(obj))
178 178 if s.getDataSize() > threshold:
179 179 databuffers.append(s.getData())
180 180 s.data = None
181 181 return pickle.dumps(s,-1),databuffers
182 182
183 183
184 184 def unserialize_object(bufs):
185 185 """reconstruct an object serialized by serialize_object from data buffers"""
186 186 bufs = list(bufs)
187 187 sobj = pickle.loads(bufs.pop(0))
188 188 if isinstance(sobj, (list, tuple)):
189 189 for s in sobj:
190 190 if s.data is None:
191 191 s.data = bufs.pop(0)
192 192 return uncanSequence(map(unserialize, sobj))
193 193 elif isinstance(sobj, dict):
194 194 newobj = {}
195 195 for k in sorted(sobj.iterkeys()):
196 196 s = sobj[k]
197 197 if s.data is None:
198 198 s.data = bufs.pop(0)
199 199 newobj[k] = uncan(unserialize(s))
200 200 return newobj
201 201 else:
202 202 if sobj.data is None:
203 203 sobj.data = bufs.pop(0)
204 204 return uncan(unserialize(sobj))
205 205
206 206 def pack_apply_message(f, args, kwargs, threshold=64e-6):
207 207 """pack up a function, args, and kwargs to be sent over the wire
208 208 as a series of buffers. Any object whose data is larger than `threshold`
209 209 will not have their data copied (currently only numpy arrays support zero-copy)"""
210 210 msg = [pickle.dumps(can(f),-1)]
211 211 databuffers = [] # for large objects
212 212 sargs, bufs = serialize_object(args,threshold)
213 213 msg.append(sargs)
214 214 databuffers.extend(bufs)
215 215 skwargs, bufs = serialize_object(kwargs,threshold)
216 216 msg.append(skwargs)
217 217 databuffers.extend(bufs)
218 218 msg.extend(databuffers)
219 219 return msg
220 220
221 221 def unpack_apply_message(bufs, g=None, copy=True):
222 222 """unpack f,args,kwargs from buffers packed by pack_apply_message()
223 223 Returns: original f,args,kwargs"""
224 224 bufs = list(bufs) # allow us to pop
225 225 assert len(bufs) >= 3, "not enough buffers!"
226 226 if not copy:
227 227 for i in range(3):
228 228 bufs[i] = bufs[i].bytes
229 229 cf = pickle.loads(bufs.pop(0))
230 230 sargs = list(pickle.loads(bufs.pop(0)))
231 231 skwargs = dict(pickle.loads(bufs.pop(0)))
232 232 # print sargs, skwargs
233 f = cf.getFunction(g)
233 f = uncan(cf, g)
234 234 for sa in sargs:
235 235 if sa.data is None:
236 236 m = bufs.pop(0)
237 237 if sa.getTypeDescriptor() in ('buffer', 'ndarray'):
238 238 if copy:
239 239 sa.data = buffer(m)
240 240 else:
241 241 sa.data = m.buffer
242 242 else:
243 243 if copy:
244 244 sa.data = m
245 245 else:
246 246 sa.data = m.bytes
247 247
248 248 args = uncanSequence(map(unserialize, sargs), g)
249 249 kwargs = {}
250 250 for k in sorted(skwargs.iterkeys()):
251 251 sa = skwargs[k]
252 252 if sa.data is None:
253 253 sa.data = bufs.pop(0)
254 254 kwargs[k] = uncan(unserialize(sa), g)
255 255
256 256 return f,args,kwargs
257 257
258 258 class StreamSession(object):
259 259 """tweaked version of IPython.zmq.session.Session, for development in Parallel"""
260 260 debug=False
261 261 def __init__(self, username=None, session=None, packer=None, unpacker=None):
262 262 if username is None:
263 263 username = os.environ.get('USER','username')
264 264 self.username = username
265 265 if session is None:
266 266 self.session = str(uuid.uuid4())
267 267 else:
268 268 self.session = session
269 269 self.msg_id = str(uuid.uuid4())
270 270 if packer is None:
271 271 self.pack = default_packer
272 272 else:
273 273 if not callable(packer):
274 274 raise TypeError("packer must be callable, not %s"%type(packer))
275 275 self.pack = packer
276 276
277 277 if unpacker is None:
278 278 self.unpack = default_unpacker
279 279 else:
280 280 if not callable(unpacker):
281 281 raise TypeError("unpacker must be callable, not %s"%type(unpacker))
282 282 self.unpack = unpacker
283 283
284 284 self.none = self.pack({})
285 285
286 286 def msg_header(self, msg_type):
287 287 h = msg_header(self.msg_id, msg_type, self.username, self.session)
288 288 self.msg_id = str(uuid.uuid4())
289 289 return h
290 290
291 291 def msg(self, msg_type, content=None, parent=None, subheader=None):
292 292 msg = {}
293 293 msg['header'] = self.msg_header(msg_type)
294 294 msg['msg_id'] = msg['header']['msg_id']
295 295 msg['parent_header'] = {} if parent is None else extract_header(parent)
296 296 msg['msg_type'] = msg_type
297 297 msg['content'] = {} if content is None else content
298 298 sub = {} if subheader is None else subheader
299 299 msg['header'].update(sub)
300 300 return msg
301 301
302 302 def send(self, stream, msg_type, content=None, buffers=None, parent=None, subheader=None, ident=None):
303 303 """send a message via stream"""
304 304 msg = self.msg(msg_type, content, parent, subheader)
305 305 buffers = [] if buffers is None else buffers
306 306 to_send = []
307 307 if isinstance(ident, list):
308 308 # accept list of idents
309 309 to_send.extend(ident)
310 310 elif ident is not None:
311 311 to_send.append(ident)
312 312 to_send.append(DELIM)
313 313 to_send.append(self.pack(msg['header']))
314 314 to_send.append(self.pack(msg['parent_header']))
315 315 # if parent is None:
316 316 # to_send.append(self.none)
317 317 # else:
318 318 # to_send.append(self.pack(dict(parent)))
319 319 if content is None:
320 320 content = self.none
321 321 elif isinstance(content, dict):
322 322 content = self.pack(content)
323 323 elif isinstance(content, str):
324 324 # content is already packed, as in a relayed message
325 325 pass
326 326 else:
327 327 raise TypeError("Content incorrect type: %s"%type(content))
328 328 to_send.append(content)
329 329 flag = 0
330 330 if buffers:
331 331 flag = zmq.SNDMORE
332 332 stream.send_multipart(to_send, flag, copy=False)
333 333 for b in buffers[:-1]:
334 334 stream.send(b, flag, copy=False)
335 335 if buffers:
336 336 stream.send(buffers[-1], copy=False)
337 337 omsg = Message(msg)
338 338 if self.debug:
339 339 pprint.pprint(omsg)
340 340 pprint.pprint(to_send)
341 341 pprint.pprint(buffers)
342 342 return omsg
343 343
344 344 def recv(self, socket, mode=zmq.NOBLOCK, content=True, copy=True):
345 345 """receives and unpacks a message
346 346 returns [idents], msg"""
347 347 if isinstance(socket, ZMQStream):
348 348 socket = socket.socket
349 349 try:
350 350 msg = socket.recv_multipart(mode)
351 351 except zmq.ZMQError, e:
352 352 if e.errno == zmq.EAGAIN:
353 353 # We can convert EAGAIN to None as we know in this case
354 354 # recv_json won't return None.
355 355 return None
356 356 else:
357 357 raise
358 358 # return an actual Message object
359 359 # determine the number of idents by trying to unpack them.
360 360 # this is terrible:
361 361 idents, msg = self.feed_identities(msg, copy)
362 362 try:
363 363 return idents, self.unpack_message(msg, content=content, copy=copy)
364 364 except Exception, e:
365 365 print idents, msg
366 366 # TODO: handle it
367 367 raise e
368 368
369 369 def feed_identities(self, msg, copy=True):
370 370 """This is a completely horrible thing, but it strips the zmq
371 371 ident prefixes off of a message. It will break if any identities
372 372 are unpackable by self.unpack."""
373 373 msg = list(msg)
374 374 idents = []
375 375 while len(msg) > 3:
376 376 if copy:
377 377 s = msg[0]
378 378 else:
379 379 s = msg[0].bytes
380 380 if s == DELIM:
381 381 msg.pop(0)
382 382 break
383 383 else:
384 384 idents.append(s)
385 385 msg.pop(0)
386 386
387 387 return idents, msg
388 388
389 389 def unpack_message(self, msg, content=True, copy=True):
390 390 """return a message object from the format
391 391 sent by self.send.
392 392
393 393 parameters:
394 394
395 395 content : bool (True)
396 396 whether to unpack the content dict (True),
397 397 or leave it serialized (False)
398 398
399 399 copy : bool (True)
400 400 whether to return the bytes (True),
401 401 or the non-copying Message object in each place (False)
402 402
403 403 """
404 404 if not len(msg) >= 3:
405 405 raise TypeError("malformed message, must have at least 3 elements")
406 406 message = {}
407 407 if not copy:
408 408 for i in range(3):
409 409 msg[i] = msg[i].bytes
410 410 message['header'] = self.unpack(msg[0])
411 411 message['msg_type'] = message['header']['msg_type']
412 412 message['parent_header'] = self.unpack(msg[1])
413 413 if content:
414 414 message['content'] = self.unpack(msg[2])
415 415 else:
416 416 message['content'] = msg[2]
417 417
418 418 # message['buffers'] = msg[3:]
419 419 # else:
420 420 # message['header'] = self.unpack(msg[0].bytes)
421 421 # message['msg_type'] = message['header']['msg_type']
422 422 # message['parent_header'] = self.unpack(msg[1].bytes)
423 423 # if content:
424 424 # message['content'] = self.unpack(msg[2].bytes)
425 425 # else:
426 426 # message['content'] = msg[2].bytes
427 427
428 428 message['buffers'] = msg[3:]# [ m.buffer for m in msg[3:] ]
429 429 return message
430 430
431 431
432 432
433 433 def test_msg2obj():
434 434 am = dict(x=1)
435 435 ao = Message(am)
436 436 assert ao.x == am['x']
437 437
438 438 am['y'] = dict(z=1)
439 439 ao = Message(am)
440 440 assert ao.y.z == am['y']['z']
441 441
442 442 k1, k2 = 'y', 'z'
443 443 assert ao[k1][k2] == am[k1][k2]
444 444
445 445 am2 = dict(ao)
446 446 assert am['x'] == am2['x']
447 447 assert am['y']['z'] == am2['y']['z']
@@ -1,95 +1,115
1 1 # encoding: utf-8
2 2
3 3 """Pickle related utilities. Perhaps this should be called 'can'."""
4 4
5 5 __docformat__ = "restructuredtext en"
6 6
7 7 #-------------------------------------------------------------------------------
8 8 # Copyright (C) 2008 The IPython Development Team
9 9 #
10 10 # Distributed under the terms of the BSD License. The full license is in
11 11 # the file COPYING, distributed as part of this software.
12 12 #-------------------------------------------------------------------------------
13 13
14 14 #-------------------------------------------------------------------------------
15 15 # Imports
16 16 #-------------------------------------------------------------------------------
17 17
18 18 from types import FunctionType
19 19
20 20 # contents of codeutil should either be in here, or codeutil belongs in IPython/util
21 21 from IPython.kernel import codeutil
22 from IPython.zmq.parallel.dependency import dependent
22 23
23 24 class CannedObject(object):
24 pass
25 def __init__(self, obj, keys=[]):
26 self.keys = keys
27 self.obj = obj
28 for key in keys:
29 setattr(obj, key, can(getattr(obj, key)))
30
25 31
32 def getObject(self, g=None):
33 if g is None:
34 g = globals()
35 for key in self.keys:
36 setattr(self.obj, key, uncan(getattr(self.obj, key), g))
37 return self.obj
38
39
40
26 41 class CannedFunction(CannedObject):
27 42
28 43 def __init__(self, f):
29 44 self._checkType(f)
30 45 self.code = f.func_code
31 46
32 47 def _checkType(self, obj):
33 48 assert isinstance(obj, FunctionType), "Not a function type"
34 49
35 50 def getFunction(self, g=None):
36 51 if g is None:
37 52 g = globals()
38 53 newFunc = FunctionType(self.code, g)
39 54 return newFunc
40 55
41 56 def can(obj):
42 57 if isinstance(obj, FunctionType):
43 58 return CannedFunction(obj)
59 elif isinstance(obj, dependent):
60 keys = ('f','df')
61 return CannedObject(obj, keys=keys)
44 62 elif isinstance(obj,dict):
45 63 return canDict(obj)
46 64 elif isinstance(obj, (list,tuple)):
47 65 return canSequence(obj)
48 66 else:
49 67 return obj
50 68
51 69 def canDict(obj):
52 70 if isinstance(obj, dict):
53 71 newobj = {}
54 72 for k, v in obj.iteritems():
55 73 newobj[k] = can(v)
56 74 return newobj
57 75 else:
58 76 return obj
59 77
60 78 def canSequence(obj):
61 79 if isinstance(obj, (list, tuple)):
62 80 t = type(obj)
63 81 return t([can(i) for i in obj])
64 82 else:
65 83 return obj
66 84
67 85 def uncan(obj, g=None):
68 86 if isinstance(obj, CannedFunction):
69 87 return obj.getFunction(g)
88 elif isinstance(obj, CannedObject):
89 return obj.getObject(g)
70 90 elif isinstance(obj,dict):
71 91 return uncanDict(obj)
72 92 elif isinstance(obj, (list,tuple)):
73 93 return uncanSequence(obj)
74 94 else:
75 95 return obj
76 96
77 97 def uncanDict(obj, g=None):
78 98 if isinstance(obj, dict):
79 99 newobj = {}
80 100 for k, v in obj.iteritems():
81 101 newobj[k] = uncan(v,g)
82 102 return newobj
83 103 else:
84 104 return obj
85 105
86 106 def uncanSequence(obj, g=None):
87 107 if isinstance(obj, (list, tuple)):
88 108 t = type(obj)
89 109 return t([uncan(i,g) for i in obj])
90 110 else:
91 111 return obj
92 112
93 113
94 114 def rebindFunctionGlobals(f, glbls):
95 115 return FunctionType(f.func_code, glbls)
General Comments 0
You need to be logged in to leave comments. Login now