##// END OF EJS Templates
scheduler progress
MinRK -
Show More
@@ -1,640 +1,654 b''
1 #!/usr/bin/env python
2 """A semi-synchronous Client for the ZMQ controller"""
1 """A semi-synchronous Client for the ZMQ controller"""
2 #-----------------------------------------------------------------------------
3 # Copyright (C) 2010 The IPython Development Team
4 #
5 # Distributed under the terms of the BSD License. The full license is in
6 # the file COPYING, distributed as part of this software.
7 #-----------------------------------------------------------------------------
3
8
4 import time
9 #-----------------------------------------------------------------------------
10 # Imports
11 #-----------------------------------------------------------------------------
5
12
13 import time
6 from pprint import pprint
14 from pprint import pprint
7
15
16 import zmq
17 from zmq.eventloop import ioloop, zmqstream
18
8 from IPython.external.decorator import decorator
19 from IPython.external.decorator import decorator
9
20
10 import streamsession as ss
21 import streamsession as ss
11 import zmq
12 from zmq.eventloop import ioloop, zmqstream
13 from remotenamespace import RemoteNamespace
22 from remotenamespace import RemoteNamespace
14 from view import DirectView
23 from view import DirectView
15 from dependency import Dependency, depend, require
24 from dependency import Dependency, depend, require
16
25
17 def _push(ns):
26 def _push(ns):
18 globals().update(ns)
27 globals().update(ns)
19
28
20 def _pull(keys):
29 def _pull(keys):
21 g = globals()
30 g = globals()
22 if isinstance(keys, (list,tuple)):
31 if isinstance(keys, (list,tuple)):
23 return map(g.get, keys)
32 return map(g.get, keys)
24 else:
33 else:
25 return g.get(keys)
34 return g.get(keys)
26
35
27 def _clear():
36 def _clear():
28 globals().clear()
37 globals().clear()
29
38
30 def execute(code):
39 def execute(code):
31 exec code in globals()
40 exec code in globals()
32
41
33 # decorators for methods:
42 # decorators for methods:
34 @decorator
43 @decorator
35 def spinfirst(f,self,*args,**kwargs):
44 def spinfirst(f,self,*args,**kwargs):
36 self.spin()
45 self.spin()
37 return f(self, *args, **kwargs)
46 return f(self, *args, **kwargs)
38
47
39 @decorator
48 @decorator
40 def defaultblock(f, self, *args, **kwargs):
49 def defaultblock(f, self, *args, **kwargs):
41 block = kwargs.get('block',None)
50 block = kwargs.get('block',None)
42 block = self.block if block is None else block
51 block = self.block if block is None else block
43 saveblock = self.block
52 saveblock = self.block
44 self.block = block
53 self.block = block
45 ret = f(self, *args, **kwargs)
54 ret = f(self, *args, **kwargs)
46 self.block = saveblock
55 self.block = saveblock
47 return ret
56 return ret
48
57
49 class AbortedTask(object):
58 class AbortedTask(object):
50 def __init__(self, msg_id):
59 def __init__(self, msg_id):
51 self.msg_id = msg_id
60 self.msg_id = msg_id
52 # @decorator
61 # @decorator
53 # def checktargets(f):
62 # def checktargets(f):
54 # @wraps(f)
63 # @wraps(f)
55 # def checked_method(self, *args, **kwargs):
64 # def checked_method(self, *args, **kwargs):
56 # self._build_targets(kwargs['targets'])
65 # self._build_targets(kwargs['targets'])
57 # return f(self, *args, **kwargs)
66 # return f(self, *args, **kwargs)
58 # return checked_method
67 # return checked_method
59
68
60
69
61 # class _ZMQEventLoopThread(threading.Thread):
70 # class _ZMQEventLoopThread(threading.Thread):
62 #
71 #
63 # def __init__(self, loop):
72 # def __init__(self, loop):
64 # self.loop = loop
73 # self.loop = loop
65 # threading.Thread.__init__(self)
74 # threading.Thread.__init__(self)
66 #
75 #
67 # def run(self):
76 # def run(self):
68 # self.loop.start()
77 # self.loop.start()
69 #
78 #
70 class Client(object):
79 class Client(object):
71 """A semi-synchronous client to the IPython ZMQ controller
80 """A semi-synchronous client to the IPython ZMQ controller
72
81
73 Attributes
82 Attributes
74 ----------
83 ----------
75 ids : set
84 ids : set
76 a set of engine IDs
85 a set of engine IDs
77 requesting the ids attribute always synchronizes
86 requesting the ids attribute always synchronizes
78 the registration state. To request ids without synchronization,
87 the registration state. To request ids without synchronization,
79 use _ids
88 use _ids
80
89
81 history : list of msg_ids
90 history : list of msg_ids
82 a list of msg_ids, keeping track of all the execution
91 a list of msg_ids, keeping track of all the execution
83 messages you have submitted
92 messages you have submitted
84
93
85 outstanding : set of msg_ids
94 outstanding : set of msg_ids
86 a set of msg_ids that have been submitted, but whose
95 a set of msg_ids that have been submitted, but whose
87 results have not been received
96 results have not been received
88
97
89 results : dict
98 results : dict
90 a dict of all our results, keyed by msg_id
99 a dict of all our results, keyed by msg_id
91
100
92 block : bool
101 block : bool
93 determines default behavior when block not specified
102 determines default behavior when block not specified
94 in execution methods
103 in execution methods
95
104
96 Methods
105 Methods
97 -------
106 -------
98 spin : flushes incoming results and registration state changes
107 spin : flushes incoming results and registration state changes
99 control methods spin, and requesting `ids` also ensures up to date
108 control methods spin, and requesting `ids` also ensures up to date
100
109
101 barrier : wait on one or more msg_ids
110 barrier : wait on one or more msg_ids
102
111
103 execution methods: apply/apply_bound/apply_to
112 execution methods: apply/apply_bound/apply_to
104 legacy: execute, run
113 legacy: execute, run
105
114
106 query methods: queue_status, get_result
115 query methods: queue_status, get_result
107
116
108 control methods: abort, kill
117 control methods: abort, kill
109
118
110
119
111
120
112 """
121 """
113
122
114
123
115 _connected=False
124 _connected=False
116 _engines=None
125 _engines=None
117 registration_socket=None
126 registration_socket=None
118 query_socket=None
127 query_socket=None
119 control_socket=None
128 control_socket=None
120 notification_socket=None
129 notification_socket=None
121 queue_socket=None
130 queue_socket=None
122 task_socket=None
131 task_socket=None
123 block = False
132 block = False
124 outstanding=None
133 outstanding=None
125 results = None
134 results = None
126 history = None
135 history = None
127 debug = False
136 debug = False
128
137
129 def __init__(self, addr, context=None, username=None, debug=False):
138 def __init__(self, addr, context=None, username=None, debug=False):
130 if context is None:
139 if context is None:
131 context = zmq.Context()
140 context = zmq.Context()
132 self.context = context
141 self.context = context
133 self.addr = addr
142 self.addr = addr
134 if username is None:
143 if username is None:
135 self.session = ss.StreamSession()
144 self.session = ss.StreamSession()
136 else:
145 else:
137 self.session = ss.StreamSession(username)
146 self.session = ss.StreamSession(username)
138 self.registration_socket = self.context.socket(zmq.PAIR)
147 self.registration_socket = self.context.socket(zmq.PAIR)
139 self.registration_socket.setsockopt(zmq.IDENTITY, self.session.session)
148 self.registration_socket.setsockopt(zmq.IDENTITY, self.session.session)
140 self.registration_socket.connect(addr)
149 self.registration_socket.connect(addr)
141 self._engines = {}
150 self._engines = {}
142 self._ids = set()
151 self._ids = set()
143 self.outstanding=set()
152 self.outstanding=set()
144 self.results = {}
153 self.results = {}
145 self.history = []
154 self.history = []
146 self.debug = debug
155 self.debug = debug
147 self.session.debug = debug
156 self.session.debug = debug
148
157
149 self._notification_handlers = {'registration_notification' : self._register_engine,
158 self._notification_handlers = {'registration_notification' : self._register_engine,
150 'unregistration_notification' : self._unregister_engine,
159 'unregistration_notification' : self._unregister_engine,
151 }
160 }
152 self._queue_handlers = {'execute_reply' : self._handle_execute_reply,
161 self._queue_handlers = {'execute_reply' : self._handle_execute_reply,
153 'apply_reply' : self._handle_apply_reply}
162 'apply_reply' : self._handle_apply_reply}
154 self._connect()
163 self._connect()
155
164
156
165
157 @property
166 @property
158 def ids(self):
167 def ids(self):
159 self._flush_notifications()
168 self._flush_notifications()
160 return self._ids
169 return self._ids
161
170
162 def _update_engines(self, engines):
171 def _update_engines(self, engines):
163 for k,v in engines.iteritems():
172 for k,v in engines.iteritems():
164 eid = int(k)
173 eid = int(k)
165 self._engines[eid] = bytes(v) # force not unicode
174 self._engines[eid] = bytes(v) # force not unicode
166 self._ids.add(eid)
175 self._ids.add(eid)
167
176
168 def _build_targets(self, targets):
177 def _build_targets(self, targets):
169 if targets is None:
178 if targets is None:
170 targets = self._ids
179 targets = self._ids
171 elif isinstance(targets, str):
180 elif isinstance(targets, str):
172 if targets.lower() == 'all':
181 if targets.lower() == 'all':
173 targets = self._ids
182 targets = self._ids
174 else:
183 else:
175 raise TypeError("%r not valid str target, must be 'all'"%(targets))
184 raise TypeError("%r not valid str target, must be 'all'"%(targets))
176 elif isinstance(targets, int):
185 elif isinstance(targets, int):
177 targets = [targets]
186 targets = [targets]
178 return [self._engines[t] for t in targets], list(targets)
187 return [self._engines[t] for t in targets], list(targets)
179
188
180 def _connect(self):
189 def _connect(self):
181 """setup all our socket connections to the controller"""
190 """setup all our socket connections to the controller"""
182 if self._connected:
191 if self._connected:
183 return
192 return
184 self._connected=True
193 self._connected=True
185 self.session.send(self.registration_socket, 'connection_request')
194 self.session.send(self.registration_socket, 'connection_request')
186 idents,msg = self.session.recv(self.registration_socket,mode=0)
195 idents,msg = self.session.recv(self.registration_socket,mode=0)
187 if self.debug:
196 if self.debug:
188 pprint(msg)
197 pprint(msg)
189 msg = ss.Message(msg)
198 msg = ss.Message(msg)
190 content = msg.content
199 content = msg.content
191 if content.status == 'ok':
200 if content.status == 'ok':
192 if content.queue:
201 if content.queue:
193 self.queue_socket = self.context.socket(zmq.PAIR)
202 self.queue_socket = self.context.socket(zmq.PAIR)
194 self.queue_socket.setsockopt(zmq.IDENTITY, self.session.session)
203 self.queue_socket.setsockopt(zmq.IDENTITY, self.session.session)
195 self.queue_socket.connect(content.queue)
204 self.queue_socket.connect(content.queue)
196 if content.task:
205 if content.task:
197 self.task_socket = self.context.socket(zmq.PAIR)
206 self.task_socket = self.context.socket(zmq.PAIR)
198 self.task_socket.setsockopt(zmq.IDENTITY, self.session.session)
207 self.task_socket.setsockopt(zmq.IDENTITY, self.session.session)
199 self.task_socket.connect(content.task)
208 self.task_socket.connect(content.task)
200 if content.notification:
209 if content.notification:
201 self.notification_socket = self.context.socket(zmq.SUB)
210 self.notification_socket = self.context.socket(zmq.SUB)
202 self.notification_socket.connect(content.notification)
211 self.notification_socket.connect(content.notification)
203 self.notification_socket.setsockopt(zmq.SUBSCRIBE, "")
212 self.notification_socket.setsockopt(zmq.SUBSCRIBE, "")
204 if content.query:
213 if content.query:
205 self.query_socket = self.context.socket(zmq.PAIR)
214 self.query_socket = self.context.socket(zmq.PAIR)
206 self.query_socket.setsockopt(zmq.IDENTITY, self.session.session)
215 self.query_socket.setsockopt(zmq.IDENTITY, self.session.session)
207 self.query_socket.connect(content.query)
216 self.query_socket.connect(content.query)
208 if content.control:
217 if content.control:
209 self.control_socket = self.context.socket(zmq.PAIR)
218 self.control_socket = self.context.socket(zmq.PAIR)
210 self.control_socket.setsockopt(zmq.IDENTITY, self.session.session)
219 self.control_socket.setsockopt(zmq.IDENTITY, self.session.session)
211 self.control_socket.connect(content.control)
220 self.control_socket.connect(content.control)
212 self._update_engines(dict(content.engines))
221 self._update_engines(dict(content.engines))
213
222
214 else:
223 else:
215 self._connected = False
224 self._connected = False
216 raise Exception("Failed to connect!")
225 raise Exception("Failed to connect!")
217
226
218 #### handlers and callbacks for incoming messages #######
227 #### handlers and callbacks for incoming messages #######
219 def _register_engine(self, msg):
228 def _register_engine(self, msg):
220 content = msg['content']
229 content = msg['content']
221 eid = content['id']
230 eid = content['id']
222 d = {eid : content['queue']}
231 d = {eid : content['queue']}
223 self._update_engines(d)
232 self._update_engines(d)
224 self._ids.add(int(eid))
233 self._ids.add(int(eid))
225
234
226 def _unregister_engine(self, msg):
235 def _unregister_engine(self, msg):
227 # print 'unregister',msg
236 # print 'unregister',msg
228 content = msg['content']
237 content = msg['content']
229 eid = int(content['id'])
238 eid = int(content['id'])
230 if eid in self._ids:
239 if eid in self._ids:
231 self._ids.remove(eid)
240 self._ids.remove(eid)
232 self._engines.pop(eid)
241 self._engines.pop(eid)
233
242
234 def _handle_execute_reply(self, msg):
243 def _handle_execute_reply(self, msg):
235 # msg_id = msg['msg_id']
244 # msg_id = msg['msg_id']
236 parent = msg['parent_header']
245 parent = msg['parent_header']
237 msg_id = parent['msg_id']
246 msg_id = parent['msg_id']
238 if msg_id not in self.outstanding:
247 if msg_id not in self.outstanding:
239 print "got unknown result: %s"%msg_id
248 print "got unknown result: %s"%msg_id
240 else:
249 else:
241 self.outstanding.remove(msg_id)
250 self.outstanding.remove(msg_id)
242 self.results[msg_id] = ss.unwrap_exception(msg['content'])
251 self.results[msg_id] = ss.unwrap_exception(msg['content'])
243
252
244 def _handle_apply_reply(self, msg):
253 def _handle_apply_reply(self, msg):
245 # pprint(msg)
254 # pprint(msg)
246 # msg_id = msg['msg_id']
255 # msg_id = msg['msg_id']
247 parent = msg['parent_header']
256 parent = msg['parent_header']
248 msg_id = parent['msg_id']
257 msg_id = parent['msg_id']
249 if msg_id not in self.outstanding:
258 if msg_id not in self.outstanding:
250 print "got unknown result: %s"%msg_id
259 print "got unknown result: %s"%msg_id
251 else:
260 else:
252 self.outstanding.remove(msg_id)
261 self.outstanding.remove(msg_id)
253 content = msg['content']
262 content = msg['content']
254 if content['status'] == 'ok':
263 if content['status'] == 'ok':
255 self.results[msg_id] = ss.unserialize_object(msg['buffers'])
264 self.results[msg_id] = ss.unserialize_object(msg['buffers'])
256 elif content['status'] == 'aborted':
265 elif content['status'] == 'aborted':
257 self.results[msg_id] = AbortedTask(msg_id)
266 self.results[msg_id] = AbortedTask(msg_id)
258 elif content['status'] == 'resubmitted':
267 elif content['status'] == 'resubmitted':
259 pass # handle resubmission
268 pass # handle resubmission
260 else:
269 else:
261 self.results[msg_id] = ss.unwrap_exception(content)
270 self.results[msg_id] = ss.unwrap_exception(content)
262
271
263 def _flush_notifications(self):
272 def _flush_notifications(self):
264 "flush incoming notifications of engine registrations"
273 "flush incoming notifications of engine registrations"
265 msg = self.session.recv(self.notification_socket, mode=zmq.NOBLOCK)
274 msg = self.session.recv(self.notification_socket, mode=zmq.NOBLOCK)
266 while msg is not None:
275 while msg is not None:
267 if self.debug:
276 if self.debug:
268 pprint(msg)
277 pprint(msg)
269 msg = msg[-1]
278 msg = msg[-1]
270 msg_type = msg['msg_type']
279 msg_type = msg['msg_type']
271 handler = self._notification_handlers.get(msg_type, None)
280 handler = self._notification_handlers.get(msg_type, None)
272 if handler is None:
281 if handler is None:
273 raise Exception("Unhandled message type: %s"%msg.msg_type)
282 raise Exception("Unhandled message type: %s"%msg.msg_type)
274 else:
283 else:
275 handler(msg)
284 handler(msg)
276 msg = self.session.recv(self.notification_socket, mode=zmq.NOBLOCK)
285 msg = self.session.recv(self.notification_socket, mode=zmq.NOBLOCK)
277
286
278 def _flush_results(self, sock):
287 def _flush_results(self, sock):
279 "flush incoming task or queue results"
288 "flush incoming task or queue results"
280 msg = self.session.recv(sock, mode=zmq.NOBLOCK)
289 msg = self.session.recv(sock, mode=zmq.NOBLOCK)
281 while msg is not None:
290 while msg is not None:
282 if self.debug:
291 if self.debug:
283 pprint(msg)
292 pprint(msg)
284 msg = msg[-1]
293 msg = msg[-1]
285 msg_type = msg['msg_type']
294 msg_type = msg['msg_type']
286 handler = self._queue_handlers.get(msg_type, None)
295 handler = self._queue_handlers.get(msg_type, None)
287 if handler is None:
296 if handler is None:
288 raise Exception("Unhandled message type: %s"%msg.msg_type)
297 raise Exception("Unhandled message type: %s"%msg.msg_type)
289 else:
298 else:
290 handler(msg)
299 handler(msg)
291 msg = self.session.recv(sock, mode=zmq.NOBLOCK)
300 msg = self.session.recv(sock, mode=zmq.NOBLOCK)
292
301
293 def _flush_control(self, sock):
302 def _flush_control(self, sock):
294 "flush incoming control replies"
303 "flush incoming control replies"
295 msg = self.session.recv(sock, mode=zmq.NOBLOCK)
304 msg = self.session.recv(sock, mode=zmq.NOBLOCK)
296 while msg is not None:
305 while msg is not None:
297 if self.debug:
306 if self.debug:
298 pprint(msg)
307 pprint(msg)
299 msg = self.session.recv(sock, mode=zmq.NOBLOCK)
308 msg = self.session.recv(sock, mode=zmq.NOBLOCK)
300
309
301 ###### get/setitem ########
310 ###### get/setitem ########
302
311
303 def __getitem__(self, key):
312 def __getitem__(self, key):
304 if isinstance(key, int):
313 if isinstance(key, int):
305 if key not in self.ids:
314 if key not in self.ids:
306 raise IndexError("No such engine: %i"%key)
315 raise IndexError("No such engine: %i"%key)
307 return DirectView(self, key)
316 return DirectView(self, key)
308
317
309 if isinstance(key, slice):
318 if isinstance(key, slice):
310 indices = range(len(self.ids))[key]
319 indices = range(len(self.ids))[key]
311 ids = sorted(self._ids)
320 ids = sorted(self._ids)
312 key = [ ids[i] for i in indices ]
321 key = [ ids[i] for i in indices ]
313 # newkeys = sorted(self._ids)[thekeys[k]]
322 # newkeys = sorted(self._ids)[thekeys[k]]
314
323
315 if isinstance(key, (tuple, list, xrange)):
324 if isinstance(key, (tuple, list, xrange)):
316 _,targets = self._build_targets(list(key))
325 _,targets = self._build_targets(list(key))
317 return DirectView(self, targets)
326 return DirectView(self, targets)
318 else:
327 else:
319 raise TypeError("key by int/iterable of ints only, not %s"%(type(key)))
328 raise TypeError("key by int/iterable of ints only, not %s"%(type(key)))
320
329
321 ############ begin real methods #############
330 ############ begin real methods #############
322
331
323 def spin(self):
332 def spin(self):
324 """flush incoming notifications and execution results."""
333 """flush incoming notifications and execution results."""
325 if self.notification_socket:
334 if self.notification_socket:
326 self._flush_notifications()
335 self._flush_notifications()
327 if self.queue_socket:
336 if self.queue_socket:
328 self._flush_results(self.queue_socket)
337 self._flush_results(self.queue_socket)
329 if self.task_socket:
338 if self.task_socket:
330 self._flush_results(self.task_socket)
339 self._flush_results(self.task_socket)
331 if self.control_socket:
340 if self.control_socket:
332 self._flush_control(self.control_socket)
341 self._flush_control(self.control_socket)
333
342
334 @spinfirst
343 @spinfirst
335 def queue_status(self, targets=None, verbose=False):
344 def queue_status(self, targets=None, verbose=False):
336 """fetch the status of engine queues
345 """fetch the status of engine queues
337
346
338 Parameters
347 Parameters
339 ----------
348 ----------
340 targets : int/str/list of ints/strs
349 targets : int/str/list of ints/strs
341 the engines on which to execute
350 the engines on which to execute
342 default : all
351 default : all
343 verbose : bool
352 verbose : bool
344 whether to return lengths only, or lists of ids for each element
353 whether to return lengths only, or lists of ids for each element
345
354
346 """
355 """
347 targets = self._build_targets(targets)[1]
356 targets = self._build_targets(targets)[1]
348 content = dict(targets=targets)
357 content = dict(targets=targets)
349 self.session.send(self.query_socket, "queue_request", content=content)
358 self.session.send(self.query_socket, "queue_request", content=content)
350 idents,msg = self.session.recv(self.query_socket, 0)
359 idents,msg = self.session.recv(self.query_socket, 0)
351 if self.debug:
360 if self.debug:
352 pprint(msg)
361 pprint(msg)
353 return msg['content']
362 return msg['content']
354
363
355 @spinfirst
364 @spinfirst
356 @defaultblock
365 @defaultblock
357 def clear(self, targets=None, block=None):
366 def clear(self, targets=None, block=None):
358 """clear the namespace in target(s)"""
367 """clear the namespace in target(s)"""
359 targets = self._build_targets(targets)[0]
368 targets = self._build_targets(targets)[0]
360 print targets
361 for t in targets:
369 for t in targets:
362 self.session.send(self.control_socket, 'clear_request', content={},ident=t)
370 self.session.send(self.control_socket, 'clear_request', content={},ident=t)
363 error = False
371 error = False
364 if self.block:
372 if self.block:
365 for i in range(len(targets)):
373 for i in range(len(targets)):
366 idents,msg = self.session.recv(self.control_socket,0)
374 idents,msg = self.session.recv(self.control_socket,0)
367 if self.debug:
375 if self.debug:
368 pprint(msg)
376 pprint(msg)
369 if msg['content']['status'] != 'ok':
377 if msg['content']['status'] != 'ok':
370 error = msg['content']
378 error = msg['content']
371 if error:
379 if error:
372 return error
380 return error
373
381
374
382
375 @spinfirst
383 @spinfirst
376 @defaultblock
384 @defaultblock
377 def abort(self, msg_ids = None, targets=None, block=None):
385 def abort(self, msg_ids = None, targets=None, block=None):
378 """abort the Queues of target(s)"""
386 """abort the Queues of target(s)"""
379 targets = self._build_targets(targets)[0]
387 targets = self._build_targets(targets)[0]
380 print targets
381 if isinstance(msg_ids, basestring):
388 if isinstance(msg_ids, basestring):
382 msg_ids = [msg_ids]
389 msg_ids = [msg_ids]
383 content = dict(msg_ids=msg_ids)
390 content = dict(msg_ids=msg_ids)
384 for t in targets:
391 for t in targets:
385 self.session.send(self.control_socket, 'abort_request',
392 self.session.send(self.control_socket, 'abort_request',
386 content=content, ident=t)
393 content=content, ident=t)
387 error = False
394 error = False
388 if self.block:
395 if self.block:
389 for i in range(len(targets)):
396 for i in range(len(targets)):
390 idents,msg = self.session.recv(self.control_socket,0)
397 idents,msg = self.session.recv(self.control_socket,0)
391 if self.debug:
398 if self.debug:
392 pprint(msg)
399 pprint(msg)
393 if msg['content']['status'] != 'ok':
400 if msg['content']['status'] != 'ok':
394 error = msg['content']
401 error = msg['content']
395 if error:
402 if error:
396 return error
403 return error
397
404
398 @spinfirst
405 @spinfirst
399 @defaultblock
406 @defaultblock
400 def kill(self, targets=None, block=None):
407 def kill(self, targets=None, block=None):
401 """Terminates one or more engine processes."""
408 """Terminates one or more engine processes."""
402 targets = self._build_targets(targets)[0]
409 targets = self._build_targets(targets)[0]
403 print targets
404 for t in targets:
410 for t in targets:
405 self.session.send(self.control_socket, 'kill_request', content={},ident=t)
411 self.session.send(self.control_socket, 'kill_request', content={},ident=t)
406 error = False
412 error = False
407 if self.block:
413 if self.block:
408 for i in range(len(targets)):
414 for i in range(len(targets)):
409 idents,msg = self.session.recv(self.control_socket,0)
415 idents,msg = self.session.recv(self.control_socket,0)
410 if self.debug:
416 if self.debug:
411 pprint(msg)
417 pprint(msg)
412 if msg['content']['status'] != 'ok':
418 if msg['content']['status'] != 'ok':
413 error = msg['content']
419 error = msg['content']
414 if error:
420 if error:
415 return error
421 return error
416
422
417 @defaultblock
423 @defaultblock
418 def execute(self, code, targets='all', block=None):
424 def execute(self, code, targets='all', block=None):
419 """executes `code` on `targets` in blocking or nonblocking manner.
425 """executes `code` on `targets` in blocking or nonblocking manner.
420
426
421 Parameters
427 Parameters
422 ----------
428 ----------
423 code : str
429 code : str
424 the code string to be executed
430 the code string to be executed
425 targets : int/str/list of ints/strs
431 targets : int/str/list of ints/strs
426 the engines on which to execute
432 the engines on which to execute
427 default : all
433 default : all
428 block : bool
434 block : bool
429 whether or not to wait until done
435 whether or not to wait until done
430 """
436 """
431 # block = self.block if block is None else block
437 # block = self.block if block is None else block
432 # saveblock = self.block
438 # saveblock = self.block
433 # self.block = block
439 # self.block = block
434 result = self.apply(execute, (code,), targets=targets, block=block, bound=True)
440 result = self.apply(execute, (code,), targets=targets, block=block, bound=True)
435 # self.block = saveblock
441 # self.block = saveblock
436 return result
442 return result
437
443
438 def run(self, code, block=None):
444 def run(self, code, block=None):
439 """runs `code` on an engine.
445 """runs `code` on an engine.
440
446
441 Calls to this are load-balanced.
447 Calls to this are load-balanced.
442
448
443 Parameters
449 Parameters
444 ----------
450 ----------
445 code : str
451 code : str
446 the code string to be executed
452 the code string to be executed
447 block : bool
453 block : bool
448 whether or not to wait until done
454 whether or not to wait until done
449
455
450 """
456 """
451 result = self.apply(execute, (code,), targets=None, block=block, bound=False)
457 result = self.apply(execute, (code,), targets=None, block=block, bound=False)
452 return result
458 return result
453
459
454 def _apply_balanced(self, f, args, kwargs, bound=True, block=None,
460 def _apply_balanced(self, f, args, kwargs, bound=True, block=None,
455 after=None, follow=None):
461 after=None, follow=None):
456 """the underlying method for applying functions in a load balanced
462 """the underlying method for applying functions in a load balanced
457 manner."""
463 manner."""
458 block = block if block is not None else self.block
464 block = block if block is not None else self.block
465 if isinstance(after, Dependency):
466 after = after.as_dict()
467 elif after is None:
468 after = []
469 if isinstance(follow, Dependency):
470 follow = follow.as_dict()
471 elif follow is None:
472 follow = []
473 subheader = dict(after=after, follow=follow)
459
474
460 bufs = ss.pack_apply_message(f,args,kwargs)
475 bufs = ss.pack_apply_message(f,args,kwargs)
461 content = dict(bound=bound)
476 content = dict(bound=bound)
462 msg = self.session.send(self.task_socket, "apply_request",
477 msg = self.session.send(self.task_socket, "apply_request",
463 content=content, buffers=bufs)
478 content=content, buffers=bufs, subheader=subheader)
464 msg_id = msg['msg_id']
479 msg_id = msg['msg_id']
465 self.outstanding.add(msg_id)
480 self.outstanding.add(msg_id)
466 self.history.append(msg_id)
481 self.history.append(msg_id)
467 if block:
482 if block:
468 self.barrier(msg_id)
483 self.barrier(msg_id)
469 return self.results[msg_id]
484 return self.results[msg_id]
470 else:
485 else:
471 return msg_id
486 return msg_id
472
487
473 def _apply_direct(self, f, args, kwargs, bound=True, block=None, targets=None,
488 def _apply_direct(self, f, args, kwargs, bound=True, block=None, targets=None,
474 after=None, follow=None):
489 after=None, follow=None):
475 """Then underlying method for applying functions to specific engines."""
490 """Then underlying method for applying functions to specific engines."""
476
491
477 block = block if block is not None else self.block
492 block = block if block is not None else self.block
478
493
479 queues,targets = self._build_targets(targets)
494 queues,targets = self._build_targets(targets)
480 print queues
481 bufs = ss.pack_apply_message(f,args,kwargs)
495 bufs = ss.pack_apply_message(f,args,kwargs)
482 if isinstance(after, Dependency):
496 if isinstance(after, Dependency):
483 after = after.as_dict()
497 after = after.as_dict()
484 elif after is None:
498 elif after is None:
485 after = []
499 after = []
486 if isinstance(follow, Dependency):
500 if isinstance(follow, Dependency):
487 follow = follow.as_dict()
501 follow = follow.as_dict()
488 elif follow is None:
502 elif follow is None:
489 follow = []
503 follow = []
490 subheader = dict(after=after, follow=follow)
504 subheader = dict(after=after, follow=follow)
491 content = dict(bound=bound)
505 content = dict(bound=bound)
492 msg_ids = []
506 msg_ids = []
493 for queue in queues:
507 for queue in queues:
494 msg = self.session.send(self.queue_socket, "apply_request",
508 msg = self.session.send(self.queue_socket, "apply_request",
495 content=content, buffers=bufs,ident=queue, subheader=subheader)
509 content=content, buffers=bufs,ident=queue, subheader=subheader)
496 msg_id = msg['msg_id']
510 msg_id = msg['msg_id']
497 self.outstanding.add(msg_id)
511 self.outstanding.add(msg_id)
498 self.history.append(msg_id)
512 self.history.append(msg_id)
499 msg_ids.append(msg_id)
513 msg_ids.append(msg_id)
500 if block:
514 if block:
501 self.barrier(msg_ids)
515 self.barrier(msg_ids)
502 else:
516 else:
503 if len(msg_ids) == 1:
517 if len(msg_ids) == 1:
504 return msg_ids[0]
518 return msg_ids[0]
505 else:
519 else:
506 return msg_ids
520 return msg_ids
507 if len(msg_ids) == 1:
521 if len(msg_ids) == 1:
508 return self.results[msg_ids[0]]
522 return self.results[msg_ids[0]]
509 else:
523 else:
510 result = {}
524 result = {}
511 for target,mid in zip(targets, msg_ids):
525 for target,mid in zip(targets, msg_ids):
512 result[target] = self.results[mid]
526 result[target] = self.results[mid]
513 return result
527 return result
514
528
515 def apply(self, f, args=None, kwargs=None, bound=True, block=None, targets=None,
529 def apply(self, f, args=None, kwargs=None, bound=True, block=None, targets=None,
516 after=None, follow=None):
530 after=None, follow=None):
517 """calls f(*args, **kwargs) on a remote engine(s), returning the result.
531 """calls f(*args, **kwargs) on a remote engine(s), returning the result.
518
532
519 if self.block is False:
533 if self.block is False:
520 returns msg_id or list of msg_ids
534 returns msg_id or list of msg_ids
521 else:
535 else:
522 returns actual result of f(*args, **kwargs)
536 returns actual result of f(*args, **kwargs)
523 """
537 """
524 # enforce types of f,args,kwrags
538 # enforce types of f,args,kwrags
525 args = args if args is not None else []
539 args = args if args is not None else []
526 kwargs = kwargs if kwargs is not None else {}
540 kwargs = kwargs if kwargs is not None else {}
527 if not callable(f):
541 if not callable(f):
528 raise TypeError("f must be callable, not %s"%type(f))
542 raise TypeError("f must be callable, not %s"%type(f))
529 if not isinstance(args, (tuple, list)):
543 if not isinstance(args, (tuple, list)):
530 raise TypeError("args must be tuple or list, not %s"%type(args))
544 raise TypeError("args must be tuple or list, not %s"%type(args))
531 if not isinstance(kwargs, dict):
545 if not isinstance(kwargs, dict):
532 raise TypeError("kwargs must be dict, not %s"%type(kwargs))
546 raise TypeError("kwargs must be dict, not %s"%type(kwargs))
533
547
534 options = dict(bound=bound, block=block, after=after, follow=follow)
548 options = dict(bound=bound, block=block, after=after, follow=follow)
535
549
536 if targets is None:
550 if targets is None:
537 return self._apply_balanced(f, args, kwargs, **options)
551 return self._apply_balanced(f, args, kwargs, **options)
538 else:
552 else:
539 return self._apply_direct(f, args, kwargs, targets=targets, **options)
553 return self._apply_direct(f, args, kwargs, targets=targets, **options)
540
554
541 def push(self, ns, targets=None, block=None):
555 def push(self, ns, targets=None, block=None):
542 """push the contents of `ns` into the namespace on `target`"""
556 """push the contents of `ns` into the namespace on `target`"""
543 if not isinstance(ns, dict):
557 if not isinstance(ns, dict):
544 raise TypeError("Must be a dict, not %s"%type(ns))
558 raise TypeError("Must be a dict, not %s"%type(ns))
545 result = self.apply(_push, (ns,), targets=targets, block=block,bound=True)
559 result = self.apply(_push, (ns,), targets=targets, block=block,bound=True)
546 return result
560 return result
547
561
548 @spinfirst
562 @spinfirst
549 def pull(self, keys, targets=None, block=True):
563 def pull(self, keys, targets=None, block=True):
550 """pull objects from `target`'s namespace by `keys`"""
564 """pull objects from `target`'s namespace by `keys`"""
551
565
552 result = self.apply(_pull, (keys,), targets=targets, block=block, bound=True)
566 result = self.apply(_pull, (keys,), targets=targets, block=block, bound=True)
553 return result
567 return result
554
568
555 def barrier(self, msg_ids=None, timeout=-1):
569 def barrier(self, msg_ids=None, timeout=-1):
556 """waits on one or more `msg_ids`, for up to `timeout` seconds.
570 """waits on one or more `msg_ids`, for up to `timeout` seconds.
557
571
558 Parameters
572 Parameters
559 ----------
573 ----------
560 msg_ids : int, str, or list of ints and/or strs
574 msg_ids : int, str, or list of ints and/or strs
561 ints are indices to self.history
575 ints are indices to self.history
562 strs are msg_ids
576 strs are msg_ids
563 default: wait on all outstanding messages
577 default: wait on all outstanding messages
564 timeout : float
578 timeout : float
565 a time in seconds, after which to give up.
579 a time in seconds, after which to give up.
566 default is -1, which means no timeout
580 default is -1, which means no timeout
567
581
568 Returns
582 Returns
569 -------
583 -------
570 True : when all msg_ids are done
584 True : when all msg_ids are done
571 False : timeout reached, msg_ids still outstanding
585 False : timeout reached, msg_ids still outstanding
572 """
586 """
573 tic = time.time()
587 tic = time.time()
574 if msg_ids is None:
588 if msg_ids is None:
575 theids = self.outstanding
589 theids = self.outstanding
576 else:
590 else:
577 if isinstance(msg_ids, (int, str)):
591 if isinstance(msg_ids, (int, str)):
578 msg_ids = [msg_ids]
592 msg_ids = [msg_ids]
579 theids = set()
593 theids = set()
580 for msg_id in msg_ids:
594 for msg_id in msg_ids:
581 if isinstance(msg_id, int):
595 if isinstance(msg_id, int):
582 msg_id = self.history[msg_id]
596 msg_id = self.history[msg_id]
583 theids.add(msg_id)
597 theids.add(msg_id)
584 self.spin()
598 self.spin()
585 while theids.intersection(self.outstanding):
599 while theids.intersection(self.outstanding):
586 if timeout >= 0 and ( time.time()-tic ) > timeout:
600 if timeout >= 0 and ( time.time()-tic ) > timeout:
587 break
601 break
588 time.sleep(1e-3)
602 time.sleep(1e-3)
589 self.spin()
603 self.spin()
590 return len(theids.intersection(self.outstanding)) == 0
604 return len(theids.intersection(self.outstanding)) == 0
591
605
592 @spinfirst
606 @spinfirst
593 def get_results(self, msg_ids,status_only=False):
607 def get_results(self, msg_ids,status_only=False):
594 """returns the result of the execute or task request with `msg_id`"""
608 """returns the result of the execute or task request with `msg_id`"""
595 if not isinstance(msg_ids, (list,tuple)):
609 if not isinstance(msg_ids, (list,tuple)):
596 msg_ids = [msg_ids]
610 msg_ids = [msg_ids]
597 theids = []
611 theids = []
598 for msg_id in msg_ids:
612 for msg_id in msg_ids:
599 if isinstance(msg_id, int):
613 if isinstance(msg_id, int):
600 msg_id = self.history[msg_id]
614 msg_id = self.history[msg_id]
601 theids.append(msg_id)
615 theids.append(msg_id)
602
616
603 content = dict(msg_ids=theids, status_only=status_only)
617 content = dict(msg_ids=theids, status_only=status_only)
604 msg = self.session.send(self.query_socket, "result_request", content=content)
618 msg = self.session.send(self.query_socket, "result_request", content=content)
605 zmq.select([self.query_socket], [], [])
619 zmq.select([self.query_socket], [], [])
606 idents,msg = self.session.recv(self.query_socket, zmq.NOBLOCK)
620 idents,msg = self.session.recv(self.query_socket, zmq.NOBLOCK)
607 if self.debug:
621 if self.debug:
608 pprint(msg)
622 pprint(msg)
609
623
610 # while True:
624 # while True:
611 # try:
625 # try:
612 # except zmq.ZMQError:
626 # except zmq.ZMQError:
613 # time.sleep(1e-3)
627 # time.sleep(1e-3)
614 # continue
628 # continue
615 # else:
629 # else:
616 # break
630 # break
617 return msg['content']
631 return msg['content']
618
632
619 class AsynClient(Client):
633 class AsynClient(Client):
620 """An Asynchronous client, using the Tornado Event Loop"""
634 """An Asynchronous client, using the Tornado Event Loop"""
621 io_loop = None
635 io_loop = None
622 queue_stream = None
636 queue_stream = None
623 notifier_stream = None
637 notifier_stream = None
624
638
625 def __init__(self, addr, context=None, username=None, debug=False, io_loop=None):
639 def __init__(self, addr, context=None, username=None, debug=False, io_loop=None):
626 Client.__init__(self, addr, context, username, debug)
640 Client.__init__(self, addr, context, username, debug)
627 if io_loop is None:
641 if io_loop is None:
628 io_loop = ioloop.IOLoop.instance()
642 io_loop = ioloop.IOLoop.instance()
629 self.io_loop = io_loop
643 self.io_loop = io_loop
630
644
631 self.queue_stream = zmqstream.ZMQStream(self.queue_socket, io_loop)
645 self.queue_stream = zmqstream.ZMQStream(self.queue_socket, io_loop)
632 self.control_stream = zmqstream.ZMQStream(self.control_socket, io_loop)
646 self.control_stream = zmqstream.ZMQStream(self.control_socket, io_loop)
633 self.task_stream = zmqstream.ZMQStream(self.task_socket, io_loop)
647 self.task_stream = zmqstream.ZMQStream(self.task_socket, io_loop)
634 self.notification_stream = zmqstream.ZMQStream(self.notification_socket, io_loop)
648 self.notification_stream = zmqstream.ZMQStream(self.notification_socket, io_loop)
635
649
636 def spin(self):
650 def spin(self):
637 for stream in (self.queue_stream, self.notifier_stream,
651 for stream in (self.queue_stream, self.notifier_stream,
638 self.task_stream, self.control_stream):
652 self.task_stream, self.control_stream):
639 stream.flush()
653 stream.flush()
640 No newline at end of file
654
@@ -1,921 +1,919 b''
1 #!/usr/bin/env python
1 #!/usr/bin/env python
2 # encoding: utf-8
3
4 """The IPython Controller with 0MQ
2 """The IPython Controller with 0MQ
5 This is the master object that handles connections from engines, clients, and
3 This is the master object that handles connections from engines, clients, and
6 """
4 """
7 #-----------------------------------------------------------------------------
5 #-----------------------------------------------------------------------------
8 # Copyright (C) 2008-2009 The IPython Development Team
6 # Copyright (C) 2010 The IPython Development Team
9 #
7 #
10 # Distributed under the terms of the BSD License. The full license is in
8 # Distributed under the terms of the BSD License. The full license is in
11 # the file COPYING, distributed as part of this software.
9 # the file COPYING, distributed as part of this software.
12 #-----------------------------------------------------------------------------
10 #-----------------------------------------------------------------------------
13
11
14 #-----------------------------------------------------------------------------
12 #-----------------------------------------------------------------------------
15 # Imports
13 # Imports
16 #-----------------------------------------------------------------------------
14 #-----------------------------------------------------------------------------
17 from datetime import datetime
15 from datetime import datetime
18 import logging
16 import logging
19
17
20 import zmq
18 import zmq
21 from zmq.eventloop import zmqstream, ioloop
19 from zmq.eventloop import zmqstream, ioloop
22 import uuid
20 import uuid
23
21
24 # internal:
22 # internal:
25 from IPython.zmq.log import logger # a Logger object
23 from IPython.zmq.log import logger # a Logger object
26 from IPython.zmq.entry_point import bind_port
24 from IPython.zmq.entry_point import bind_port
27
25
28 from streamsession import Message, wrap_exception
26 from streamsession import Message, wrap_exception
29 from entry_point import (make_argument_parser, select_random_ports, split_ports,
27 from entry_point import (make_argument_parser, select_random_ports, split_ports,
30 connect_logger)
28 connect_logger)
31 # from messages import json # use the same import switches
29 # from messages import json # use the same import switches
32
30
33 #-----------------------------------------------------------------------------
31 #-----------------------------------------------------------------------------
34 # Code
32 # Code
35 #-----------------------------------------------------------------------------
33 #-----------------------------------------------------------------------------
36
34
37 class ReverseDict(dict):
35 class ReverseDict(dict):
38 """simple double-keyed subset of dict methods."""
36 """simple double-keyed subset of dict methods."""
39
37
40 def __init__(self, *args, **kwargs):
38 def __init__(self, *args, **kwargs):
41 dict.__init__(self, *args, **kwargs)
39 dict.__init__(self, *args, **kwargs)
42 self.reverse = dict()
40 self.reverse = dict()
43 for key, value in self.iteritems():
41 for key, value in self.iteritems():
44 self.reverse[value] = key
42 self.reverse[value] = key
45
43
46 def __getitem__(self, key):
44 def __getitem__(self, key):
47 try:
45 try:
48 return dict.__getitem__(self, key)
46 return dict.__getitem__(self, key)
49 except KeyError:
47 except KeyError:
50 return self.reverse[key]
48 return self.reverse[key]
51
49
52 def __setitem__(self, key, value):
50 def __setitem__(self, key, value):
53 if key in self.reverse:
51 if key in self.reverse:
54 raise KeyError("Can't have key %r on both sides!"%key)
52 raise KeyError("Can't have key %r on both sides!"%key)
55 dict.__setitem__(self, key, value)
53 dict.__setitem__(self, key, value)
56 self.reverse[value] = key
54 self.reverse[value] = key
57
55
58 def pop(self, key):
56 def pop(self, key):
59 value = dict.pop(self, key)
57 value = dict.pop(self, key)
60 self.d1.pop(value)
58 self.d1.pop(value)
61 return value
59 return value
62
60
63
61
64 class EngineConnector(object):
62 class EngineConnector(object):
65 """A simple object for accessing the various zmq connections of an object.
63 """A simple object for accessing the various zmq connections of an object.
66 Attributes are:
64 Attributes are:
67 id (int): engine ID
65 id (int): engine ID
68 uuid (str): uuid (unused?)
66 uuid (str): uuid (unused?)
69 queue (str): identity of queue's XREQ socket
67 queue (str): identity of queue's XREQ socket
70 registration (str): identity of registration XREQ socket
68 registration (str): identity of registration XREQ socket
71 heartbeat (str): identity of heartbeat XREQ socket
69 heartbeat (str): identity of heartbeat XREQ socket
72 """
70 """
73 id=0
71 id=0
74 queue=None
72 queue=None
75 control=None
73 control=None
76 registration=None
74 registration=None
77 heartbeat=None
75 heartbeat=None
78 pending=None
76 pending=None
79
77
80 def __init__(self, id, queue, registration, control, heartbeat=None):
78 def __init__(self, id, queue, registration, control, heartbeat=None):
81 logger.info("engine::Engine Connected: %i"%id)
79 logger.info("engine::Engine Connected: %i"%id)
82 self.id = id
80 self.id = id
83 self.queue = queue
81 self.queue = queue
84 self.registration = registration
82 self.registration = registration
85 self.control = control
83 self.control = control
86 self.heartbeat = heartbeat
84 self.heartbeat = heartbeat
87
85
88 class Controller(object):
86 class Controller(object):
89 """The IPython Controller with 0MQ connections
87 """The IPython Controller with 0MQ connections
90
88
91 Parameters
89 Parameters
92 ==========
90 ==========
93 loop: zmq IOLoop instance
91 loop: zmq IOLoop instance
94 session: StreamSession object
92 session: StreamSession object
95 <removed> context: zmq context for creating new connections (?)
93 <removed> context: zmq context for creating new connections (?)
96 registrar: ZMQStream for engine registration requests (XREP)
94 registrar: ZMQStream for engine registration requests (XREP)
97 clientele: ZMQStream for client connections (XREP)
95 clientele: ZMQStream for client connections (XREP)
98 not used for jobs, only query/control commands
96 not used for jobs, only query/control commands
99 queue: ZMQStream for monitoring the command queue (SUB)
97 queue: ZMQStream for monitoring the command queue (SUB)
100 heartbeat: HeartMonitor object checking the pulse of the engines
98 heartbeat: HeartMonitor object checking the pulse of the engines
101 db_stream: connection to db for out of memory logging of commands
99 db_stream: connection to db for out of memory logging of commands
102 NotImplemented
100 NotImplemented
103 queue_addr: zmq connection address of the XREP socket for the queue
101 queue_addr: zmq connection address of the XREP socket for the queue
104 hb_addr: zmq connection address of the PUB socket for heartbeats
102 hb_addr: zmq connection address of the PUB socket for heartbeats
105 task_addr: zmq connection address of the XREQ socket for task queue
103 task_addr: zmq connection address of the XREQ socket for task queue
106 """
104 """
107 # internal data structures:
105 # internal data structures:
108 ids=None # engine IDs
106 ids=None # engine IDs
109 keytable=None
107 keytable=None
110 engines=None
108 engines=None
111 clients=None
109 clients=None
112 hearts=None
110 hearts=None
113 pending=None
111 pending=None
114 results=None
112 results=None
115 tasks=None
113 tasks=None
116 completed=None
114 completed=None
117 mia=None
115 mia=None
118 incoming_registrations=None
116 incoming_registrations=None
119 registration_timeout=None
117 registration_timeout=None
120
118
121 #objects from constructor:
119 #objects from constructor:
122 loop=None
120 loop=None
123 registrar=None
121 registrar=None
124 clientelle=None
122 clientelle=None
125 queue=None
123 queue=None
126 heartbeat=None
124 heartbeat=None
127 notifier=None
125 notifier=None
128 db=None
126 db=None
129 client_addr=None
127 client_addr=None
130 engine_addrs=None
128 engine_addrs=None
131
129
132
130
133 def __init__(self, loop, session, queue, registrar, heartbeat, clientele, notifier, db, engine_addrs, client_addrs):
131 def __init__(self, loop, session, queue, registrar, heartbeat, clientele, notifier, db, engine_addrs, client_addrs):
134 """
132 """
135 # universal:
133 # universal:
136 loop: IOLoop for creating future connections
134 loop: IOLoop for creating future connections
137 session: streamsession for sending serialized data
135 session: streamsession for sending serialized data
138 # engine:
136 # engine:
139 queue: ZMQStream for monitoring queue messages
137 queue: ZMQStream for monitoring queue messages
140 registrar: ZMQStream for engine registration
138 registrar: ZMQStream for engine registration
141 heartbeat: HeartMonitor object for tracking engines
139 heartbeat: HeartMonitor object for tracking engines
142 # client:
140 # client:
143 clientele: ZMQStream for client connections
141 clientele: ZMQStream for client connections
144 # extra:
142 # extra:
145 db: ZMQStream for db connection (NotImplemented)
143 db: ZMQStream for db connection (NotImplemented)
146 engine_addrs: zmq address/protocol dict for engine connections
144 engine_addrs: zmq address/protocol dict for engine connections
147 client_addrs: zmq address/protocol dict for client connections
145 client_addrs: zmq address/protocol dict for client connections
148 """
146 """
149 self.ids = set()
147 self.ids = set()
150 self.keytable={}
148 self.keytable={}
151 self.incoming_registrations={}
149 self.incoming_registrations={}
152 self.engines = {}
150 self.engines = {}
153 self.by_ident = {}
151 self.by_ident = {}
154 self.clients = {}
152 self.clients = {}
155 self.hearts = {}
153 self.hearts = {}
156 self.mia = set()
154 self.mia = set()
157
155
158 # self.sockets = {}
156 # self.sockets = {}
159 self.loop = loop
157 self.loop = loop
160 self.session = session
158 self.session = session
161 self.registrar = registrar
159 self.registrar = registrar
162 self.clientele = clientele
160 self.clientele = clientele
163 self.queue = queue
161 self.queue = queue
164 self.heartbeat = heartbeat
162 self.heartbeat = heartbeat
165 self.notifier = notifier
163 self.notifier = notifier
166 self.db = db
164 self.db = db
167
165
168 self.client_addrs = client_addrs
166 self.client_addrs = client_addrs
169 assert isinstance(client_addrs['queue'], str)
167 assert isinstance(client_addrs['queue'], str)
170 # self.hb_addrs = hb_addrs
168 # self.hb_addrs = hb_addrs
171 self.engine_addrs = engine_addrs
169 self.engine_addrs = engine_addrs
172 assert isinstance(engine_addrs['queue'], str)
170 assert isinstance(engine_addrs['queue'], str)
173 assert len(engine_addrs['heartbeat']) == 2
171 assert len(engine_addrs['heartbeat']) == 2
174
172
175
173
176 # register our callbacks
174 # register our callbacks
177 self.registrar.on_recv(self.dispatch_register_request)
175 self.registrar.on_recv(self.dispatch_register_request)
178 self.clientele.on_recv(self.dispatch_client_msg)
176 self.clientele.on_recv(self.dispatch_client_msg)
179 self.queue.on_recv(self.dispatch_queue_traffic)
177 self.queue.on_recv(self.dispatch_queue_traffic)
180
178
181 if heartbeat is not None:
179 if heartbeat is not None:
182 heartbeat.add_heart_failure_handler(self.handle_heart_failure)
180 heartbeat.add_heart_failure_handler(self.handle_heart_failure)
183 heartbeat.add_new_heart_handler(self.handle_new_heart)
181 heartbeat.add_new_heart_handler(self.handle_new_heart)
184
182
185 if self.db is not None:
183 if self.db is not None:
186 self.db.on_recv(self.dispatch_db)
184 self.db.on_recv(self.dispatch_db)
187
185
188 self.client_handlers = {'queue_request': self.queue_status,
186 self.client_handlers = {'queue_request': self.queue_status,
189 'result_request': self.get_results,
187 'result_request': self.get_results,
190 'purge_request': self.purge_results,
188 'purge_request': self.purge_results,
191 'resubmit_request': self.resubmit_task,
189 'resubmit_request': self.resubmit_task,
192 }
190 }
193
191
194 self.registrar_handlers = {'registration_request' : self.register_engine,
192 self.registrar_handlers = {'registration_request' : self.register_engine,
195 'unregistration_request' : self.unregister_engine,
193 'unregistration_request' : self.unregister_engine,
196 'connection_request': self.connection_request,
194 'connection_request': self.connection_request,
197
195
198 }
196 }
199 #
197 #
200 # this is the stuff that will move to DB:
198 # this is the stuff that will move to DB:
201 self.results = {} # completed results
199 self.results = {} # completed results
202 self.pending = {} # pending messages, keyed by msg_id
200 self.pending = {} # pending messages, keyed by msg_id
203 self.queues = {} # pending msg_ids keyed by engine_id
201 self.queues = {} # pending msg_ids keyed by engine_id
204 self.tasks = {} # pending msg_ids submitted as tasks, keyed by client_id
202 self.tasks = {} # pending msg_ids submitted as tasks, keyed by client_id
205 self.completed = {} # completed msg_ids keyed by engine_id
203 self.completed = {} # completed msg_ids keyed by engine_id
206 self.registration_timeout = max(5000, 2*self.heartbeat.period)
204 self.registration_timeout = max(5000, 2*self.heartbeat.period)
207
205
208 logger.info("controller::created controller")
206 logger.info("controller::created controller")
209
207
210 def _new_id(self):
208 def _new_id(self):
211 """gemerate a new ID"""
209 """gemerate a new ID"""
212 newid = 0
210 newid = 0
213 incoming = [id[0] for id in self.incoming_registrations.itervalues()]
211 incoming = [id[0] for id in self.incoming_registrations.itervalues()]
214 # print newid, self.ids, self.incoming_registrations
212 # print newid, self.ids, self.incoming_registrations
215 while newid in self.ids or newid in incoming:
213 while newid in self.ids or newid in incoming:
216 newid += 1
214 newid += 1
217 return newid
215 return newid
218
216
219
220 #-----------------------------------------------------------------------------
217 #-----------------------------------------------------------------------------
221 # message validation
218 # message validation
222 #-----------------------------------------------------------------------------
219 #-----------------------------------------------------------------------------
220
223 def _validate_targets(self, targets):
221 def _validate_targets(self, targets):
224 """turn any valid targets argument into a list of integer ids"""
222 """turn any valid targets argument into a list of integer ids"""
225 if targets is None:
223 if targets is None:
226 # default to all
224 # default to all
227 targets = self.ids
225 targets = self.ids
228
226
229 if isinstance(targets, (int,str,unicode)):
227 if isinstance(targets, (int,str,unicode)):
230 # only one target specified
228 # only one target specified
231 targets = [targets]
229 targets = [targets]
232 _targets = []
230 _targets = []
233 for t in targets:
231 for t in targets:
234 # map raw identities to ids
232 # map raw identities to ids
235 if isinstance(t, (str,unicode)):
233 if isinstance(t, (str,unicode)):
236 t = self.by_ident.get(t, t)
234 t = self.by_ident.get(t, t)
237 _targets.append(t)
235 _targets.append(t)
238 targets = _targets
236 targets = _targets
239 bad_targets = [ t for t in targets if t not in self.ids ]
237 bad_targets = [ t for t in targets if t not in self.ids ]
240 if bad_targets:
238 if bad_targets:
241 raise IndexError("No Such Engine: %r"%bad_targets)
239 raise IndexError("No Such Engine: %r"%bad_targets)
242 if not targets:
240 if not targets:
243 raise IndexError("No Engines Registered")
241 raise IndexError("No Engines Registered")
244 return targets
242 return targets
245
243
246 def _validate_client_msg(self, msg):
244 def _validate_client_msg(self, msg):
247 """validates and unpacks headers of a message. Returns False if invalid,
245 """validates and unpacks headers of a message. Returns False if invalid,
248 (ident, header, parent, content)"""
246 (ident, header, parent, content)"""
249 client_id = msg[0]
247 client_id = msg[0]
250 try:
248 try:
251 msg = self.session.unpack_message(msg[1:], content=True)
249 msg = self.session.unpack_message(msg[1:], content=True)
252 except:
250 except:
253 logger.error("client::Invalid Message %s"%msg)
251 logger.error("client::Invalid Message %s"%msg)
254 return False
252 return False
255
253
256 msg_type = msg.get('msg_type', None)
254 msg_type = msg.get('msg_type', None)
257 if msg_type is None:
255 if msg_type is None:
258 return False
256 return False
259 header = msg.get('header')
257 header = msg.get('header')
260 # session doesn't handle split content for now:
258 # session doesn't handle split content for now:
261 return client_id, msg
259 return client_id, msg
262
260
263
261
264 #-----------------------------------------------------------------------------
262 #-----------------------------------------------------------------------------
265 # dispatch methods (1 per socket)
263 # dispatch methods (1 per stream)
266 #-----------------------------------------------------------------------------
264 #-----------------------------------------------------------------------------
267
265
268 def dispatch_register_request(self, msg):
266 def dispatch_register_request(self, msg):
269 """"""
267 """"""
270 logger.debug("registration::dispatch_register_request(%s)"%msg)
268 logger.debug("registration::dispatch_register_request(%s)"%msg)
271 idents,msg = self.session.feed_identities(msg)
269 idents,msg = self.session.feed_identities(msg)
272 print idents,msg, len(msg)
270 print idents,msg, len(msg)
273 try:
271 try:
274 msg = self.session.unpack_message(msg,content=True)
272 msg = self.session.unpack_message(msg,content=True)
275 except Exception, e:
273 except Exception, e:
276 logger.error("registration::got bad registration message: %s"%msg)
274 logger.error("registration::got bad registration message: %s"%msg)
277 raise e
275 raise e
278 return
276 return
279
277
280 msg_type = msg['msg_type']
278 msg_type = msg['msg_type']
281 content = msg['content']
279 content = msg['content']
282
280
283 handler = self.registrar_handlers.get(msg_type, None)
281 handler = self.registrar_handlers.get(msg_type, None)
284 if handler is None:
282 if handler is None:
285 logger.error("registration::got bad registration message: %s"%msg)
283 logger.error("registration::got bad registration message: %s"%msg)
286 else:
284 else:
287 handler(idents, msg)
285 handler(idents, msg)
288
286
289 def dispatch_queue_traffic(self, msg):
287 def dispatch_queue_traffic(self, msg):
290 """all ME and Task queue messages come through here"""
288 """all ME and Task queue messages come through here"""
291 logger.debug("queue traffic: %s"%msg[:2])
289 logger.debug("queue traffic: %s"%msg[:2])
292 switch = msg[0]
290 switch = msg[0]
293 idents, msg = self.session.feed_identities(msg[1:])
291 idents, msg = self.session.feed_identities(msg[1:])
294 if switch == 'in':
292 if switch == 'in':
295 self.save_queue_request(idents, msg)
293 self.save_queue_request(idents, msg)
296 elif switch == 'out':
294 elif switch == 'out':
297 self.save_queue_result(idents, msg)
295 self.save_queue_result(idents, msg)
298 elif switch == 'intask':
296 elif switch == 'intask':
299 self.save_task_request(idents, msg)
297 self.save_task_request(idents, msg)
300 elif switch == 'outtask':
298 elif switch == 'outtask':
301 self.save_task_result(idents, msg)
299 self.save_task_result(idents, msg)
302 elif switch == 'tracktask':
300 elif switch == 'tracktask':
303 self.save_task_destination(idents, msg)
301 self.save_task_destination(idents, msg)
304 elif switch in ('incontrol', 'outcontrol'):
302 elif switch in ('incontrol', 'outcontrol'):
305 pass
303 pass
306 else:
304 else:
307 logger.error("Invalid message topic: %s"%switch)
305 logger.error("Invalid message topic: %s"%switch)
308
306
309
307
310 def dispatch_client_msg(self, msg):
308 def dispatch_client_msg(self, msg):
311 """Route messages from clients"""
309 """Route messages from clients"""
312 idents, msg = self.session.feed_identities(msg)
310 idents, msg = self.session.feed_identities(msg)
313 client_id = idents[0]
311 client_id = idents[0]
314 try:
312 try:
315 msg = self.session.unpack_message(msg, content=True)
313 msg = self.session.unpack_message(msg, content=True)
316 except:
314 except:
317 content = wrap_exception()
315 content = wrap_exception()
318 logger.error("Bad Client Message: %s"%msg)
316 logger.error("Bad Client Message: %s"%msg)
319 self.session.send(self.clientele, "controller_error", ident=client_id,
317 self.session.send(self.clientele, "controller_error", ident=client_id,
320 content=content)
318 content=content)
321 return
319 return
322
320
323 # print client_id, header, parent, content
321 # print client_id, header, parent, content
324 #switch on message type:
322 #switch on message type:
325 msg_type = msg['msg_type']
323 msg_type = msg['msg_type']
326 logger.info("client:: client %s requested %s"%(client_id, msg_type))
324 logger.info("client:: client %s requested %s"%(client_id, msg_type))
327 handler = self.client_handlers.get(msg_type, None)
325 handler = self.client_handlers.get(msg_type, None)
328 try:
326 try:
329 assert handler is not None, "Bad Message Type: %s"%msg_type
327 assert handler is not None, "Bad Message Type: %s"%msg_type
330 except:
328 except:
331 content = wrap_exception()
329 content = wrap_exception()
332 logger.error("Bad Message Type: %s"%msg_type)
330 logger.error("Bad Message Type: %s"%msg_type)
333 self.session.send(self.clientele, "controller_error", ident=client_id,
331 self.session.send(self.clientele, "controller_error", ident=client_id,
334 content=content)
332 content=content)
335 return
333 return
336 else:
334 else:
337 handler(client_id, msg)
335 handler(client_id, msg)
338
336
339 def dispatch_db(self, msg):
337 def dispatch_db(self, msg):
340 """"""
338 """"""
341 raise NotImplementedError
339 raise NotImplementedError
342
340
343 #---------------------------------------------------------------------------
341 #---------------------------------------------------------------------------
344 # handler methods (1 per event)
342 # handler methods (1 per event)
345 #---------------------------------------------------------------------------
343 #---------------------------------------------------------------------------
346
344
347 #----------------------- Heartbeat --------------------------------------
345 #----------------------- Heartbeat --------------------------------------
348
346
349 def handle_new_heart(self, heart):
347 def handle_new_heart(self, heart):
350 """handler to attach to heartbeater.
348 """handler to attach to heartbeater.
351 Called when a new heart starts to beat.
349 Called when a new heart starts to beat.
352 Triggers completion of registration."""
350 Triggers completion of registration."""
353 logger.debug("heartbeat::handle_new_heart(%r)"%heart)
351 logger.debug("heartbeat::handle_new_heart(%r)"%heart)
354 if heart not in self.incoming_registrations:
352 if heart not in self.incoming_registrations:
355 logger.info("heartbeat::ignoring new heart: %r"%heart)
353 logger.info("heartbeat::ignoring new heart: %r"%heart)
356 else:
354 else:
357 self.finish_registration(heart)
355 self.finish_registration(heart)
358
356
359
357
360 def handle_heart_failure(self, heart):
358 def handle_heart_failure(self, heart):
361 """handler to attach to heartbeater.
359 """handler to attach to heartbeater.
362 called when a previously registered heart fails to respond to beat request.
360 called when a previously registered heart fails to respond to beat request.
363 triggers unregistration"""
361 triggers unregistration"""
364 logger.debug("heartbeat::handle_heart_failure(%r)"%heart)
362 logger.debug("heartbeat::handle_heart_failure(%r)"%heart)
365 eid = self.hearts.get(heart, None)
363 eid = self.hearts.get(heart, None)
366 queue = self.engines[eid].queue
364 queue = self.engines[eid].queue
367 if eid is None:
365 if eid is None:
368 logger.info("heartbeat::ignoring heart failure %r"%heart)
366 logger.info("heartbeat::ignoring heart failure %r"%heart)
369 else:
367 else:
370 self.unregister_engine(heart, dict(content=dict(id=eid, queue=queue)))
368 self.unregister_engine(heart, dict(content=dict(id=eid, queue=queue)))
371
369
372 #----------------------- MUX Queue Traffic ------------------------------
370 #----------------------- MUX Queue Traffic ------------------------------
373
371
374 def save_queue_request(self, idents, msg):
372 def save_queue_request(self, idents, msg):
375 queue_id, client_id = idents[:2]
373 queue_id, client_id = idents[:2]
376
374
377 try:
375 try:
378 msg = self.session.unpack_message(msg, content=False)
376 msg = self.session.unpack_message(msg, content=False)
379 except:
377 except:
380 logger.error("queue::client %r sent invalid message to %r: %s"%(client_id, queue_id, msg))
378 logger.error("queue::client %r sent invalid message to %r: %s"%(client_id, queue_id, msg))
381 return
379 return
382
380
383 eid = self.by_ident.get(queue_id, None)
381 eid = self.by_ident.get(queue_id, None)
384 if eid is None:
382 if eid is None:
385 logger.error("queue::target %r not registered"%queue_id)
383 logger.error("queue::target %r not registered"%queue_id)
386 logger.debug("queue:: valid are: %s"%(self.by_ident.keys()))
384 logger.debug("queue:: valid are: %s"%(self.by_ident.keys()))
387 return
385 return
388
386
389 header = msg['header']
387 header = msg['header']
390 msg_id = header['msg_id']
388 msg_id = header['msg_id']
391 info = dict(submit=datetime.now(),
389 info = dict(submit=datetime.now(),
392 received=None,
390 received=None,
393 engine=(eid, queue_id))
391 engine=(eid, queue_id))
394 self.pending[msg_id] = ( msg, info )
392 self.pending[msg_id] = ( msg, info )
395 self.queues[eid][0].append(msg_id)
393 self.queues[eid][0].append(msg_id)
396
394
397 def save_queue_result(self, idents, msg):
395 def save_queue_result(self, idents, msg):
398 client_id, queue_id = idents[:2]
396 client_id, queue_id = idents[:2]
399
397
400 try:
398 try:
401 msg = self.session.unpack_message(msg, content=False)
399 msg = self.session.unpack_message(msg, content=False)
402 except:
400 except:
403 logger.error("queue::engine %r sent invalid message to %r: %s"%(
401 logger.error("queue::engine %r sent invalid message to %r: %s"%(
404 queue_id,client_id, msg))
402 queue_id,client_id, msg))
405 return
403 return
406
404
407 eid = self.by_ident.get(queue_id, None)
405 eid = self.by_ident.get(queue_id, None)
408 if eid is None:
406 if eid is None:
409 logger.error("queue::unknown engine %r is sending a reply: "%queue_id)
407 logger.error("queue::unknown engine %r is sending a reply: "%queue_id)
410 logger.debug("queue:: %s"%msg[2:])
408 logger.debug("queue:: %s"%msg[2:])
411 return
409 return
412
410
413 parent = msg['parent_header']
411 parent = msg['parent_header']
414 if not parent:
412 if not parent:
415 return
413 return
416 msg_id = parent['msg_id']
414 msg_id = parent['msg_id']
417 self.results[msg_id] = msg
415 self.results[msg_id] = msg
418 if msg_id in self.pending:
416 if msg_id in self.pending:
419 self.pending.pop(msg_id)
417 self.pending.pop(msg_id)
420 self.queues[eid][0].remove(msg_id)
418 self.queues[eid][0].remove(msg_id)
421 self.completed[eid].append(msg_id)
419 self.completed[eid].append(msg_id)
422 else:
420 else:
423 logger.debug("queue:: unknown msg finished %s"%msg_id)
421 logger.debug("queue:: unknown msg finished %s"%msg_id)
424
422
425 #--------------------- Task Queue Traffic ------------------------------
423 #--------------------- Task Queue Traffic ------------------------------
426
424
427 def save_task_request(self, idents, msg):
425 def save_task_request(self, idents, msg):
428 client_id = idents[0]
426 client_id = idents[0]
429
427
430 try:
428 try:
431 msg = self.session.unpack_message(msg, content=False)
429 msg = self.session.unpack_message(msg, content=False)
432 except:
430 except:
433 logger.error("task::client %r sent invalid task message: %s"%(
431 logger.error("task::client %r sent invalid task message: %s"%(
434 client_id, msg))
432 client_id, msg))
435 return
433 return
436
434
437 header = msg['header']
435 header = msg['header']
438 msg_id = header['msg_id']
436 msg_id = header['msg_id']
439 self.mia.add(msg_id)
437 self.mia.add(msg_id)
440 self.pending[msg_id] = msg
438 self.pending[msg_id] = msg
441 if not self.tasks.has_key(client_id):
439 if not self.tasks.has_key(client_id):
442 self.tasks[client_id] = []
440 self.tasks[client_id] = []
443 self.tasks[client_id].append(msg_id)
441 self.tasks[client_id].append(msg_id)
444
442
445 def save_task_result(self, idents, msg):
443 def save_task_result(self, idents, msg):
446 client_id = idents[0]
444 client_id = idents[0]
447 try:
445 try:
448 msg = self.session.unpack_message(msg, content=False)
446 msg = self.session.unpack_message(msg, content=False)
449 except:
447 except:
450 logger.error("task::invalid task result message send to %r: %s"%(
448 logger.error("task::invalid task result message send to %r: %s"%(
451 client_id, msg))
449 client_id, msg))
452 return
450 return
453
451
454 parent = msg['parent_header']
452 parent = msg['parent_header']
455 if not parent:
453 if not parent:
456 # print msg
454 # print msg
457 # logger.warn("")
455 # logger.warn("")
458 return
456 return
459 msg_id = parent['msg_id']
457 msg_id = parent['msg_id']
460 self.results[msg_id] = msg
458 self.results[msg_id] = msg
461 if msg_id in self.pending:
459 if msg_id in self.pending:
462 self.pending.pop(msg_id)
460 self.pending.pop(msg_id)
463 if msg_id in self.mia:
461 if msg_id in self.mia:
464 self.mia.remove(msg_id)
462 self.mia.remove(msg_id)
465 else:
463 else:
466 logger.debug("task:: unknown task %s finished"%msg_id)
464 logger.debug("task::unknown task %s finished"%msg_id)
467
465
468 def save_task_destination(self, idents, msg):
466 def save_task_destination(self, idents, msg):
469 try:
467 try:
470 msg = self.session.unpack_message(msg, content=True)
468 msg = self.session.unpack_message(msg, content=True)
471 except:
469 except:
472 logger.error("task::invalid task tracking message")
470 logger.error("task::invalid task tracking message")
473 return
471 return
474 content = msg['content']
472 content = msg['content']
475 print content
473 print content
476 msg_id = content['msg_id']
474 msg_id = content['msg_id']
477 engine_uuid = content['engine_id']
475 engine_uuid = content['engine_id']
478 for eid,queue_id in self.keytable.iteritems():
476 for eid,queue_id in self.keytable.iteritems():
479 if queue_id == engine_uuid:
477 if queue_id == engine_uuid:
480 break
478 break
481
479
482 logger.info("task:: task %s arrived on %s"%(msg_id, eid))
480 logger.info("task::task %s arrived on %s"%(msg_id, eid))
483 if msg_id in self.mia:
481 if msg_id in self.mia:
484 self.mia.remove(msg_id)
482 self.mia.remove(msg_id)
485 else:
483 else:
486 logger.debug("task::task %s not listed as MIA?!"%(msg_id))
484 logger.debug("task::task %s not listed as MIA?!"%(msg_id))
487 self.tasks[engine_uuid].append(msg_id)
485 self.tasks[engine_uuid].append(msg_id)
488
486
489 def mia_task_request(self, idents, msg):
487 def mia_task_request(self, idents, msg):
490 client_id = idents[0]
488 client_id = idents[0]
491 content = dict(mia=self.mia,status='ok')
489 content = dict(mia=self.mia,status='ok')
492 self.session.send('mia_reply', content=content, idents=client_id)
490 self.session.send('mia_reply', content=content, idents=client_id)
493
491
494
492
495
493
496 #-------------------- Registration -----------------------------
494 #-------------------- Registration -----------------------------
497
495
498 def connection_request(self, client_id, msg):
496 def connection_request(self, client_id, msg):
499 """reply with connection addresses for clients"""
497 """reply with connection addresses for clients"""
500 logger.info("client::client %s connected"%client_id)
498 logger.info("client::client %s connected"%client_id)
501 content = dict(status='ok')
499 content = dict(status='ok')
502 content.update(self.client_addrs)
500 content.update(self.client_addrs)
503 jsonable = {}
501 jsonable = {}
504 for k,v in self.keytable.iteritems():
502 for k,v in self.keytable.iteritems():
505 jsonable[str(k)] = v
503 jsonable[str(k)] = v
506 content['engines'] = jsonable
504 content['engines'] = jsonable
507 self.session.send(self.registrar, 'connection_reply', content, parent=msg, ident=client_id)
505 self.session.send(self.registrar, 'connection_reply', content, parent=msg, ident=client_id)
508
506
509 def register_engine(self, reg, msg):
507 def register_engine(self, reg, msg):
510 """register an engine"""
508 """register an engine"""
511 content = msg['content']
509 content = msg['content']
512 try:
510 try:
513 queue = content['queue']
511 queue = content['queue']
514 except KeyError:
512 except KeyError:
515 logger.error("registration::queue not specified")
513 logger.error("registration::queue not specified")
516 return
514 return
517 heart = content.get('heartbeat', None)
515 heart = content.get('heartbeat', None)
518 """register a new engine, and create the socket(s) necessary"""
516 """register a new engine, and create the socket(s) necessary"""
519 eid = self._new_id()
517 eid = self._new_id()
520 # print (eid, queue, reg, heart)
518 # print (eid, queue, reg, heart)
521
519
522 logger.debug("registration::register_engine(%i, %r, %r, %r)"%(eid, queue, reg, heart))
520 logger.debug("registration::register_engine(%i, %r, %r, %r)"%(eid, queue, reg, heart))
523
521
524 content = dict(id=eid,status='ok')
522 content = dict(id=eid,status='ok')
525 content.update(self.engine_addrs)
523 content.update(self.engine_addrs)
526 # check if requesting available IDs:
524 # check if requesting available IDs:
527 if queue in self.by_ident:
525 if queue in self.by_ident:
528 content = {'status': 'error', 'reason': "queue_id %r in use"%queue}
526 content = {'status': 'error', 'reason': "queue_id %r in use"%queue}
529 elif heart in self.hearts: # need to check unique hearts?
527 elif heart in self.hearts: # need to check unique hearts?
530 content = {'status': 'error', 'reason': "heart_id %r in use"%heart}
528 content = {'status': 'error', 'reason': "heart_id %r in use"%heart}
531 else:
529 else:
532 for h, pack in self.incoming_registrations.iteritems():
530 for h, pack in self.incoming_registrations.iteritems():
533 if heart == h:
531 if heart == h:
534 content = {'status': 'error', 'reason': "heart_id %r in use"%heart}
532 content = {'status': 'error', 'reason': "heart_id %r in use"%heart}
535 break
533 break
536 elif queue == pack[1]:
534 elif queue == pack[1]:
537 content = {'status': 'error', 'reason': "queue_id %r in use"%queue}
535 content = {'status': 'error', 'reason': "queue_id %r in use"%queue}
538 break
536 break
539
537
540 msg = self.session.send(self.registrar, "registration_reply",
538 msg = self.session.send(self.registrar, "registration_reply",
541 content=content,
539 content=content,
542 ident=reg)
540 ident=reg)
543
541
544 if content['status'] == 'ok':
542 if content['status'] == 'ok':
545 if heart in self.heartbeat.hearts:
543 if heart in self.heartbeat.hearts:
546 # already beating
544 # already beating
547 self.incoming_registrations[heart] = (eid,queue,reg,None)
545 self.incoming_registrations[heart] = (eid,queue,reg,None)
548 self.finish_registration(heart)
546 self.finish_registration(heart)
549 else:
547 else:
550 purge = lambda : self._purge_stalled_registration(heart)
548 purge = lambda : self._purge_stalled_registration(heart)
551 dc = ioloop.DelayedCallback(purge, self.registration_timeout, self.loop)
549 dc = ioloop.DelayedCallback(purge, self.registration_timeout, self.loop)
552 dc.start()
550 dc.start()
553 self.incoming_registrations[heart] = (eid,queue,reg,dc)
551 self.incoming_registrations[heart] = (eid,queue,reg,dc)
554 else:
552 else:
555 logger.error("registration::registration %i failed: %s"%(eid, content['reason']))
553 logger.error("registration::registration %i failed: %s"%(eid, content['reason']))
556 return eid
554 return eid
557
555
558 def unregister_engine(self, ident, msg):
556 def unregister_engine(self, ident, msg):
559 try:
557 try:
560 eid = msg['content']['id']
558 eid = msg['content']['id']
561 except:
559 except:
562 logger.error("registration::bad engine id for unregistration: %s"%ident)
560 logger.error("registration::bad engine id for unregistration: %s"%ident)
563 return
561 return
564 logger.info("registration::unregister_engine(%s)"%eid)
562 logger.info("registration::unregister_engine(%s)"%eid)
565 content=dict(id=eid, queue=self.engines[eid].queue)
563 content=dict(id=eid, queue=self.engines[eid].queue)
566 self.ids.remove(eid)
564 self.ids.remove(eid)
567 self.keytable.pop(eid)
565 self.keytable.pop(eid)
568 ec = self.engines.pop(eid)
566 ec = self.engines.pop(eid)
569 self.hearts.pop(ec.heartbeat)
567 self.hearts.pop(ec.heartbeat)
570 self.by_ident.pop(ec.queue)
568 self.by_ident.pop(ec.queue)
571 self.completed.pop(eid)
569 self.completed.pop(eid)
572 for msg_id in self.queues.pop(eid)[0]:
570 for msg_id in self.queues.pop(eid)[0]:
573 msg = self.pending.pop(msg_id)
571 msg = self.pending.pop(msg_id)
574 ############## TODO: HANDLE IT ################
572 ############## TODO: HANDLE IT ################
575
573
576 if self.notifier:
574 if self.notifier:
577 self.session.send(self.notifier, "unregistration_notification", content=content)
575 self.session.send(self.notifier, "unregistration_notification", content=content)
578
576
579 def finish_registration(self, heart):
577 def finish_registration(self, heart):
580 try:
578 try:
581 (eid,queue,reg,purge) = self.incoming_registrations.pop(heart)
579 (eid,queue,reg,purge) = self.incoming_registrations.pop(heart)
582 except KeyError:
580 except KeyError:
583 logger.error("registration::tried to finish nonexistant registration")
581 logger.error("registration::tried to finish nonexistant registration")
584 return
582 return
585 logger.info("registration::finished registering engine %i:%r"%(eid,queue))
583 logger.info("registration::finished registering engine %i:%r"%(eid,queue))
586 if purge is not None:
584 if purge is not None:
587 purge.stop()
585 purge.stop()
588 control = queue
586 control = queue
589 self.ids.add(eid)
587 self.ids.add(eid)
590 self.keytable[eid] = queue
588 self.keytable[eid] = queue
591 self.engines[eid] = EngineConnector(eid, queue, reg, control, heart)
589 self.engines[eid] = EngineConnector(eid, queue, reg, control, heart)
592 self.by_ident[queue] = eid
590 self.by_ident[queue] = eid
593 self.queues[eid] = ([],[])
591 self.queues[eid] = ([],[])
594 self.completed[eid] = list()
592 self.completed[eid] = list()
595 self.hearts[heart] = eid
593 self.hearts[heart] = eid
596 content = dict(id=eid, queue=self.engines[eid].queue)
594 content = dict(id=eid, queue=self.engines[eid].queue)
597 if self.notifier:
595 if self.notifier:
598 self.session.send(self.notifier, "registration_notification", content=content)
596 self.session.send(self.notifier, "registration_notification", content=content)
599
597
600 def _purge_stalled_registration(self, heart):
598 def _purge_stalled_registration(self, heart):
601 if heart in self.incoming_registrations:
599 if heart in self.incoming_registrations:
602 eid = self.incoming_registrations.pop(heart)[0]
600 eid = self.incoming_registrations.pop(heart)[0]
603 logger.info("registration::purging stalled registration: %i"%eid)
601 logger.info("registration::purging stalled registration: %i"%eid)
604 else:
602 else:
605 pass
603 pass
606
604
607 #------------------- Client Requests -------------------------------
605 #------------------- Client Requests -------------------------------
608
606
609 def check_load(self, client_id, msg):
607 def check_load(self, client_id, msg):
610 content = msg['content']
608 content = msg['content']
611 try:
609 try:
612 targets = content['targets']
610 targets = content['targets']
613 targets = self._validate_targets(targets)
611 targets = self._validate_targets(targets)
614 except:
612 except:
615 content = wrap_exception()
613 content = wrap_exception()
616 self.session.send(self.clientele, "controller_error",
614 self.session.send(self.clientele, "controller_error",
617 content=content, ident=client_id)
615 content=content, ident=client_id)
618 return
616 return
619
617
620 content = dict(status='ok')
618 content = dict(status='ok')
621 # loads = {}
619 # loads = {}
622 for t in targets:
620 for t in targets:
623 content[str(t)] = len(self.queues[t])
621 content[str(t)] = len(self.queues[t])
624 self.session.send(self.clientele, "load_reply", content=content, ident=client_id)
622 self.session.send(self.clientele, "load_reply", content=content, ident=client_id)
625
623
626
624
627 def queue_status(self, client_id, msg):
625 def queue_status(self, client_id, msg):
628 """handle queue_status request"""
626 """handle queue_status request"""
629 content = msg['content']
627 content = msg['content']
630 targets = content['targets']
628 targets = content['targets']
631 try:
629 try:
632 targets = self._validate_targets(targets)
630 targets = self._validate_targets(targets)
633 except:
631 except:
634 content = wrap_exception()
632 content = wrap_exception()
635 self.session.send(self.clientele, "controller_error",
633 self.session.send(self.clientele, "controller_error",
636 content=content, ident=client_id)
634 content=content, ident=client_id)
637 return
635 return
638 verbose = msg.get('verbose', False)
636 verbose = msg.get('verbose', False)
639 content = dict()
637 content = dict()
640 for t in targets:
638 for t in targets:
641 queue = self.queues[t]
639 queue = self.queues[t]
642 completed = self.completed[t]
640 completed = self.completed[t]
643 if not verbose:
641 if not verbose:
644 queue = len(queue)
642 queue = len(queue)
645 completed = len(completed)
643 completed = len(completed)
646 content[str(t)] = {'queue': queue, 'completed': completed }
644 content[str(t)] = {'queue': queue, 'completed': completed }
647 # pending
645 # pending
648 self.session.send(self.clientele, "queue_reply", content=content, ident=client_id)
646 self.session.send(self.clientele, "queue_reply", content=content, ident=client_id)
649
647
650 def purge_results(self, client_id, msg):
648 def purge_results(self, client_id, msg):
651 content = msg['content']
649 content = msg['content']
652 msg_ids = content.get('msg_ids', [])
650 msg_ids = content.get('msg_ids', [])
653 reply = dict(status='ok')
651 reply = dict(status='ok')
654 if msg_ids == 'all':
652 if msg_ids == 'all':
655 self.results = {}
653 self.results = {}
656 else:
654 else:
657 for msg_id in msg_ids:
655 for msg_id in msg_ids:
658 if msg_id in self.results:
656 if msg_id in self.results:
659 self.results.pop(msg_id)
657 self.results.pop(msg_id)
660 else:
658 else:
661 if msg_id in self.pending:
659 if msg_id in self.pending:
662 reply = dict(status='error', reason="msg pending: %r"%msg_id)
660 reply = dict(status='error', reason="msg pending: %r"%msg_id)
663 else:
661 else:
664 reply = dict(status='error', reason="No such msg: %r"%msg_id)
662 reply = dict(status='error', reason="No such msg: %r"%msg_id)
665 break
663 break
666 eids = content.get('engine_ids', [])
664 eids = content.get('engine_ids', [])
667 for eid in eids:
665 for eid in eids:
668 if eid not in self.engines:
666 if eid not in self.engines:
669 reply = dict(status='error', reason="No such engine: %i"%eid)
667 reply = dict(status='error', reason="No such engine: %i"%eid)
670 break
668 break
671 msg_ids = self.completed.pop(eid)
669 msg_ids = self.completed.pop(eid)
672 for msg_id in msg_ids:
670 for msg_id in msg_ids:
673 self.results.pop(msg_id)
671 self.results.pop(msg_id)
674
672
675 self.sesison.send(self.clientele, 'purge_reply', content=reply, ident=client_id)
673 self.sesison.send(self.clientele, 'purge_reply', content=reply, ident=client_id)
676
674
677 def resubmit_task(self, client_id, msg, buffers):
675 def resubmit_task(self, client_id, msg, buffers):
678 content = msg['content']
676 content = msg['content']
679 header = msg['header']
677 header = msg['header']
680
678
681
679
682 msg_ids = content.get('msg_ids', [])
680 msg_ids = content.get('msg_ids', [])
683 reply = dict(status='ok')
681 reply = dict(status='ok')
684 if msg_ids == 'all':
682 if msg_ids == 'all':
685 self.results = {}
683 self.results = {}
686 else:
684 else:
687 for msg_id in msg_ids:
685 for msg_id in msg_ids:
688 if msg_id in self.results:
686 if msg_id in self.results:
689 self.results.pop(msg_id)
687 self.results.pop(msg_id)
690 else:
688 else:
691 if msg_id in self.pending:
689 if msg_id in self.pending:
692 reply = dict(status='error', reason="msg pending: %r"%msg_id)
690 reply = dict(status='error', reason="msg pending: %r"%msg_id)
693 else:
691 else:
694 reply = dict(status='error', reason="No such msg: %r"%msg_id)
692 reply = dict(status='error', reason="No such msg: %r"%msg_id)
695 break
693 break
696 eids = content.get('engine_ids', [])
694 eids = content.get('engine_ids', [])
697 for eid in eids:
695 for eid in eids:
698 if eid not in self.engines:
696 if eid not in self.engines:
699 reply = dict(status='error', reason="No such engine: %i"%eid)
697 reply = dict(status='error', reason="No such engine: %i"%eid)
700 break
698 break
701 msg_ids = self.completed.pop(eid)
699 msg_ids = self.completed.pop(eid)
702 for msg_id in msg_ids:
700 for msg_id in msg_ids:
703 self.results.pop(msg_id)
701 self.results.pop(msg_id)
704
702
705 self.sesison.send(self.clientele, 'purge_reply', content=reply, ident=client_id)
703 self.sesison.send(self.clientele, 'purge_reply', content=reply, ident=client_id)
706
704
707 def get_results(self, client_id, msg):
705 def get_results(self, client_id, msg):
708 """get the result of 1 or more messages"""
706 """get the result of 1 or more messages"""
709 content = msg['content']
707 content = msg['content']
710 msg_ids = set(content['msg_ids'])
708 msg_ids = set(content['msg_ids'])
711 statusonly = content.get('status_only', False)
709 statusonly = content.get('status_only', False)
712 pending = []
710 pending = []
713 completed = []
711 completed = []
714 content = dict(status='ok')
712 content = dict(status='ok')
715 content['pending'] = pending
713 content['pending'] = pending
716 content['completed'] = completed
714 content['completed'] = completed
717 for msg_id in msg_ids:
715 for msg_id in msg_ids:
718 if msg_id in self.pending:
716 if msg_id in self.pending:
719 pending.append(msg_id)
717 pending.append(msg_id)
720 elif msg_id in self.results:
718 elif msg_id in self.results:
721 completed.append(msg_id)
719 completed.append(msg_id)
722 if not statusonly:
720 if not statusonly:
723 content[msg_id] = self.results[msg_id]['content']
721 content[msg_id] = self.results[msg_id]['content']
724 else:
722 else:
725 content = dict(status='error')
723 content = dict(status='error')
726 content['reason'] = 'no such message: '+msg_id
724 content['reason'] = 'no such message: '+msg_id
727 break
725 break
728 self.session.send(self.clientele, "result_reply", content=content,
726 self.session.send(self.clientele, "result_reply", content=content,
729 parent=msg, ident=client_id)
727 parent=msg, ident=client_id)
730
728
731
729
732
730
733 ############ OLD METHODS for Python Relay Controller ###################
731 ############ OLD METHODS for Python Relay Controller ###################
734 def _validate_engine_msg(self, msg):
732 def _validate_engine_msg(self, msg):
735 """validates and unpacks headers of a message. Returns False if invalid,
733 """validates and unpacks headers of a message. Returns False if invalid,
736 (ident, message)"""
734 (ident, message)"""
737 ident = msg[0]
735 ident = msg[0]
738 try:
736 try:
739 msg = self.session.unpack_message(msg[1:], content=False)
737 msg = self.session.unpack_message(msg[1:], content=False)
740 except:
738 except:
741 logger.error("engine.%s::Invalid Message %s"%(ident, msg))
739 logger.error("engine.%s::Invalid Message %s"%(ident, msg))
742 return False
740 return False
743
741
744 try:
742 try:
745 eid = msg.header.username
743 eid = msg.header.username
746 assert self.engines.has_key(eid)
744 assert self.engines.has_key(eid)
747 except:
745 except:
748 logger.error("engine::Invalid Engine ID %s"%(ident))
746 logger.error("engine::Invalid Engine ID %s"%(ident))
749 return False
747 return False
750
748
751 return eid, msg
749 return eid, msg
752
750
753
751
754 #--------------------
752 #--------------------
755 # Entry Point
753 # Entry Point
756 #--------------------
754 #--------------------
757
755
758 def main():
756 def main():
759 import time
757 import time
760 from multiprocessing import Process
758 from multiprocessing import Process
761
759
762 from zmq.eventloop.zmqstream import ZMQStream
760 from zmq.eventloop.zmqstream import ZMQStream
763 from zmq.devices import ProcessMonitoredQueue
761 from zmq.devices import ProcessMonitoredQueue
764 from zmq.log import handlers
762 from zmq.log import handlers
765
763
766 import streamsession as session
764 import streamsession as session
767 import heartmonitor
765 import heartmonitor
768 from scheduler import launch_scheduler
766 from scheduler import launch_scheduler
769
767
770 parser = make_argument_parser()
768 parser = make_argument_parser()
771
769
772 parser.add_argument('--client', type=int, metavar='PORT', default=0,
770 parser.add_argument('--client', type=int, metavar='PORT', default=0,
773 help='set the XREP port for clients [default: random]')
771 help='set the XREP port for clients [default: random]')
774 parser.add_argument('--notice', type=int, metavar='PORT', default=0,
772 parser.add_argument('--notice', type=int, metavar='PORT', default=0,
775 help='set the PUB socket for registration notification [default: random]')
773 help='set the PUB socket for registration notification [default: random]')
776 parser.add_argument('--hb', type=str, metavar='PORTS',
774 parser.add_argument('--hb', type=str, metavar='PORTS',
777 help='set the 2 ports for heartbeats [default: random]')
775 help='set the 2 ports for heartbeats [default: random]')
778 parser.add_argument('--ping', type=int, default=3000,
776 parser.add_argument('--ping', type=int, default=3000,
779 help='set the heartbeat period in ms [default: 3000]')
777 help='set the heartbeat period in ms [default: 3000]')
780 parser.add_argument('--monitor', type=int, metavar='PORT', default=0,
778 parser.add_argument('--monitor', type=int, metavar='PORT', default=0,
781 help='set the SUB port for queue monitoring [default: random]')
779 help='set the SUB port for queue monitoring [default: random]')
782 parser.add_argument('--mux', type=str, metavar='PORTS',
780 parser.add_argument('--mux', type=str, metavar='PORTS',
783 help='set the XREP ports for the MUX queue [default: random]')
781 help='set the XREP ports for the MUX queue [default: random]')
784 parser.add_argument('--task', type=str, metavar='PORTS',
782 parser.add_argument('--task', type=str, metavar='PORTS',
785 help='set the XREP/XREQ ports for the task queue [default: random]')
783 help='set the XREP/XREQ ports for the task queue [default: random]')
786 parser.add_argument('--control', type=str, metavar='PORTS',
784 parser.add_argument('--control', type=str, metavar='PORTS',
787 help='set the XREP ports for the control queue [default: random]')
785 help='set the XREP ports for the control queue [default: random]')
788 parser.add_argument('--scheduler', type=str, default='pure',
786 parser.add_argument('--scheduler', type=str, default='pure',
789 choices = ['pure', 'lru', 'plainrandom', 'weighted', 'twobin','leastload'],
787 choices = ['pure', 'lru', 'plainrandom', 'weighted', 'twobin','leastload'],
790 help='select the task scheduler [default: pure ZMQ]')
788 help='select the task scheduler [default: pure ZMQ]')
791
789
792 args = parser.parse_args()
790 args = parser.parse_args()
793
791
794 if args.url:
792 if args.url:
795 args.transport,iface = args.url.split('://')
793 args.transport,iface = args.url.split('://')
796 iface = iface.split(':')
794 iface = iface.split(':')
797 args.ip = iface[0]
795 args.ip = iface[0]
798 if iface[1]:
796 if iface[1]:
799 args.regport = iface[1]
797 args.regport = iface[1]
800
798
801 iface="%s://%s"%(args.transport,args.ip)+':%i'
799 iface="%s://%s"%(args.transport,args.ip)+':%i'
802
800
803 random_ports = 0
801 random_ports = 0
804 if args.hb:
802 if args.hb:
805 hb = split_ports(args.hb, 2)
803 hb = split_ports(args.hb, 2)
806 else:
804 else:
807 hb = select_random_ports(2)
805 hb = select_random_ports(2)
808 if args.mux:
806 if args.mux:
809 mux = split_ports(args.mux, 2)
807 mux = split_ports(args.mux, 2)
810 else:
808 else:
811 mux = None
809 mux = None
812 random_ports += 2
810 random_ports += 2
813 if args.task:
811 if args.task:
814 task = split_ports(args.task, 2)
812 task = split_ports(args.task, 2)
815 else:
813 else:
816 task = None
814 task = None
817 random_ports += 2
815 random_ports += 2
818 if args.control:
816 if args.control:
819 control = split_ports(args.control, 2)
817 control = split_ports(args.control, 2)
820 else:
818 else:
821 control = None
819 control = None
822 random_ports += 2
820 random_ports += 2
823
821
824 ctx = zmq.Context()
822 ctx = zmq.Context()
825 loop = ioloop.IOLoop.instance()
823 loop = ioloop.IOLoop.instance()
826
824
827 # setup logging
825 # setup logging
828 connect_logger(ctx, iface%args.logport, root="controller", loglevel=args.loglevel)
826 connect_logger(ctx, iface%args.logport, root="controller", loglevel=args.loglevel)
829
827
830 # Registrar socket
828 # Registrar socket
831 reg = ZMQStream(ctx.socket(zmq.XREP), loop)
829 reg = ZMQStream(ctx.socket(zmq.XREP), loop)
832 regport = bind_port(reg, args.ip, args.regport)
830 regport = bind_port(reg, args.ip, args.regport)
833
831
834 ### Engine connections ###
832 ### Engine connections ###
835
833
836 # heartbeat
834 # heartbeat
837 hpub = ctx.socket(zmq.PUB)
835 hpub = ctx.socket(zmq.PUB)
838 bind_port(hpub, args.ip, hb[0])
836 bind_port(hpub, args.ip, hb[0])
839 hrep = ctx.socket(zmq.XREP)
837 hrep = ctx.socket(zmq.XREP)
840 bind_port(hrep, args.ip, hb[1])
838 bind_port(hrep, args.ip, hb[1])
841
839
842 hmon = heartmonitor.HeartMonitor(loop, ZMQStream(hpub,loop), ZMQStream(hrep,loop),args.ping)
840 hmon = heartmonitor.HeartMonitor(loop, ZMQStream(hpub,loop), ZMQStream(hrep,loop),args.ping)
843 hmon.start()
841 hmon.start()
844
842
845 ### Client connections ###
843 ### Client connections ###
846 # Clientele socket
844 # Clientele socket
847 c = ZMQStream(ctx.socket(zmq.XREP), loop)
845 c = ZMQStream(ctx.socket(zmq.XREP), loop)
848 cport = bind_port(c, args.ip, args.client)
846 cport = bind_port(c, args.ip, args.client)
849 # Notifier socket
847 # Notifier socket
850 n = ZMQStream(ctx.socket(zmq.PUB), loop)
848 n = ZMQStream(ctx.socket(zmq.PUB), loop)
851 nport = bind_port(n, args.ip, args.notice)
849 nport = bind_port(n, args.ip, args.notice)
852
850
853 thesession = session.StreamSession(username=args.ident or "controller")
851 thesession = session.StreamSession(username=args.ident or "controller")
854
852
855 ### build and launch the queues ###
853 ### build and launch the queues ###
856
854
857 # monitor socket
855 # monitor socket
858 sub = ctx.socket(zmq.SUB)
856 sub = ctx.socket(zmq.SUB)
859 sub.setsockopt(zmq.SUBSCRIBE, "")
857 sub.setsockopt(zmq.SUBSCRIBE, "")
860 monport = bind_port(sub, args.ip, args.monitor)
858 monport = bind_port(sub, args.ip, args.monitor)
861 sub = ZMQStream(sub, loop)
859 sub = ZMQStream(sub, loop)
862
860
863 ports = select_random_ports(random_ports)
861 ports = select_random_ports(random_ports)
864 # Multiplexer Queue (in a Process)
862 # Multiplexer Queue (in a Process)
865 if not mux:
863 if not mux:
866 mux = (ports.pop(),ports.pop())
864 mux = (ports.pop(),ports.pop())
867 q = ProcessMonitoredQueue(zmq.XREP, zmq.XREP, zmq.PUB, 'in', 'out')
865 q = ProcessMonitoredQueue(zmq.XREP, zmq.XREP, zmq.PUB, 'in', 'out')
868 q.bind_in(iface%mux[0])
866 q.bind_in(iface%mux[0])
869 q.bind_out(iface%mux[1])
867 q.bind_out(iface%mux[1])
870 q.connect_mon(iface%monport)
868 q.connect_mon(iface%monport)
871 q.daemon=True
869 q.daemon=True
872 q.start()
870 q.start()
873
871
874 # Control Queue (in a Process)
872 # Control Queue (in a Process)
875 if not control:
873 if not control:
876 control = (ports.pop(),ports.pop())
874 control = (ports.pop(),ports.pop())
877 q = ProcessMonitoredQueue(zmq.XREP, zmq.XREP, zmq.PUB, 'incontrol', 'outcontrol')
875 q = ProcessMonitoredQueue(zmq.XREP, zmq.XREP, zmq.PUB, 'incontrol', 'outcontrol')
878 q.bind_in(iface%control[0])
876 q.bind_in(iface%control[0])
879 q.bind_out(iface%control[1])
877 q.bind_out(iface%control[1])
880 q.connect_mon(iface%monport)
878 q.connect_mon(iface%monport)
881 q.daemon=True
879 q.daemon=True
882 q.start()
880 q.start()
883
881
884 # Task Queue (in a Process)
882 # Task Queue (in a Process)
885 if not task:
883 if not task:
886 task = (ports.pop(),ports.pop())
884 task = (ports.pop(),ports.pop())
887 if args.scheduler == 'pure':
885 if args.scheduler == 'pure':
888 q = ProcessMonitoredQueue(zmq.XREP, zmq.XREQ, zmq.PUB, 'intask', 'outtask')
886 q = ProcessMonitoredQueue(zmq.XREP, zmq.XREQ, zmq.PUB, 'intask', 'outtask')
889 q.bind_in(iface%task[0])
887 q.bind_in(iface%task[0])
890 q.bind_out(iface%task[1])
888 q.bind_out(iface%task[1])
891 q.connect_mon(iface%monport)
889 q.connect_mon(iface%monport)
892 q.daemon=True
890 q.daemon=True
893 q.start()
891 q.start()
894 else:
892 else:
895 sargs = (iface%task[0],iface%task[1],iface%monport,iface%nport,args.scheduler)
893 sargs = (iface%task[0],iface%task[1],iface%monport,iface%nport,args.scheduler)
896 print sargs
894 print sargs
897 p = Process(target=launch_scheduler, args=sargs)
895 p = Process(target=launch_scheduler, args=sargs)
898 p.daemon=True
896 p.daemon=True
899 p.start()
897 p.start()
900
898
901 time.sleep(.25)
899 time.sleep(.25)
902
900
903 # build connection dicts
901 # build connection dicts
904 engine_addrs = {
902 engine_addrs = {
905 'control' : iface%control[1],
903 'control' : iface%control[1],
906 'queue': iface%mux[1],
904 'queue': iface%mux[1],
907 'heartbeat': (iface%hb[0], iface%hb[1]),
905 'heartbeat': (iface%hb[0], iface%hb[1]),
908 'task' : iface%task[1],
906 'task' : iface%task[1],
909 'monitor' : iface%monport,
907 'monitor' : iface%monport,
910 }
908 }
911
909
912 client_addrs = {
910 client_addrs = {
913 'control' : iface%control[0],
911 'control' : iface%control[0],
914 'query': iface%cport,
912 'query': iface%cport,
915 'queue': iface%mux[0],
913 'queue': iface%mux[0],
916 'task' : iface%task[0],
914 'task' : iface%task[0],
917 'notification': iface%nport
915 'notification': iface%nport
918 }
916 }
919 con = Controller(loop, thesession, sub, reg, hmon, c, n, None, engine_addrs, client_addrs)
917 con = Controller(loop, thesession, sub, reg, hmon, c, n, None, engine_addrs, client_addrs)
920 loop.start()
918 loop.start()
921
919
@@ -1,401 +1,404 b''
1 #----------------------------------------------------------------------
1 #----------------------------------------------------------------------
2 # Imports
2 # Imports
3 #----------------------------------------------------------------------
3 #----------------------------------------------------------------------
4
4
5 from random import randint,random
5 from random import randint,random
6
6
7 try:
7 try:
8 import numpy
8 import numpy
9 except ImportError:
9 except ImportError:
10 numpy = None
10 numpy = None
11
11
12 import zmq
12 import zmq
13 from zmq.eventloop import ioloop, zmqstream
13 from zmq.eventloop import ioloop, zmqstream
14
14
15 # local imports
15 # local imports
16 from IPython.zmq.log import logger # a Logger object
16 from IPython.zmq.log import logger # a Logger object
17 from client import Client
17 from client import Client
18 from dependency import Dependency
18 from dependency import Dependency
19 import streamsession as ss
19 import streamsession as ss
20
20
21 from IPython.external.decorator import decorator
21 from IPython.external.decorator import decorator
22
22
23 @decorator
23 @decorator
24 def logged(f,self,*args,**kwargs):
24 def logged(f,self,*args,**kwargs):
25 print ("#--------------------")
25 print ("#--------------------")
26 print ("%s(*%s,**%s)"%(f.func_name, args, kwargs))
26 print ("%s(*%s,**%s)"%(f.func_name, args, kwargs))
27 print ("#--")
27 return f(self,*args, **kwargs)
28 return f(self,*args, **kwargs)
28
29
29 #----------------------------------------------------------------------
30 #----------------------------------------------------------------------
30 # Chooser functions
31 # Chooser functions
31 #----------------------------------------------------------------------
32 #----------------------------------------------------------------------
32
33
33 def plainrandom(loads):
34 def plainrandom(loads):
34 """Plain random pick."""
35 """Plain random pick."""
35 n = len(loads)
36 n = len(loads)
36 return randint(0,n-1)
37 return randint(0,n-1)
37
38
38 def lru(loads):
39 def lru(loads):
39 """Always pick the front of the line.
40 """Always pick the front of the line.
40
41
41 The content of loads is ignored.
42 The content of loads is ignored.
42
43
43 Assumes LRU ordering of loads, with oldest first.
44 Assumes LRU ordering of loads, with oldest first.
44 """
45 """
45 return 0
46 return 0
46
47
47 def twobin(loads):
48 def twobin(loads):
48 """Pick two at random, use the LRU of the two.
49 """Pick two at random, use the LRU of the two.
49
50
50 The content of loads is ignored.
51 The content of loads is ignored.
51
52
52 Assumes LRU ordering of loads, with oldest first.
53 Assumes LRU ordering of loads, with oldest first.
53 """
54 """
54 n = len(loads)
55 n = len(loads)
55 a = randint(0,n-1)
56 a = randint(0,n-1)
56 b = randint(0,n-1)
57 b = randint(0,n-1)
57 return min(a,b)
58 return min(a,b)
58
59
59 def weighted(loads):
60 def weighted(loads):
60 """Pick two at random using inverse load as weight.
61 """Pick two at random using inverse load as weight.
61
62
62 Return the less loaded of the two.
63 Return the less loaded of the two.
63 """
64 """
64 # weight 0 a million times more than 1:
65 # weight 0 a million times more than 1:
65 weights = 1./(1e-6+numpy.array(loads))
66 weights = 1./(1e-6+numpy.array(loads))
66 sums = weights.cumsum()
67 sums = weights.cumsum()
67 t = sums[-1]
68 t = sums[-1]
68 x = random()*t
69 x = random()*t
69 y = random()*t
70 y = random()*t
70 idx = 0
71 idx = 0
71 idy = 0
72 idy = 0
72 while sums[idx] < x:
73 while sums[idx] < x:
73 idx += 1
74 idx += 1
74 while sums[idy] < y:
75 while sums[idy] < y:
75 idy += 1
76 idy += 1
76 if weights[idy] > weights[idx]:
77 if weights[idy] > weights[idx]:
77 return idy
78 return idy
78 else:
79 else:
79 return idx
80 return idx
80
81
81 def leastload(loads):
82 def leastload(loads):
82 """Always choose the lowest load.
83 """Always choose the lowest load.
83
84
84 If the lowest load occurs more than once, the first
85 If the lowest load occurs more than once, the first
85 occurance will be used. If loads has LRU ordering, this means
86 occurance will be used. If loads has LRU ordering, this means
86 the LRU of those with the lowest load is chosen.
87 the LRU of those with the lowest load is chosen.
87 """
88 """
88 return loads.index(min(loads))
89 return loads.index(min(loads))
89
90
90 #---------------------------------------------------------------------
91 #---------------------------------------------------------------------
91 # Classes
92 # Classes
92 #---------------------------------------------------------------------
93 #---------------------------------------------------------------------
93 class TaskScheduler(object):
94 class TaskScheduler(object):
94 """Simple Python TaskScheduler object.
95 """Python TaskScheduler object.
95
96
96 This is the simplest object that supports msg_id based
97 This is the simplest object that supports msg_id based
97 DAG dependencies. *Only* task msg_ids are checked, not
98 DAG dependencies. *Only* task msg_ids are checked, not
98 msg_ids of jobs submitted via the MUX queue.
99 msg_ids of jobs submitted via the MUX queue.
99
100
100 """
101 """
101
102
102 scheme = leastload # function for determining the destination
103 scheme = leastload # function for determining the destination
103 client_stream = None # client-facing stream
104 client_stream = None # client-facing stream
104 engine_stream = None # engine-facing stream
105 engine_stream = None # engine-facing stream
105 mon_stream = None # controller-facing stream
106 mon_stream = None # controller-facing stream
106 dependencies = None # dict by msg_id of [ msg_ids that depend on key ]
107 dependencies = None # dict by msg_id of [ msg_ids that depend on key ]
107 depending = None # dict by msg_id of (msg_id, raw_msg, after, follow)
108 depending = None # dict by msg_id of (msg_id, raw_msg, after, follow)
108 pending = None # dict by engine_uuid of submitted tasks
109 pending = None # dict by engine_uuid of submitted tasks
109 completed = None # dict by engine_uuid of completed tasks
110 completed = None # dict by engine_uuid of completed tasks
110 clients = None # dict by msg_id for who submitted the task
111 clients = None # dict by msg_id for who submitted the task
111 targets = None # list of target IDENTs
112 targets = None # list of target IDENTs
112 loads = None # list of engine loads
113 loads = None # list of engine loads
113 all_done = None # set of all completed tasks
114 all_done = None # set of all completed tasks
114 blacklist = None # dict by msg_id of locations where a job has encountered UnmetDependency
115 blacklist = None # dict by msg_id of locations where a job has encountered UnmetDependency
115
116
116
117
117 def __init__(self, client_stream, engine_stream, mon_stream,
118 def __init__(self, client_stream, engine_stream, mon_stream,
118 notifier_stream, scheme=None, io_loop=None):
119 notifier_stream, scheme=None, io_loop=None):
119 if io_loop is None:
120 if io_loop is None:
120 io_loop = ioloop.IOLoop.instance()
121 io_loop = ioloop.IOLoop.instance()
121 self.io_loop = io_loop
122 self.io_loop = io_loop
122 self.client_stream = client_stream
123 self.client_stream = client_stream
123 self.engine_stream = engine_stream
124 self.engine_stream = engine_stream
124 self.mon_stream = mon_stream
125 self.mon_stream = mon_stream
125 self.notifier_stream = notifier_stream
126 self.notifier_stream = notifier_stream
126
127
127 if scheme is not None:
128 if scheme is not None:
128 self.scheme = scheme
129 self.scheme = scheme
129 else:
130 else:
130 self.scheme = TaskScheduler.scheme
131 self.scheme = TaskScheduler.scheme
131
132
132 self.session = ss.StreamSession(username="TaskScheduler")
133 self.session = ss.StreamSession(username="TaskScheduler")
133
134
134 self.dependencies = {}
135 self.dependencies = {}
135 self.depending = {}
136 self.depending = {}
136 self.completed = {}
137 self.completed = {}
137 self.pending = {}
138 self.pending = {}
138 self.all_done = set()
139 self.all_done = set()
140 self.blacklist = {}
139
141
140 self.targets = []
142 self.targets = []
141 self.loads = []
143 self.loads = []
142
144
143 engine_stream.on_recv(self.dispatch_result, copy=False)
145 engine_stream.on_recv(self.dispatch_result, copy=False)
144 self._notification_handlers = dict(
146 self._notification_handlers = dict(
145 registration_notification = self._register_engine,
147 registration_notification = self._register_engine,
146 unregistration_notification = self._unregister_engine
148 unregistration_notification = self._unregister_engine
147 )
149 )
148 self.notifier_stream.on_recv(self.dispatch_notification)
150 self.notifier_stream.on_recv(self.dispatch_notification)
149
151
150 def resume_receiving(self):
152 def resume_receiving(self):
151 """resume accepting jobs"""
153 """resume accepting jobs"""
152 self.client_stream.on_recv(self.dispatch_submission, copy=False)
154 self.client_stream.on_recv(self.dispatch_submission, copy=False)
153
155
154 def stop_receiving(self):
156 def stop_receiving(self):
155 self.client_stream.on_recv(None)
157 self.client_stream.on_recv(None)
156
158
157 #-----------------------------------------------------------------------
159 #-----------------------------------------------------------------------
158 # [Un]Registration Handling
160 # [Un]Registration Handling
159 #-----------------------------------------------------------------------
161 #-----------------------------------------------------------------------
160
162
161 def dispatch_notification(self, msg):
163 def dispatch_notification(self, msg):
162 """dispatch register/unregister events."""
164 """dispatch register/unregister events."""
163 idents,msg = self.session.feed_identities(msg)
165 idents,msg = self.session.feed_identities(msg)
164 msg = self.session.unpack_message(msg)
166 msg = self.session.unpack_message(msg)
165 msg_type = msg['msg_type']
167 msg_type = msg['msg_type']
166 handler = self._notification_handlers.get(msg_type, None)
168 handler = self._notification_handlers.get(msg_type, None)
167 if handler is None:
169 if handler is None:
168 raise Exception("Unhandled message type: %s"%msg_type)
170 raise Exception("Unhandled message type: %s"%msg_type)
169 else:
171 else:
170 try:
172 try:
171 handler(str(msg['content']['queue']))
173 handler(str(msg['content']['queue']))
172 except KeyError:
174 except KeyError:
173 logger.error("task::Invalid notification msg: %s"%msg)
175 logger.error("task::Invalid notification msg: %s"%msg)
174 @logged
176 @logged
175 def _register_engine(self, uid):
177 def _register_engine(self, uid):
176 """new engine became available"""
178 """new engine became available"""
177 # head of the line:
179 # head of the line:
178 self.targets.insert(0,uid)
180 self.targets.insert(0,uid)
179 self.loads.insert(0,0)
181 self.loads.insert(0,0)
180 # initialize sets
182 # initialize sets
181 self.completed[uid] = set()
183 self.completed[uid] = set()
182 self.pending[uid] = {}
184 self.pending[uid] = {}
183 if len(self.targets) == 1:
185 if len(self.targets) == 1:
184 self.resume_receiving()
186 self.resume_receiving()
185
187
186 def _unregister_engine(self, uid):
188 def _unregister_engine(self, uid):
187 """existing engine became unavailable"""
189 """existing engine became unavailable"""
188 # handle any potentially finished tasks:
190 # handle any potentially finished tasks:
189 if len(self.targets) == 1:
191 if len(self.targets) == 1:
190 self.stop_receiving()
192 self.stop_receiving()
191 self.engine_stream.flush()
193 self.engine_stream.flush()
192
194
193 self.completed.pop(uid)
195 self.completed.pop(uid)
194 lost = self.pending.pop(uid)
196 lost = self.pending.pop(uid)
195
197
196 idx = self.targets.index(uid)
198 idx = self.targets.index(uid)
197 self.targets.pop(idx)
199 self.targets.pop(idx)
198 self.loads.pop(idx)
200 self.loads.pop(idx)
199
201
200 self.handle_stranded_tasks(lost)
202 self.handle_stranded_tasks(lost)
201
203
202 def handle_stranded_tasks(self, lost):
204 def handle_stranded_tasks(self, lost):
203 """deal with jobs resident in an engine that died."""
205 """deal with jobs resident in an engine that died."""
204 # TODO: resubmit the tasks?
206 # TODO: resubmit the tasks?
205 for msg_id in lost:
207 for msg_id in lost:
206 pass
208 pass
207
209
208
210
209 #-----------------------------------------------------------------------
211 #-----------------------------------------------------------------------
210 # Job Submission
212 # Job Submission
211 #-----------------------------------------------------------------------
213 #-----------------------------------------------------------------------
212 @logged
214 @logged
213 def dispatch_submission(self, raw_msg):
215 def dispatch_submission(self, raw_msg):
214 """dispatch job submission"""
216 """dispatch job submission"""
215 # ensure targets up to date:
217 # ensure targets up to date:
216 self.notifier_stream.flush()
218 self.notifier_stream.flush()
217 try:
219 try:
218 idents, msg = self.session.feed_identities(raw_msg, copy=False)
220 idents, msg = self.session.feed_identities(raw_msg, copy=False)
219 except Exception, e:
221 except Exception, e:
220 logger.error("task::Invaid msg: %s"%msg)
222 logger.error("task::Invaid msg: %s"%msg)
221 return
223 return
222
224
223 msg = self.session.unpack_message(msg, content=False, copy=False)
225 msg = self.session.unpack_message(msg, content=False, copy=False)
224 print idents,msg
225 header = msg['header']
226 header = msg['header']
226 msg_id = header['msg_id']
227 msg_id = header['msg_id']
227 after = Dependency(header.get('after', []))
228 after = Dependency(header.get('after', []))
228 if after.mode == 'all':
229 if after.mode == 'all':
229 after.difference_update(self.all_done)
230 after.difference_update(self.all_done)
230 if after.check(self.all_done):
231 if after.check(self.all_done):
231 # recast as empty set, if we are already met,
232 # recast as empty set, if we are already met,
232 # to prevent
233 # to prevent
233 after = Dependency([])
234 after = Dependency([])
234
235
235 follow = Dependency(header.get('follow', []))
236 follow = Dependency(header.get('follow', []))
236 print raw_msg
237 if len(after) == 0:
237 if len(after) == 0:
238 # time deps already met, try to run
238 # time deps already met, try to run
239 if not self.maybe_run(msg_id, raw_msg, follow):
239 if not self.maybe_run(msg_id, raw_msg, follow):
240 # can't run yet
240 # can't run yet
241 self.save_unmet(msg_id, raw_msg, after, follow)
241 self.save_unmet(msg_id, raw_msg, after, follow)
242 else:
242 else:
243 self.save_unmet(msg_id, raw_msg, after, follow)
243 self.save_unmet(msg_id, raw_msg, after, follow)
244 # send to monitor
244 # send to monitor
245 self.mon_stream.send_multipart(['intask']+raw_msg, copy=False)
245 self.mon_stream.send_multipart(['intask']+raw_msg, copy=False)
246 @logged
246 @logged
247 def maybe_run(self, msg_id, raw_msg, follow=None):
247 def maybe_run(self, msg_id, raw_msg, follow=None):
248 """check location dependencies, and run if they are met."""
248 """check location dependencies, and run if they are met."""
249
249
250 if follow:
250 if follow:
251 def can_run(idx):
251 def can_run(idx):
252 target = self.targets[idx]
252 target = self.targets[idx]
253 return target not in self.blacklist.get(msg_id, []) and\
253 return target not in self.blacklist.get(msg_id, []) and\
254 follow.check(self.completed[target])
254 follow.check(self.completed[target])
255
255
256 indices = filter(can_run, range(len(self.targets)))
256 indices = filter(can_run, range(len(self.targets)))
257 if not indices:
257 if not indices:
258 return False
258 return False
259 else:
259 else:
260 indices = None
260 indices = None
261
261
262 self.submit_task(msg_id, raw_msg, indices)
262 self.submit_task(msg_id, raw_msg, indices)
263 return True
263 return True
264
264
265 @logged
265 @logged
266 def save_unmet(self, msg_id, msg, after, follow):
266 def save_unmet(self, msg_id, msg, after, follow):
267 """Save a message for later submission when its dependencies are met."""
267 """Save a message for later submission when its dependencies are met."""
268 self.depending[msg_id] = (msg_id,msg,after,follow)
268 self.depending[msg_id] = (msg_id,msg,after,follow)
269 # track the ids in both follow/after, but not those already completed
269 # track the ids in both follow/after, but not those already completed
270 for dep_id in after.union(follow).difference(self.all_done):
270 for dep_id in after.union(follow).difference(self.all_done):
271 print dep_id
271 if dep_id not in self.dependencies:
272 if dep_id not in self.dependencies:
272 self.dependencies[dep_id] = set()
273 self.dependencies[dep_id] = set()
273 self.dependencies[dep_id].add(msg_id)
274 self.dependencies[dep_id].add(msg_id)
275
274 @logged
276 @logged
275 def submit_task(self, msg_id, msg, follow=None, indices=None):
277 def submit_task(self, msg_id, msg, follow=None, indices=None):
276 """submit a task to any of a subset of our targets"""
278 """submit a task to any of a subset of our targets"""
277 if indices:
279 if indices:
278 loads = [self.loads[i] for i in indices]
280 loads = [self.loads[i] for i in indices]
279 else:
281 else:
280 loads = self.loads
282 loads = self.loads
281 idx = self.scheme(loads)
283 idx = self.scheme(loads)
282 if indices:
284 if indices:
283 idx = indices[idx]
285 idx = indices[idx]
284 target = self.targets[idx]
286 target = self.targets[idx]
285 print target, map(str, msg[:3])
287 print target, map(str, msg[:3])
286 self.engine_stream.socket.send(target, flags=zmq.SNDMORE, copy=False)
288 self.engine_stream.send(target, flags=zmq.SNDMORE, copy=False)
287 self.engine_stream.socket.send_multipart(msg, copy=False)
289 self.engine_stream.send_multipart(msg, copy=False)
288 self.add_job(idx)
290 self.add_job(idx)
289 self.pending[target][msg_id] = (msg, follow)
291 self.pending[target][msg_id] = (msg, follow)
290
292
291 #-----------------------------------------------------------------------
293 #-----------------------------------------------------------------------
292 # Result Handling
294 # Result Handling
293 #-----------------------------------------------------------------------
295 #-----------------------------------------------------------------------
294 @logged
296 @logged
295 def dispatch_result(self, raw_msg):
297 def dispatch_result(self, raw_msg):
296 try:
298 try:
297 idents,msg = self.session.feed_identities(raw_msg, copy=False)
299 idents,msg = self.session.feed_identities(raw_msg, copy=False)
298 except Exception, e:
300 except Exception, e:
299 logger.error("task::Invaid result: %s"%msg)
301 logger.error("task::Invaid result: %s"%msg)
300 return
302 return
301 msg = self.session.unpack_message(msg, content=False, copy=False)
303 msg = self.session.unpack_message(msg, content=False, copy=False)
302 header = msg['header']
304 header = msg['header']
303 if header.get('dependencies_met', True):
305 if header.get('dependencies_met', True):
304 self.handle_result_success(idents, msg['parent_header'], raw_msg)
306 self.handle_result_success(idents, msg['parent_header'], raw_msg)
305 # send to monitor
307 # send to monitor
306 self.mon_stream.send_multipart(['outtask']+raw_msg, copy=False)
308 self.mon_stream.send_multipart(['outtask']+raw_msg, copy=False)
307 else:
309 else:
308 self.handle_unmet_dependency(self, idents, msg['parent_header'])
310 self.handle_unmet_dependency(idents, msg['parent_header'])
309
311
310 @logged
312 @logged
311 def handle_result_success(self, idents, parent, raw_msg):
313 def handle_result_success(self, idents, parent, raw_msg):
312 # first, relay result to client
314 # first, relay result to client
313 engine = idents[0]
315 engine = idents[0]
314 client = idents[1]
316 client = idents[1]
315 # swap_ids for XREP-XREP mirror
317 # swap_ids for XREP-XREP mirror
316 raw_msg[:2] = [client,engine]
318 raw_msg[:2] = [client,engine]
317 print map(str, raw_msg[:4])
319 print map(str, raw_msg[:4])
318 self.client_stream.send_multipart(raw_msg, copy=False)
320 self.client_stream.send_multipart(raw_msg, copy=False)
319 # now, update our data structures
321 # now, update our data structures
320 msg_id = parent['msg_id']
322 msg_id = parent['msg_id']
321 self.pending[engine].pop(msg_id)
323 self.pending[engine].pop(msg_id)
322 self.completed[engine].add(msg_id)
324 self.completed[engine].add(msg_id)
323
325
324 self.update_dependencies(msg_id)
326 self.update_dependencies(msg_id)
325
327
326 @logged
328 @logged
327 def handle_unmet_dependency(self, idents, parent):
329 def handle_unmet_dependency(self, idents, parent):
328 engine = idents[0]
330 engine = idents[0]
329 msg_id = parent['msg_id']
331 msg_id = parent['msg_id']
330 if msg_id not in self.blacklist:
332 if msg_id not in self.blacklist:
331 self.blacklist[msg_id] = set()
333 self.blacklist[msg_id] = set()
332 self.blacklist[msg_id].add(engine)
334 self.blacklist[msg_id].add(engine)
333 raw_msg,follow = self.pending[engine].pop(msg_id)
335 raw_msg,follow = self.pending[engine].pop(msg_id)
334 if not self.maybe_run(raw_msg, follow):
336 if not self.maybe_run(msg_id, raw_msg, follow):
335 # resubmit failed, put it back in our dependency tree
337 # resubmit failed, put it back in our dependency tree
336 self.save_unmet(msg_id, raw_msg, Dependency(), follow)
338 self.save_unmet(msg_id, raw_msg, Dependency(), follow)
337 pass
339 pass
338 @logged
340 @logged
339 def update_dependencies(self, dep_id):
341 def update_dependencies(self, dep_id):
340 """dep_id just finished. Update our dependency
342 """dep_id just finished. Update our dependency
341 table and submit any jobs that just became runable."""
343 table and submit any jobs that just became runable."""
342 if dep_id not in self.dependencies:
344 if dep_id not in self.dependencies:
343 return
345 return
344 jobs = self.dependencies.pop(dep_id)
346 jobs = self.dependencies.pop(dep_id)
345 for job in jobs:
347 for job in jobs:
346 msg_id, raw_msg, after, follow = self.depending[job]
348 msg_id, raw_msg, after, follow = self.depending[job]
347 if msg_id in after:
349 if msg_id in after:
348 after.remove(msg_id)
350 after.remove(msg_id)
349 if not after: # time deps met
351 if not after: # time deps met
350 if self.maybe_run(msg_id, raw_msg, follow):
352 if self.maybe_run(msg_id, raw_msg, follow):
351 self.depending.pop(job)
353 self.depending.pop(job)
352 for mid in follow:
354 for mid in follow:
353 self.dependencies[mid].remove(msg_id)
355 if mid in self.dependencies:
356 self.dependencies[mid].remove(msg_id)
354
357
355 #----------------------------------------------------------------------
358 #----------------------------------------------------------------------
356 # methods to be overridden by subclasses
359 # methods to be overridden by subclasses
357 #----------------------------------------------------------------------
360 #----------------------------------------------------------------------
358
361
359 def add_job(self, idx):
362 def add_job(self, idx):
360 """Called after self.targets[idx] just got the job with header.
363 """Called after self.targets[idx] just got the job with header.
361 Override with subclasses. The default ordering is simple LRU.
364 Override with subclasses. The default ordering is simple LRU.
362 The default loads are the number of outstanding jobs."""
365 The default loads are the number of outstanding jobs."""
363 self.loads[idx] += 1
366 self.loads[idx] += 1
364 for lis in (self.targets, self.loads):
367 for lis in (self.targets, self.loads):
365 lis.append(lis.pop(idx))
368 lis.append(lis.pop(idx))
366
369
367
370
368 def finish_job(self, idx):
371 def finish_job(self, idx):
369 """Called after self.targets[idx] just finished a job.
372 """Called after self.targets[idx] just finished a job.
370 Override with subclasses."""
373 Override with subclasses."""
371 self.loads[idx] -= 1
374 self.loads[idx] -= 1
372
375
373
376
374
377
375 def launch_scheduler(in_addr, out_addr, mon_addr, not_addr, scheme='weighted'):
378 def launch_scheduler(in_addr, out_addr, mon_addr, not_addr, scheme='weighted'):
376 from zmq.eventloop import ioloop
379 from zmq.eventloop import ioloop
377 from zmq.eventloop.zmqstream import ZMQStream
380 from zmq.eventloop.zmqstream import ZMQStream
378
381
379 ctx = zmq.Context()
382 ctx = zmq.Context()
380 loop = ioloop.IOLoop()
383 loop = ioloop.IOLoop()
381
384
382 scheme = globals().get(scheme)
385 scheme = globals().get(scheme)
383
386
384 ins = ZMQStream(ctx.socket(zmq.XREP),loop)
387 ins = ZMQStream(ctx.socket(zmq.XREP),loop)
385 ins.bind(in_addr)
388 ins.bind(in_addr)
386 outs = ZMQStream(ctx.socket(zmq.XREP),loop)
389 outs = ZMQStream(ctx.socket(zmq.XREP),loop)
387 outs.bind(out_addr)
390 outs.bind(out_addr)
388 mons = ZMQStream(ctx.socket(zmq.PUB),loop)
391 mons = ZMQStream(ctx.socket(zmq.PUB),loop)
389 mons.connect(mon_addr)
392 mons.connect(mon_addr)
390 nots = ZMQStream(ctx.socket(zmq.SUB),loop)
393 nots = ZMQStream(ctx.socket(zmq.SUB),loop)
391 nots.setsockopt(zmq.SUBSCRIBE, '')
394 nots.setsockopt(zmq.SUBSCRIBE, '')
392 nots.connect(not_addr)
395 nots.connect(not_addr)
393
396
394 scheduler = TaskScheduler(ins,outs,mons,nots,scheme,loop)
397 scheduler = TaskScheduler(ins,outs,mons,nots,scheme,loop)
395
398
396 loop.start()
399 loop.start()
397
400
398
401
399 if __name__ == '__main__':
402 if __name__ == '__main__':
400 iface = 'tcp://127.0.0.1:%i'
403 iface = 'tcp://127.0.0.1:%i'
401 launch_scheduler(iface%12345,iface%1236,iface%12347,iface%12348)
404 launch_scheduler(iface%12345,iface%1236,iface%12347,iface%12348)
@@ -1,498 +1,499 b''
1 #!/usr/bin/env python
1 #!/usr/bin/env python
2 """edited session.py to work with streams, and move msg_type to the header
2 """edited session.py to work with streams, and move msg_type to the header
3 """
3 """
4
4
5
5
6 import os
6 import os
7 import sys
7 import sys
8 import traceback
8 import traceback
9 import pprint
9 import pprint
10 import uuid
10 import uuid
11
11
12 import zmq
12 import zmq
13 from zmq.utils import jsonapi
13 from zmq.utils import jsonapi
14 from zmq.eventloop.zmqstream import ZMQStream
14 from zmq.eventloop.zmqstream import ZMQStream
15
15
16 from IPython.zmq.pickleutil import can, uncan, canSequence, uncanSequence
16 from IPython.zmq.pickleutil import can, uncan, canSequence, uncanSequence
17 from IPython.zmq.newserialized import serialize, unserialize
17 from IPython.zmq.newserialized import serialize, unserialize
18
18
19 try:
19 try:
20 import cPickle
20 import cPickle
21 pickle = cPickle
21 pickle = cPickle
22 except:
22 except:
23 cPickle = None
23 cPickle = None
24 import pickle
24 import pickle
25
25
26 # packer priority: jsonlib[2], cPickle, simplejson/json, pickle
26 # packer priority: jsonlib[2], cPickle, simplejson/json, pickle
27 json_name = '' if not jsonapi.jsonmod else jsonapi.jsonmod.__name__
27 json_name = '' if not jsonapi.jsonmod else jsonapi.jsonmod.__name__
28 if json_name in ('jsonlib', 'jsonlib2'):
28 if json_name in ('jsonlib', 'jsonlib2'):
29 use_json = True
29 use_json = True
30 elif json_name:
30 elif json_name:
31 if cPickle is None:
31 if cPickle is None:
32 use_json = True
32 use_json = True
33 else:
33 else:
34 use_json = False
34 use_json = False
35 else:
35 else:
36 use_json = False
36 use_json = False
37
37
38 def squash_unicode(obj):
38 def squash_unicode(obj):
39 if isinstance(obj,dict):
39 if isinstance(obj,dict):
40 for key in obj.keys():
40 for key in obj.keys():
41 obj[key] = squash_unicode(obj[key])
41 obj[key] = squash_unicode(obj[key])
42 if isinstance(key, unicode):
42 if isinstance(key, unicode):
43 obj[squash_unicode(key)] = obj.pop(key)
43 obj[squash_unicode(key)] = obj.pop(key)
44 elif isinstance(obj, list):
44 elif isinstance(obj, list):
45 for i,v in enumerate(obj):
45 for i,v in enumerate(obj):
46 obj[i] = squash_unicode(v)
46 obj[i] = squash_unicode(v)
47 elif isinstance(obj, unicode):
47 elif isinstance(obj, unicode):
48 obj = obj.encode('utf8')
48 obj = obj.encode('utf8')
49 return obj
49 return obj
50
50
51 if use_json:
51 if use_json:
52 default_packer = jsonapi.dumps
52 default_packer = jsonapi.dumps
53 default_unpacker = lambda s: squash_unicode(jsonapi.loads(s))
53 default_unpacker = lambda s: squash_unicode(jsonapi.loads(s))
54 else:
54 else:
55 default_packer = lambda o: pickle.dumps(o,-1)
55 default_packer = lambda o: pickle.dumps(o,-1)
56 default_unpacker = pickle.loads
56 default_unpacker = pickle.loads
57
57
58
58
59 DELIM="<IDS|MSG>"
59 DELIM="<IDS|MSG>"
60
60
61 def wrap_exception():
61 def wrap_exception():
62 etype, evalue, tb = sys.exc_info()
62 etype, evalue, tb = sys.exc_info()
63 tb = traceback.format_exception(etype, evalue, tb)
63 tb = traceback.format_exception(etype, evalue, tb)
64 exc_content = {
64 exc_content = {
65 u'status' : u'error',
65 u'status' : u'error',
66 u'traceback' : tb,
66 u'traceback' : tb,
67 u'etype' : unicode(etype),
67 u'etype' : unicode(etype),
68 u'evalue' : unicode(evalue)
68 u'evalue' : unicode(evalue)
69 }
69 }
70 return exc_content
70 return exc_content
71
71
72 class KernelError(Exception):
72 class KernelError(Exception):
73 pass
73 pass
74
74
75 def unwrap_exception(content):
75 def unwrap_exception(content):
76 err = KernelError(content['etype'], content['evalue'])
76 err = KernelError(content['etype'], content['evalue'])
77 err.evalue = content['evalue']
77 err.evalue = content['evalue']
78 err.etype = content['etype']
78 err.etype = content['etype']
79 err.traceback = ''.join(content['traceback'])
79 err.traceback = ''.join(content['traceback'])
80 return err
80 return err
81
81
82
82
83 class Message(object):
83 class Message(object):
84 """A simple message object that maps dict keys to attributes.
84 """A simple message object that maps dict keys to attributes.
85
85
86 A Message can be created from a dict and a dict from a Message instance
86 A Message can be created from a dict and a dict from a Message instance
87 simply by calling dict(msg_obj)."""
87 simply by calling dict(msg_obj)."""
88
88
89 def __init__(self, msg_dict):
89 def __init__(self, msg_dict):
90 dct = self.__dict__
90 dct = self.__dict__
91 for k, v in dict(msg_dict).iteritems():
91 for k, v in dict(msg_dict).iteritems():
92 if isinstance(v, dict):
92 if isinstance(v, dict):
93 v = Message(v)
93 v = Message(v)
94 dct[k] = v
94 dct[k] = v
95
95
96 # Having this iterator lets dict(msg_obj) work out of the box.
96 # Having this iterator lets dict(msg_obj) work out of the box.
97 def __iter__(self):
97 def __iter__(self):
98 return iter(self.__dict__.iteritems())
98 return iter(self.__dict__.iteritems())
99
99
100 def __repr__(self):
100 def __repr__(self):
101 return repr(self.__dict__)
101 return repr(self.__dict__)
102
102
103 def __str__(self):
103 def __str__(self):
104 return pprint.pformat(self.__dict__)
104 return pprint.pformat(self.__dict__)
105
105
106 def __contains__(self, k):
106 def __contains__(self, k):
107 return k in self.__dict__
107 return k in self.__dict__
108
108
109 def __getitem__(self, k):
109 def __getitem__(self, k):
110 return self.__dict__[k]
110 return self.__dict__[k]
111
111
112
112
113 def msg_header(msg_id, msg_type, username, session):
113 def msg_header(msg_id, msg_type, username, session):
114 return locals()
114 return locals()
115 # return {
115 # return {
116 # 'msg_id' : msg_id,
116 # 'msg_id' : msg_id,
117 # 'msg_type': msg_type,
117 # 'msg_type': msg_type,
118 # 'username' : username,
118 # 'username' : username,
119 # 'session' : session
119 # 'session' : session
120 # }
120 # }
121
121
122
122
123 def extract_header(msg_or_header):
123 def extract_header(msg_or_header):
124 """Given a message or header, return the header."""
124 """Given a message or header, return the header."""
125 if not msg_or_header:
125 if not msg_or_header:
126 return {}
126 return {}
127 try:
127 try:
128 # See if msg_or_header is the entire message.
128 # See if msg_or_header is the entire message.
129 h = msg_or_header['header']
129 h = msg_or_header['header']
130 except KeyError:
130 except KeyError:
131 try:
131 try:
132 # See if msg_or_header is just the header
132 # See if msg_or_header is just the header
133 h = msg_or_header['msg_id']
133 h = msg_or_header['msg_id']
134 except KeyError:
134 except KeyError:
135 raise
135 raise
136 else:
136 else:
137 h = msg_or_header
137 h = msg_or_header
138 if not isinstance(h, dict):
138 if not isinstance(h, dict):
139 h = dict(h)
139 h = dict(h)
140 return h
140 return h
141
141
142 def rekey(dikt):
142 def rekey(dikt):
143 """rekey a dict that has been forced to use str keys where there should be
143 """rekey a dict that has been forced to use str keys where there should be
144 ints by json. This belongs in the jsonutil added by fperez."""
144 ints by json. This belongs in the jsonutil added by fperez."""
145 for k in dikt.iterkeys():
145 for k in dikt.iterkeys():
146 if isinstance(k, str):
146 if isinstance(k, str):
147 ik=fk=None
147 ik=fk=None
148 try:
148 try:
149 ik = int(k)
149 ik = int(k)
150 except ValueError:
150 except ValueError:
151 try:
151 try:
152 fk = float(k)
152 fk = float(k)
153 except ValueError:
153 except ValueError:
154 continue
154 continue
155 if ik is not None:
155 if ik is not None:
156 nk = ik
156 nk = ik
157 else:
157 else:
158 nk = fk
158 nk = fk
159 if nk in dikt:
159 if nk in dikt:
160 raise KeyError("already have key %r"%nk)
160 raise KeyError("already have key %r"%nk)
161 dikt[nk] = dikt.pop(k)
161 dikt[nk] = dikt.pop(k)
162 return dikt
162 return dikt
163
163
164 def serialize_object(obj, threshold=64e-6):
164 def serialize_object(obj, threshold=64e-6):
165 """serialize an object into a list of sendable buffers.
165 """serialize an object into a list of sendable buffers.
166
166
167 Returns: (pmd, bufs)
167 Returns: (pmd, bufs)
168 where pmd is the pickled metadata wrapper, and bufs
168 where pmd is the pickled metadata wrapper, and bufs
169 is a list of data buffers"""
169 is a list of data buffers"""
170 # threshold is 100 B
170 # threshold is 100 B
171 databuffers = []
171 databuffers = []
172 if isinstance(obj, (list, tuple)):
172 if isinstance(obj, (list, tuple)):
173 clist = canSequence(obj)
173 clist = canSequence(obj)
174 slist = map(serialize, clist)
174 slist = map(serialize, clist)
175 for s in slist:
175 for s in slist:
176 if s.typeDescriptor in ('buffer', 'ndarray') or s.getDataSize() > threshold:
176 if s.typeDescriptor in ('buffer', 'ndarray') or s.getDataSize() > threshold:
177 databuffers.append(s.getData())
177 databuffers.append(s.getData())
178 s.data = None
178 s.data = None
179 return pickle.dumps(slist,-1), databuffers
179 return pickle.dumps(slist,-1), databuffers
180 elif isinstance(obj, dict):
180 elif isinstance(obj, dict):
181 sobj = {}
181 sobj = {}
182 for k in sorted(obj.iterkeys()):
182 for k in sorted(obj.iterkeys()):
183 s = serialize(can(obj[k]))
183 s = serialize(can(obj[k]))
184 if s.getDataSize() > threshold:
184 if s.getDataSize() > threshold:
185 databuffers.append(s.getData())
185 databuffers.append(s.getData())
186 s.data = None
186 s.data = None
187 sobj[k] = s
187 sobj[k] = s
188 return pickle.dumps(sobj,-1),databuffers
188 return pickle.dumps(sobj,-1),databuffers
189 else:
189 else:
190 s = serialize(can(obj))
190 s = serialize(can(obj))
191 if s.getDataSize() > threshold:
191 if s.getDataSize() > threshold:
192 databuffers.append(s.getData())
192 databuffers.append(s.getData())
193 s.data = None
193 s.data = None
194 return pickle.dumps(s,-1),databuffers
194 return pickle.dumps(s,-1),databuffers
195
195
196
196
197 def unserialize_object(bufs):
197 def unserialize_object(bufs):
198 """reconstruct an object serialized by serialize_object from data buffers"""
198 """reconstruct an object serialized by serialize_object from data buffers"""
199 bufs = list(bufs)
199 bufs = list(bufs)
200 sobj = pickle.loads(bufs.pop(0))
200 sobj = pickle.loads(bufs.pop(0))
201 if isinstance(sobj, (list, tuple)):
201 if isinstance(sobj, (list, tuple)):
202 for s in sobj:
202 for s in sobj:
203 if s.data is None:
203 if s.data is None:
204 s.data = bufs.pop(0)
204 s.data = bufs.pop(0)
205 return uncanSequence(map(unserialize, sobj))
205 return uncanSequence(map(unserialize, sobj))
206 elif isinstance(sobj, dict):
206 elif isinstance(sobj, dict):
207 newobj = {}
207 newobj = {}
208 for k in sorted(sobj.iterkeys()):
208 for k in sorted(sobj.iterkeys()):
209 s = sobj[k]
209 s = sobj[k]
210 if s.data is None:
210 if s.data is None:
211 s.data = bufs.pop(0)
211 s.data = bufs.pop(0)
212 newobj[k] = uncan(unserialize(s))
212 newobj[k] = uncan(unserialize(s))
213 return newobj
213 return newobj
214 else:
214 else:
215 if sobj.data is None:
215 if sobj.data is None:
216 sobj.data = bufs.pop(0)
216 sobj.data = bufs.pop(0)
217 return uncan(unserialize(sobj))
217 return uncan(unserialize(sobj))
218
218
219 def pack_apply_message(f, args, kwargs, threshold=64e-6):
219 def pack_apply_message(f, args, kwargs, threshold=64e-6):
220 """pack up a function, args, and kwargs to be sent over the wire
220 """pack up a function, args, and kwargs to be sent over the wire
221 as a series of buffers. Any object whose data is larger than `threshold`
221 as a series of buffers. Any object whose data is larger than `threshold`
222 will not have their data copied (currently only numpy arrays support zero-copy)"""
222 will not have their data copied (currently only numpy arrays support zero-copy)"""
223 msg = [pickle.dumps(can(f),-1)]
223 msg = [pickle.dumps(can(f),-1)]
224 databuffers = [] # for large objects
224 databuffers = [] # for large objects
225 sargs, bufs = serialize_object(args,threshold)
225 sargs, bufs = serialize_object(args,threshold)
226 msg.append(sargs)
226 msg.append(sargs)
227 databuffers.extend(bufs)
227 databuffers.extend(bufs)
228 skwargs, bufs = serialize_object(kwargs,threshold)
228 skwargs, bufs = serialize_object(kwargs,threshold)
229 msg.append(skwargs)
229 msg.append(skwargs)
230 databuffers.extend(bufs)
230 databuffers.extend(bufs)
231 msg.extend(databuffers)
231 msg.extend(databuffers)
232 return msg
232 return msg
233
233
234 def unpack_apply_message(bufs, g=None, copy=True):
234 def unpack_apply_message(bufs, g=None, copy=True):
235 """unpack f,args,kwargs from buffers packed by pack_apply_message()
235 """unpack f,args,kwargs from buffers packed by pack_apply_message()
236 Returns: original f,args,kwargs"""
236 Returns: original f,args,kwargs"""
237 bufs = list(bufs) # allow us to pop
237 bufs = list(bufs) # allow us to pop
238 assert len(bufs) >= 3, "not enough buffers!"
238 assert len(bufs) >= 3, "not enough buffers!"
239 if not copy:
239 if not copy:
240 for i in range(3):
240 for i in range(3):
241 bufs[i] = bufs[i].bytes
241 bufs[i] = bufs[i].bytes
242 cf = pickle.loads(bufs.pop(0))
242 cf = pickle.loads(bufs.pop(0))
243 sargs = list(pickle.loads(bufs.pop(0)))
243 sargs = list(pickle.loads(bufs.pop(0)))
244 skwargs = dict(pickle.loads(bufs.pop(0)))
244 skwargs = dict(pickle.loads(bufs.pop(0)))
245 # print sargs, skwargs
245 # print sargs, skwargs
246 f = uncan(cf, g)
246 f = uncan(cf, g)
247 for sa in sargs:
247 for sa in sargs:
248 if sa.data is None:
248 if sa.data is None:
249 m = bufs.pop(0)
249 m = bufs.pop(0)
250 if sa.getTypeDescriptor() in ('buffer', 'ndarray'):
250 if sa.getTypeDescriptor() in ('buffer', 'ndarray'):
251 if copy:
251 if copy:
252 sa.data = buffer(m)
252 sa.data = buffer(m)
253 else:
253 else:
254 sa.data = m.buffer
254 sa.data = m.buffer
255 else:
255 else:
256 if copy:
256 if copy:
257 sa.data = m
257 sa.data = m
258 else:
258 else:
259 sa.data = m.bytes
259 sa.data = m.bytes
260
260
261 args = uncanSequence(map(unserialize, sargs), g)
261 args = uncanSequence(map(unserialize, sargs), g)
262 kwargs = {}
262 kwargs = {}
263 for k in sorted(skwargs.iterkeys()):
263 for k in sorted(skwargs.iterkeys()):
264 sa = skwargs[k]
264 sa = skwargs[k]
265 if sa.data is None:
265 if sa.data is None:
266 sa.data = bufs.pop(0)
266 sa.data = bufs.pop(0)
267 kwargs[k] = uncan(unserialize(sa), g)
267 kwargs[k] = uncan(unserialize(sa), g)
268
268
269 return f,args,kwargs
269 return f,args,kwargs
270
270
271 class StreamSession(object):
271 class StreamSession(object):
272 """tweaked version of IPython.zmq.session.Session, for development in Parallel"""
272 """tweaked version of IPython.zmq.session.Session, for development in Parallel"""
273 debug=False
273 debug=False
274 def __init__(self, username=None, session=None, packer=None, unpacker=None):
274 def __init__(self, username=None, session=None, packer=None, unpacker=None):
275 if username is None:
275 if username is None:
276 username = os.environ.get('USER','username')
276 username = os.environ.get('USER','username')
277 self.username = username
277 self.username = username
278 if session is None:
278 if session is None:
279 self.session = str(uuid.uuid4())
279 self.session = str(uuid.uuid4())
280 else:
280 else:
281 self.session = session
281 self.session = session
282 self.msg_id = str(uuid.uuid4())
282 self.msg_id = str(uuid.uuid4())
283 if packer is None:
283 if packer is None:
284 self.pack = default_packer
284 self.pack = default_packer
285 else:
285 else:
286 if not callable(packer):
286 if not callable(packer):
287 raise TypeError("packer must be callable, not %s"%type(packer))
287 raise TypeError("packer must be callable, not %s"%type(packer))
288 self.pack = packer
288 self.pack = packer
289
289
290 if unpacker is None:
290 if unpacker is None:
291 self.unpack = default_unpacker
291 self.unpack = default_unpacker
292 else:
292 else:
293 if not callable(unpacker):
293 if not callable(unpacker):
294 raise TypeError("unpacker must be callable, not %s"%type(unpacker))
294 raise TypeError("unpacker must be callable, not %s"%type(unpacker))
295 self.unpack = unpacker
295 self.unpack = unpacker
296
296
297 self.none = self.pack({})
297 self.none = self.pack({})
298
298
299 def msg_header(self, msg_type):
299 def msg_header(self, msg_type):
300 h = msg_header(self.msg_id, msg_type, self.username, self.session)
300 h = msg_header(self.msg_id, msg_type, self.username, self.session)
301 self.msg_id = str(uuid.uuid4())
301 self.msg_id = str(uuid.uuid4())
302 return h
302 return h
303
303
304 def msg(self, msg_type, content=None, parent=None, subheader=None):
304 def msg(self, msg_type, content=None, parent=None, subheader=None):
305 msg = {}
305 msg = {}
306 msg['header'] = self.msg_header(msg_type)
306 msg['header'] = self.msg_header(msg_type)
307 msg['msg_id'] = msg['header']['msg_id']
307 msg['msg_id'] = msg['header']['msg_id']
308 msg['parent_header'] = {} if parent is None else extract_header(parent)
308 msg['parent_header'] = {} if parent is None else extract_header(parent)
309 msg['msg_type'] = msg_type
309 msg['msg_type'] = msg_type
310 msg['content'] = {} if content is None else content
310 msg['content'] = {} if content is None else content
311 sub = {} if subheader is None else subheader
311 sub = {} if subheader is None else subheader
312 msg['header'].update(sub)
312 msg['header'].update(sub)
313 return msg
313 return msg
314
314
315 def send(self, stream, msg_type, content=None, buffers=None, parent=None, subheader=None, ident=None):
315 def send(self, stream, msg_type, content=None, buffers=None, parent=None, subheader=None, ident=None):
316 """Build and send a message via stream or socket.
316 """Build and send a message via stream or socket.
317
317
318 Parameters
318 Parameters
319 ----------
319 ----------
320
320
321 msg_type : str or Message/dict
321 msg_type : str or Message/dict
322 Normally, msg_type will be
322 Normally, msg_type will be
323
323
324
324
325
325
326 Returns
326 Returns
327 -------
327 -------
328 (msg,sent) : tuple
328 (msg,sent) : tuple
329 msg : Message
329 msg : Message
330 the nice wrapped dict-like object containing the headers
330 the nice wrapped dict-like object containing the headers
331
331
332 """
332 """
333 if isinstance(msg_type, (Message, dict)):
333 if isinstance(msg_type, (Message, dict)):
334 # we got a Message, not a msg_type
334 # we got a Message, not a msg_type
335 # don't build a new Message
335 # don't build a new Message
336 msg = msg_type
336 msg = msg_type
337 content = msg['content']
337 content = msg['content']
338 else:
338 else:
339 msg = self.msg(msg_type, content, parent, subheader)
339 msg = self.msg(msg_type, content, parent, subheader)
340 buffers = [] if buffers is None else buffers
340 buffers = [] if buffers is None else buffers
341 to_send = []
341 to_send = []
342 if isinstance(ident, list):
342 if isinstance(ident, list):
343 # accept list of idents
343 # accept list of idents
344 to_send.extend(ident)
344 to_send.extend(ident)
345 elif ident is not None:
345 elif ident is not None:
346 to_send.append(ident)
346 to_send.append(ident)
347 to_send.append(DELIM)
347 to_send.append(DELIM)
348 to_send.append(self.pack(msg['header']))
348 to_send.append(self.pack(msg['header']))
349 to_send.append(self.pack(msg['parent_header']))
349 to_send.append(self.pack(msg['parent_header']))
350 # if parent is None:
350 # if parent is None:
351 # to_send.append(self.none)
351 # to_send.append(self.none)
352 # else:
352 # else:
353 # to_send.append(self.pack(dict(parent)))
353 # to_send.append(self.pack(dict(parent)))
354 if content is None:
354 if content is None:
355 content = self.none
355 content = self.none
356 elif isinstance(content, dict):
356 elif isinstance(content, dict):
357 content = self.pack(content)
357 content = self.pack(content)
358 elif isinstance(content, str):
358 elif isinstance(content, str):
359 # content is already packed, as in a relayed message
359 # content is already packed, as in a relayed message
360 pass
360 pass
361 else:
361 else:
362 raise TypeError("Content incorrect type: %s"%type(content))
362 raise TypeError("Content incorrect type: %s"%type(content))
363 to_send.append(content)
363 to_send.append(content)
364 flag = 0
364 flag = 0
365 if buffers:
365 if buffers:
366 flag = zmq.SNDMORE
366 flag = zmq.SNDMORE
367 stream.send_multipart(to_send, flag, copy=False)
367 stream.send_multipart(to_send, flag, copy=False)
368 for b in buffers[:-1]:
368 for b in buffers[:-1]:
369 stream.send(b, flag, copy=False)
369 stream.send(b, flag, copy=False)
370 if buffers:
370 if buffers:
371 stream.send(buffers[-1], copy=False)
371 stream.send(buffers[-1], copy=False)
372 omsg = Message(msg)
372 omsg = Message(msg)
373 if self.debug:
373 if self.debug:
374 pprint.pprint(omsg)
374 pprint.pprint(omsg)
375 pprint.pprint(to_send)
375 pprint.pprint(to_send)
376 pprint.pprint(buffers)
376 pprint.pprint(buffers)
377 # return both the msg object and the buffers
377 # return both the msg object and the buffers
378 return omsg
378 return omsg
379
379
380 def send_raw(self, stream, msg, flags=0, copy=True, idents=None):
380 def send_raw(self, stream, msg, flags=0, copy=True, idents=None):
381 """send a raw message via idents.
381 """send a raw message via idents.
382
382
383 Parameters
383 Parameters
384 ----------
384 ----------
385 msg : list of sendable buffers"""
385 msg : list of sendable buffers"""
386 to_send = []
386 to_send = []
387 if isinstance(ident, str):
387 if isinstance(ident, str):
388 ident = [ident]
388 ident = [ident]
389 if ident is not None:
389 if ident is not None:
390 to_send.extend(ident)
390 to_send.extend(ident)
391 to_send.append(DELIM)
391 to_send.append(DELIM)
392 to_send.extend(msg)
392 to_send.extend(msg)
393 stream.send_multipart(msg, flags, copy=copy)
393 stream.send_multipart(msg, flags, copy=copy)
394
394
395 def recv(self, socket, mode=zmq.NOBLOCK, content=True, copy=True):
395 def recv(self, socket, mode=zmq.NOBLOCK, content=True, copy=True):
396 """receives and unpacks a message
396 """receives and unpacks a message
397 returns [idents], msg"""
397 returns [idents], msg"""
398 if isinstance(socket, ZMQStream):
398 if isinstance(socket, ZMQStream):
399 socket = socket.socket
399 socket = socket.socket
400 try:
400 try:
401 msg = socket.recv_multipart(mode)
401 msg = socket.recv_multipart(mode)
402 except zmq.ZMQError, e:
402 except zmq.ZMQError, e:
403 if e.errno == zmq.EAGAIN:
403 if e.errno == zmq.EAGAIN:
404 # We can convert EAGAIN to None as we know in this case
404 # We can convert EAGAIN to None as we know in this case
405 # recv_json won't return None.
405 # recv_json won't return None.
406 return None
406 return None
407 else:
407 else:
408 raise
408 raise
409 # return an actual Message object
409 # return an actual Message object
410 # determine the number of idents by trying to unpack them.
410 # determine the number of idents by trying to unpack them.
411 # this is terrible:
411 # this is terrible:
412 idents, msg = self.feed_identities(msg, copy)
412 idents, msg = self.feed_identities(msg, copy)
413 try:
413 try:
414 return idents, self.unpack_message(msg, content=content, copy=copy)
414 return idents, self.unpack_message(msg, content=content, copy=copy)
415 except Exception, e:
415 except Exception, e:
416 print idents, msg
416 print idents, msg
417 # TODO: handle it
417 # TODO: handle it
418 raise e
418 raise e
419
419
420 def feed_identities(self, msg, copy=True):
420 def feed_identities(self, msg, copy=True):
421 """This is a completely horrible thing, but it strips the zmq
421 """This is a completely horrible thing, but it strips the zmq
422 ident prefixes off of a message. It will break if any identities
422 ident prefixes off of a message. It will break if any identities
423 are unpackable by self.unpack."""
423 are unpackable by self.unpack."""
424 msg = list(msg)
424 msg = list(msg)
425 idents = []
425 idents = []
426 while len(msg) > 3:
426 while len(msg) > 3:
427 if copy:
427 if copy:
428 s = msg[0]
428 s = msg[0]
429 else:
429 else:
430 s = msg[0].bytes
430 s = msg[0].bytes
431 if s == DELIM:
431 if s == DELIM:
432 msg.pop(0)
432 msg.pop(0)
433 break
433 break
434 else:
434 else:
435 idents.append(s)
435 idents.append(s)
436 msg.pop(0)
436 msg.pop(0)
437
437
438 return idents, msg
438 return idents, msg
439
439
440 def unpack_message(self, msg, content=True, copy=True):
440 def unpack_message(self, msg, content=True, copy=True):
441 """return a message object from the format
441 """Return a message object from the format
442 sent by self.send.
442 sent by self.send.
443
443
444 parameters:
444 Parameters:
445 -----------
445
446
446 content : bool (True)
447 content : bool (True)
447 whether to unpack the content dict (True),
448 whether to unpack the content dict (True),
448 or leave it serialized (False)
449 or leave it serialized (False)
449
450
450 copy : bool (True)
451 copy : bool (True)
451 whether to return the bytes (True),
452 whether to return the bytes (True),
452 or the non-copying Message object in each place (False)
453 or the non-copying Message object in each place (False)
453
454
454 """
455 """
455 if not len(msg) >= 3:
456 if not len(msg) >= 3:
456 raise TypeError("malformed message, must have at least 3 elements")
457 raise TypeError("malformed message, must have at least 3 elements")
457 message = {}
458 message = {}
458 if not copy:
459 if not copy:
459 for i in range(3):
460 for i in range(3):
460 msg[i] = msg[i].bytes
461 msg[i] = msg[i].bytes
461 message['header'] = self.unpack(msg[0])
462 message['header'] = self.unpack(msg[0])
462 message['msg_type'] = message['header']['msg_type']
463 message['msg_type'] = message['header']['msg_type']
463 message['parent_header'] = self.unpack(msg[1])
464 message['parent_header'] = self.unpack(msg[1])
464 if content:
465 if content:
465 message['content'] = self.unpack(msg[2])
466 message['content'] = self.unpack(msg[2])
466 else:
467 else:
467 message['content'] = msg[2]
468 message['content'] = msg[2]
468
469
469 # message['buffers'] = msg[3:]
470 # message['buffers'] = msg[3:]
470 # else:
471 # else:
471 # message['header'] = self.unpack(msg[0].bytes)
472 # message['header'] = self.unpack(msg[0].bytes)
472 # message['msg_type'] = message['header']['msg_type']
473 # message['msg_type'] = message['header']['msg_type']
473 # message['parent_header'] = self.unpack(msg[1].bytes)
474 # message['parent_header'] = self.unpack(msg[1].bytes)
474 # if content:
475 # if content:
475 # message['content'] = self.unpack(msg[2].bytes)
476 # message['content'] = self.unpack(msg[2].bytes)
476 # else:
477 # else:
477 # message['content'] = msg[2].bytes
478 # message['content'] = msg[2].bytes
478
479
479 message['buffers'] = msg[3:]# [ m.buffer for m in msg[3:] ]
480 message['buffers'] = msg[3:]# [ m.buffer for m in msg[3:] ]
480 return message
481 return message
481
482
482
483
483
484
484 def test_msg2obj():
485 def test_msg2obj():
485 am = dict(x=1)
486 am = dict(x=1)
486 ao = Message(am)
487 ao = Message(am)
487 assert ao.x == am['x']
488 assert ao.x == am['x']
488
489
489 am['y'] = dict(z=1)
490 am['y'] = dict(z=1)
490 ao = Message(am)
491 ao = Message(am)
491 assert ao.y.z == am['y']['z']
492 assert ao.y.z == am['y']['z']
492
493
493 k1, k2 = 'y', 'z'
494 k1, k2 = 'y', 'z'
494 assert ao[k1][k2] == am[k1][k2]
495 assert ao[k1][k2] == am[k1][k2]
495
496
496 am2 = dict(ao)
497 am2 = dict(ao)
497 assert am['x'] == am2['x']
498 assert am['x'] == am2['x']
498 assert am['y']['z'] == am2['y']['z']
499 assert am['y']['z'] == am2['y']['z']
General Comments 0
You need to be logged in to leave comments. Login now