##// END OF EJS Templates
add log_errors decorator for on_recv callbacks...
MinRK -
Show More
@@ -1,180 +1,181 b''
1 1 #!/usr/bin/env python
2 2 """
3 3 A multi-heart Heartbeat system using PUB and XREP sockets. pings are sent out on the PUB,
4 4 and hearts are tracked based on their XREQ identities.
5 5
6 6 Authors:
7 7
8 8 * Min RK
9 9 """
10 10 #-----------------------------------------------------------------------------
11 11 # Copyright (C) 2010-2011 The IPython Development Team
12 12 #
13 13 # Distributed under the terms of the BSD License. The full license is in
14 14 # the file COPYING, distributed as part of this software.
15 15 #-----------------------------------------------------------------------------
16 16
17 17 from __future__ import print_function
18 18 import time
19 19 import uuid
20 20
21 21 import zmq
22 22 from zmq.devices import ThreadDevice
23 23 from zmq.eventloop import ioloop, zmqstream
24 24
25 25 from IPython.config.configurable import LoggingConfigurable
26 26 from IPython.utils.traitlets import Set, Instance, CFloat, Integer
27 27
28 from IPython.parallel.util import asbytes
28 from IPython.parallel.util import asbytes, log_errors
29 29
30 30 class Heart(object):
31 31 """A basic heart object for responding to a HeartMonitor.
32 32 This is a simple wrapper with defaults for the most common
33 33 Device model for responding to heartbeats.
34 34
35 35 It simply builds a threadsafe zmq.FORWARDER Device, defaulting to using
36 36 SUB/XREQ for in/out.
37 37
38 38 You can specify the XREQ's IDENTITY via the optional heart_id argument."""
39 39 device=None
40 40 id=None
41 41 def __init__(self, in_addr, out_addr, in_type=zmq.SUB, out_type=zmq.DEALER, heart_id=None):
42 42 self.device = ThreadDevice(zmq.FORWARDER, in_type, out_type)
43 43 # do not allow the device to share global Context.instance,
44 44 # which is the default behavior in pyzmq > 2.1.10
45 45 self.device.context_factory = zmq.Context
46 46
47 47 self.device.daemon=True
48 48 self.device.connect_in(in_addr)
49 49 self.device.connect_out(out_addr)
50 50 if in_type == zmq.SUB:
51 51 self.device.setsockopt_in(zmq.SUBSCRIBE, b"")
52 52 if heart_id is None:
53 53 heart_id = uuid.uuid4().bytes
54 54 self.device.setsockopt_out(zmq.IDENTITY, heart_id)
55 55 self.id = heart_id
56 56
57 57 def start(self):
58 58 return self.device.start()
59 59
60 60
61 61 class HeartMonitor(LoggingConfigurable):
62 62 """A basic HeartMonitor class
63 63 pingstream: a PUB stream
64 64 pongstream: an XREP stream
65 65 period: the period of the heartbeat in milliseconds"""
66 66
67 67 period = Integer(3000, config=True,
68 68 help='The frequency at which the Hub pings the engines for heartbeats '
69 69 '(in ms)',
70 70 )
71 71
72 72 pingstream=Instance('zmq.eventloop.zmqstream.ZMQStream')
73 73 pongstream=Instance('zmq.eventloop.zmqstream.ZMQStream')
74 74 loop = Instance('zmq.eventloop.ioloop.IOLoop')
75 75 def _loop_default(self):
76 76 return ioloop.IOLoop.instance()
77 77
78 78 # not settable:
79 79 hearts=Set()
80 80 responses=Set()
81 81 on_probation=Set()
82 82 last_ping=CFloat(0)
83 83 _new_handlers = Set()
84 84 _failure_handlers = Set()
85 85 lifetime = CFloat(0)
86 86 tic = CFloat(0)
87 87
88 88 def __init__(self, **kwargs):
89 89 super(HeartMonitor, self).__init__(**kwargs)
90 90
91 91 self.pongstream.on_recv(self.handle_pong)
92 92
93 93 def start(self):
94 94 self.tic = time.time()
95 95 self.caller = ioloop.PeriodicCallback(self.beat, self.period, self.loop)
96 96 self.caller.start()
97 97
98 98 def add_new_heart_handler(self, handler):
99 99 """add a new handler for new hearts"""
100 100 self.log.debug("heartbeat::new_heart_handler: %s", handler)
101 101 self._new_handlers.add(handler)
102 102
103 103 def add_heart_failure_handler(self, handler):
104 104 """add a new handler for heart failure"""
105 105 self.log.debug("heartbeat::new heart failure handler: %s", handler)
106 106 self._failure_handlers.add(handler)
107 107
108 108 def beat(self):
109 109 self.pongstream.flush()
110 110 self.last_ping = self.lifetime
111 111
112 112 toc = time.time()
113 113 self.lifetime += toc-self.tic
114 114 self.tic = toc
115 115 self.log.debug("heartbeat::sending %s", self.lifetime)
116 116 goodhearts = self.hearts.intersection(self.responses)
117 117 missed_beats = self.hearts.difference(goodhearts)
118 118 heartfailures = self.on_probation.intersection(missed_beats)
119 119 newhearts = self.responses.difference(goodhearts)
120 120 map(self.handle_new_heart, newhearts)
121 121 map(self.handle_heart_failure, heartfailures)
122 122 self.on_probation = missed_beats.intersection(self.hearts)
123 123 self.responses = set()
124 124 # print self.on_probation, self.hearts
125 125 # self.log.debug("heartbeat::beat %.3f, %i beating hearts", self.lifetime, len(self.hearts))
126 126 self.pingstream.send(asbytes(str(self.lifetime)))
127 127 # flush stream to force immediate socket send
128 128 self.pingstream.flush()
129 129
130 130 def handle_new_heart(self, heart):
131 131 if self._new_handlers:
132 132 for handler in self._new_handlers:
133 133 handler(heart)
134 134 else:
135 135 self.log.info("heartbeat::yay, got new heart %s!", heart)
136 136 self.hearts.add(heart)
137 137
138 138 def handle_heart_failure(self, heart):
139 139 if self._failure_handlers:
140 140 for handler in self._failure_handlers:
141 141 try:
142 142 handler(heart)
143 143 except Exception as e:
144 144 self.log.error("heartbeat::Bad Handler! %s", handler, exc_info=True)
145 145 pass
146 146 else:
147 147 self.log.info("heartbeat::Heart %s failed :(", heart)
148 148 self.hearts.remove(heart)
149 149
150 150
151 @log_errors
151 152 def handle_pong(self, msg):
152 153 "a heart just beat"
153 154 current = asbytes(str(self.lifetime))
154 155 last = asbytes(str(self.last_ping))
155 156 if msg[1] == current:
156 157 delta = time.time()-self.tic
157 158 # self.log.debug("heartbeat::heart %r took %.2f ms to respond"%(msg[0], 1000*delta))
158 159 self.responses.add(msg[0])
159 160 elif msg[1] == last:
160 161 delta = time.time()-self.tic + (self.lifetime-self.last_ping)
161 162 self.log.warn("heartbeat::heart %r missed a beat, and took %.2f ms to respond", msg[0], 1000*delta)
162 163 self.responses.add(msg[0])
163 164 else:
164 165 self.log.warn("heartbeat::got bad heartbeat (possibly old?): %s (current=%.3f)", msg[1], self.lifetime)
165 166
166 167
167 168 if __name__ == '__main__':
168 169 loop = ioloop.IOLoop.instance()
169 170 context = zmq.Context()
170 171 pub = context.socket(zmq.PUB)
171 172 pub.bind('tcp://127.0.0.1:5555')
172 173 xrep = context.socket(zmq.ROUTER)
173 174 xrep.bind('tcp://127.0.0.1:5556')
174 175
175 176 outstream = zmqstream.ZMQStream(pub, loop)
176 177 instream = zmqstream.ZMQStream(xrep, loop)
177 178
178 179 hb = HeartMonitor(loop, outstream, instream)
179 180
180 181 loop.start()
@@ -1,1293 +1,1295 b''
1 1 """The IPython Controller Hub with 0MQ
2 2 This is the master object that handles connections from engines and clients,
3 3 and monitors traffic through the various queues.
4 4
5 5 Authors:
6 6
7 7 * Min RK
8 8 """
9 9 #-----------------------------------------------------------------------------
10 10 # Copyright (C) 2010-2011 The IPython Development Team
11 11 #
12 12 # Distributed under the terms of the BSD License. The full license is in
13 13 # the file COPYING, distributed as part of this software.
14 14 #-----------------------------------------------------------------------------
15 15
16 16 #-----------------------------------------------------------------------------
17 17 # Imports
18 18 #-----------------------------------------------------------------------------
19 19 from __future__ import print_function
20 20
21 21 import sys
22 22 import time
23 23 from datetime import datetime
24 24
25 25 import zmq
26 26 from zmq.eventloop import ioloop
27 27 from zmq.eventloop.zmqstream import ZMQStream
28 28
29 29 # internal:
30 30 from IPython.utils.importstring import import_item
31 31 from IPython.utils.traitlets import (
32 32 HasTraits, Instance, Integer, Unicode, Dict, Set, Tuple, CBytes, DottedObjectName
33 33 )
34 34
35 35 from IPython.parallel import error, util
36 36 from IPython.parallel.factory import RegistrationFactory
37 37
38 38 from IPython.zmq.session import SessionFactory
39 39
40 40 from .heartmonitor import HeartMonitor
41 41
42 42 #-----------------------------------------------------------------------------
43 43 # Code
44 44 #-----------------------------------------------------------------------------
45 45
46 46 def _passer(*args, **kwargs):
47 47 return
48 48
49 49 def _printer(*args, **kwargs):
50 50 print (args)
51 51 print (kwargs)
52 52
53 53 def empty_record():
54 54 """Return an empty dict with all record keys."""
55 55 return {
56 56 'msg_id' : None,
57 57 'header' : None,
58 58 'content': None,
59 59 'buffers': None,
60 60 'submitted': None,
61 61 'client_uuid' : None,
62 62 'engine_uuid' : None,
63 63 'started': None,
64 64 'completed': None,
65 65 'resubmitted': None,
66 66 'result_header' : None,
67 67 'result_content' : None,
68 68 'result_buffers' : None,
69 69 'queue' : None,
70 70 'pyin' : None,
71 71 'pyout': None,
72 72 'pyerr': None,
73 73 'stdout': '',
74 74 'stderr': '',
75 75 }
76 76
77 77 def init_record(msg):
78 78 """Initialize a TaskRecord based on a request."""
79 79 header = msg['header']
80 80 return {
81 81 'msg_id' : header['msg_id'],
82 82 'header' : header,
83 83 'content': msg['content'],
84 84 'buffers': msg['buffers'],
85 85 'submitted': header['date'],
86 86 'client_uuid' : None,
87 87 'engine_uuid' : None,
88 88 'started': None,
89 89 'completed': None,
90 90 'resubmitted': None,
91 91 'result_header' : None,
92 92 'result_content' : None,
93 93 'result_buffers' : None,
94 94 'queue' : None,
95 95 'pyin' : None,
96 96 'pyout': None,
97 97 'pyerr': None,
98 98 'stdout': '',
99 99 'stderr': '',
100 100 }
101 101
102 102
103 103 class EngineConnector(HasTraits):
104 104 """A simple object for accessing the various zmq connections of an object.
105 105 Attributes are:
106 106 id (int): engine ID
107 107 uuid (str): uuid (unused?)
108 108 queue (str): identity of queue's XREQ socket
109 109 registration (str): identity of registration XREQ socket
110 110 heartbeat (str): identity of heartbeat XREQ socket
111 111 """
112 112 id=Integer(0)
113 113 queue=CBytes()
114 114 control=CBytes()
115 115 registration=CBytes()
116 116 heartbeat=CBytes()
117 117 pending=Set()
118 118
119 119 class HubFactory(RegistrationFactory):
120 120 """The Configurable for setting up a Hub."""
121 121
122 122 # port-pairs for monitoredqueues:
123 123 hb = Tuple(Integer,Integer,config=True,
124 124 help="""XREQ/SUB Port pair for Engine heartbeats""")
125 125 def _hb_default(self):
126 126 return tuple(util.select_random_ports(2))
127 127
128 128 mux = Tuple(Integer,Integer,config=True,
129 129 help="""Engine/Client Port pair for MUX queue""")
130 130
131 131 def _mux_default(self):
132 132 return tuple(util.select_random_ports(2))
133 133
134 134 task = Tuple(Integer,Integer,config=True,
135 135 help="""Engine/Client Port pair for Task queue""")
136 136 def _task_default(self):
137 137 return tuple(util.select_random_ports(2))
138 138
139 139 control = Tuple(Integer,Integer,config=True,
140 140 help="""Engine/Client Port pair for Control queue""")
141 141
142 142 def _control_default(self):
143 143 return tuple(util.select_random_ports(2))
144 144
145 145 iopub = Tuple(Integer,Integer,config=True,
146 146 help="""Engine/Client Port pair for IOPub relay""")
147 147
148 148 def _iopub_default(self):
149 149 return tuple(util.select_random_ports(2))
150 150
151 151 # single ports:
152 152 mon_port = Integer(config=True,
153 153 help="""Monitor (SUB) port for queue traffic""")
154 154
155 155 def _mon_port_default(self):
156 156 return util.select_random_ports(1)[0]
157 157
158 158 notifier_port = Integer(config=True,
159 159 help="""PUB port for sending engine status notifications""")
160 160
161 161 def _notifier_port_default(self):
162 162 return util.select_random_ports(1)[0]
163 163
164 164 engine_ip = Unicode('127.0.0.1', config=True,
165 165 help="IP on which to listen for engine connections. [default: loopback]")
166 166 engine_transport = Unicode('tcp', config=True,
167 167 help="0MQ transport for engine connections. [default: tcp]")
168 168
169 169 client_ip = Unicode('127.0.0.1', config=True,
170 170 help="IP on which to listen for client connections. [default: loopback]")
171 171 client_transport = Unicode('tcp', config=True,
172 172 help="0MQ transport for client connections. [default : tcp]")
173 173
174 174 monitor_ip = Unicode('127.0.0.1', config=True,
175 175 help="IP on which to listen for monitor messages. [default: loopback]")
176 176 monitor_transport = Unicode('tcp', config=True,
177 177 help="0MQ transport for monitor messages. [default : tcp]")
178 178
179 179 monitor_url = Unicode('')
180 180
181 181 db_class = DottedObjectName('IPython.parallel.controller.dictdb.DictDB',
182 182 config=True, help="""The class to use for the DB backend""")
183 183
184 184 # not configurable
185 185 db = Instance('IPython.parallel.controller.dictdb.BaseDB')
186 186 heartmonitor = Instance('IPython.parallel.controller.heartmonitor.HeartMonitor')
187 187
188 188 def _ip_changed(self, name, old, new):
189 189 self.engine_ip = new
190 190 self.client_ip = new
191 191 self.monitor_ip = new
192 192 self._update_monitor_url()
193 193
194 194 def _update_monitor_url(self):
195 195 self.monitor_url = "%s://%s:%i" % (self.monitor_transport, self.monitor_ip, self.mon_port)
196 196
197 197 def _transport_changed(self, name, old, new):
198 198 self.engine_transport = new
199 199 self.client_transport = new
200 200 self.monitor_transport = new
201 201 self._update_monitor_url()
202 202
203 203 def __init__(self, **kwargs):
204 204 super(HubFactory, self).__init__(**kwargs)
205 205 self._update_monitor_url()
206 206
207 207
208 208 def construct(self):
209 209 self.init_hub()
210 210
211 211 def start(self):
212 212 self.heartmonitor.start()
213 213 self.log.info("Heartmonitor started")
214 214
215 215 def init_hub(self):
216 216 """construct"""
217 217 client_iface = "%s://%s:" % (self.client_transport, self.client_ip) + "%i"
218 218 engine_iface = "%s://%s:" % (self.engine_transport, self.engine_ip) + "%i"
219 219
220 220 ctx = self.context
221 221 loop = self.loop
222 222
223 223 # Registrar socket
224 224 q = ZMQStream(ctx.socket(zmq.ROUTER), loop)
225 225 q.bind(client_iface % self.regport)
226 226 self.log.info("Hub listening on %s for registration.", client_iface % self.regport)
227 227 if self.client_ip != self.engine_ip:
228 228 q.bind(engine_iface % self.regport)
229 229 self.log.info("Hub listening on %s for registration.", engine_iface % self.regport)
230 230
231 231 ### Engine connections ###
232 232
233 233 # heartbeat
234 234 hpub = ctx.socket(zmq.PUB)
235 235 hpub.bind(engine_iface % self.hb[0])
236 236 hrep = ctx.socket(zmq.ROUTER)
237 237 hrep.bind(engine_iface % self.hb[1])
238 238 self.heartmonitor = HeartMonitor(loop=loop, config=self.config, log=self.log,
239 239 pingstream=ZMQStream(hpub,loop),
240 240 pongstream=ZMQStream(hrep,loop)
241 241 )
242 242
243 243 ### Client connections ###
244 244 # Notifier socket
245 245 n = ZMQStream(ctx.socket(zmq.PUB), loop)
246 246 n.bind(client_iface%self.notifier_port)
247 247
248 248 ### build and launch the queues ###
249 249
250 250 # monitor socket
251 251 sub = ctx.socket(zmq.SUB)
252 252 sub.setsockopt(zmq.SUBSCRIBE, b"")
253 253 sub.bind(self.monitor_url)
254 254 sub.bind('inproc://monitor')
255 255 sub = ZMQStream(sub, loop)
256 256
257 257 # connect the db
258 258 self.log.info('Hub using DB backend: %r'%(self.db_class.split()[-1]))
259 259 # cdir = self.config.Global.cluster_dir
260 260 self.db = import_item(str(self.db_class))(session=self.session.session,
261 261 config=self.config, log=self.log)
262 262 time.sleep(.25)
263 263 try:
264 264 scheme = self.config.TaskScheduler.scheme_name
265 265 except AttributeError:
266 266 from .scheduler import TaskScheduler
267 267 scheme = TaskScheduler.scheme_name.get_default_value()
268 268 # build connection dicts
269 269 self.engine_info = {
270 270 'control' : engine_iface%self.control[1],
271 271 'mux': engine_iface%self.mux[1],
272 272 'heartbeat': (engine_iface%self.hb[0], engine_iface%self.hb[1]),
273 273 'task' : engine_iface%self.task[1],
274 274 'iopub' : engine_iface%self.iopub[1],
275 275 # 'monitor' : engine_iface%self.mon_port,
276 276 }
277 277
278 278 self.client_info = {
279 279 'control' : client_iface%self.control[0],
280 280 'mux': client_iface%self.mux[0],
281 281 'task' : (scheme, client_iface%self.task[0]),
282 282 'iopub' : client_iface%self.iopub[0],
283 283 'notification': client_iface%self.notifier_port
284 284 }
285 285 self.log.debug("Hub engine addrs: %s", self.engine_info)
286 286 self.log.debug("Hub client addrs: %s", self.client_info)
287 287
288 288 # resubmit stream
289 289 r = ZMQStream(ctx.socket(zmq.DEALER), loop)
290 290 url = util.disambiguate_url(self.client_info['task'][-1])
291 291 r.setsockopt(zmq.IDENTITY, self.session.bsession)
292 292 r.connect(url)
293 293
294 294 self.hub = Hub(loop=loop, session=self.session, monitor=sub, heartmonitor=self.heartmonitor,
295 295 query=q, notifier=n, resubmit=r, db=self.db,
296 296 engine_info=self.engine_info, client_info=self.client_info,
297 297 log=self.log)
298 298
299 299
300 300 class Hub(SessionFactory):
301 301 """The IPython Controller Hub with 0MQ connections
302 302
303 303 Parameters
304 304 ==========
305 305 loop: zmq IOLoop instance
306 306 session: Session object
307 307 <removed> context: zmq context for creating new connections (?)
308 308 queue: ZMQStream for monitoring the command queue (SUB)
309 309 query: ZMQStream for engine registration and client queries requests (XREP)
310 310 heartbeat: HeartMonitor object checking the pulse of the engines
311 311 notifier: ZMQStream for broadcasting engine registration changes (PUB)
312 312 db: connection to db for out of memory logging of commands
313 313 NotImplemented
314 314 engine_info: dict of zmq connection information for engines to connect
315 315 to the queues.
316 316 client_info: dict of zmq connection information for engines to connect
317 317 to the queues.
318 318 """
319 319 # internal data structures:
320 320 ids=Set() # engine IDs
321 321 keytable=Dict()
322 322 by_ident=Dict()
323 323 engines=Dict()
324 324 clients=Dict()
325 325 hearts=Dict()
326 326 pending=Set()
327 327 queues=Dict() # pending msg_ids keyed by engine_id
328 328 tasks=Dict() # pending msg_ids submitted as tasks, keyed by client_id
329 329 completed=Dict() # completed msg_ids keyed by engine_id
330 330 all_completed=Set() # completed msg_ids keyed by engine_id
331 331 dead_engines=Set() # completed msg_ids keyed by engine_id
332 332 unassigned=Set() # set of task msg_ds not yet assigned a destination
333 333 incoming_registrations=Dict()
334 334 registration_timeout=Integer()
335 335 _idcounter=Integer(0)
336 336
337 337 # objects from constructor:
338 338 query=Instance(ZMQStream)
339 339 monitor=Instance(ZMQStream)
340 340 notifier=Instance(ZMQStream)
341 341 resubmit=Instance(ZMQStream)
342 342 heartmonitor=Instance(HeartMonitor)
343 343 db=Instance(object)
344 344 client_info=Dict()
345 345 engine_info=Dict()
346 346
347 347
348 348 def __init__(self, **kwargs):
349 349 """
350 350 # universal:
351 351 loop: IOLoop for creating future connections
352 352 session: streamsession for sending serialized data
353 353 # engine:
354 354 queue: ZMQStream for monitoring queue messages
355 355 query: ZMQStream for engine+client registration and client requests
356 356 heartbeat: HeartMonitor object for tracking engines
357 357 # extra:
358 358 db: ZMQStream for db connection (NotImplemented)
359 359 engine_info: zmq address/protocol dict for engine connections
360 360 client_info: zmq address/protocol dict for client connections
361 361 """
362 362
363 363 super(Hub, self).__init__(**kwargs)
364 364 self.registration_timeout = max(5000, 2*self.heartmonitor.period)
365 365
366 366 # validate connection dicts:
367 367 for k,v in self.client_info.iteritems():
368 368 if k == 'task':
369 369 util.validate_url_container(v[1])
370 370 else:
371 371 util.validate_url_container(v)
372 372 # util.validate_url_container(self.client_info)
373 373 util.validate_url_container(self.engine_info)
374 374
375 375 # register our callbacks
376 376 self.query.on_recv(self.dispatch_query)
377 377 self.monitor.on_recv(self.dispatch_monitor_traffic)
378 378
379 379 self.heartmonitor.add_heart_failure_handler(self.handle_heart_failure)
380 380 self.heartmonitor.add_new_heart_handler(self.handle_new_heart)
381 381
382 382 self.monitor_handlers = {b'in' : self.save_queue_request,
383 383 b'out': self.save_queue_result,
384 384 b'intask': self.save_task_request,
385 385 b'outtask': self.save_task_result,
386 386 b'tracktask': self.save_task_destination,
387 387 b'incontrol': _passer,
388 388 b'outcontrol': _passer,
389 389 b'iopub': self.save_iopub_message,
390 390 }
391 391
392 392 self.query_handlers = {'queue_request': self.queue_status,
393 393 'result_request': self.get_results,
394 394 'history_request': self.get_history,
395 395 'db_request': self.db_query,
396 396 'purge_request': self.purge_results,
397 397 'load_request': self.check_load,
398 398 'resubmit_request': self.resubmit_task,
399 399 'shutdown_request': self.shutdown_request,
400 400 'registration_request' : self.register_engine,
401 401 'unregistration_request' : self.unregister_engine,
402 402 'connection_request': self.connection_request,
403 403 }
404 404
405 405 # ignore resubmit replies
406 406 self.resubmit.on_recv(lambda msg: None, copy=False)
407 407
408 408 self.log.info("hub::created hub")
409 409
410 410 @property
411 411 def _next_id(self):
412 412 """gemerate a new ID.
413 413
414 414 No longer reuse old ids, just count from 0."""
415 415 newid = self._idcounter
416 416 self._idcounter += 1
417 417 return newid
418 418 # newid = 0
419 419 # incoming = [id[0] for id in self.incoming_registrations.itervalues()]
420 420 # # print newid, self.ids, self.incoming_registrations
421 421 # while newid in self.ids or newid in incoming:
422 422 # newid += 1
423 423 # return newid
424 424
425 425 #-----------------------------------------------------------------------------
426 426 # message validation
427 427 #-----------------------------------------------------------------------------
428 428
429 429 def _validate_targets(self, targets):
430 430 """turn any valid targets argument into a list of integer ids"""
431 431 if targets is None:
432 432 # default to all
433 433 return self.ids
434 434
435 435 if isinstance(targets, (int,str,unicode)):
436 436 # only one target specified
437 437 targets = [targets]
438 438 _targets = []
439 439 for t in targets:
440 440 # map raw identities to ids
441 441 if isinstance(t, (str,unicode)):
442 442 t = self.by_ident.get(t, t)
443 443 _targets.append(t)
444 444 targets = _targets
445 445 bad_targets = [ t for t in targets if t not in self.ids ]
446 446 if bad_targets:
447 447 raise IndexError("No Such Engine: %r" % bad_targets)
448 448 if not targets:
449 449 raise IndexError("No Engines Registered")
450 450 return targets
451 451
452 452 #-----------------------------------------------------------------------------
453 453 # dispatch methods (1 per stream)
454 454 #-----------------------------------------------------------------------------
455 455
456 456
457 @util.log_errors
457 458 def dispatch_monitor_traffic(self, msg):
458 459 """all ME and Task queue messages come through here, as well as
459 460 IOPub traffic."""
460 461 self.log.debug("monitor traffic: %r", msg[0])
461 462 switch = msg[0]
462 463 try:
463 464 idents, msg = self.session.feed_identities(msg[1:])
464 465 except ValueError:
465 466 idents=[]
466 467 if not idents:
467 468 self.log.error("Bad Monitor Message: %r", msg)
468 469 return
469 470 handler = self.monitor_handlers.get(switch, None)
470 471 if handler is not None:
471 472 handler(idents, msg)
472 473 else:
473 474 self.log.error("Invalid monitor topic: %r", switch)
474 475
475 476
477 @util.log_errors
476 478 def dispatch_query(self, msg):
477 479 """Route registration requests and queries from clients."""
478 480 try:
479 481 idents, msg = self.session.feed_identities(msg)
480 482 except ValueError:
481 483 idents = []
482 484 if not idents:
483 485 self.log.error("Bad Query Message: %r", msg)
484 486 return
485 487 client_id = idents[0]
486 488 try:
487 489 msg = self.session.unserialize(msg, content=True)
488 490 except Exception:
489 491 content = error.wrap_exception()
490 492 self.log.error("Bad Query Message: %r", msg, exc_info=True)
491 493 self.session.send(self.query, "hub_error", ident=client_id,
492 494 content=content)
493 495 return
494 496 # print client_id, header, parent, content
495 497 #switch on message type:
496 498 msg_type = msg['header']['msg_type']
497 499 self.log.info("client::client %r requested %r", client_id, msg_type)
498 500 handler = self.query_handlers.get(msg_type, None)
499 501 try:
500 502 assert handler is not None, "Bad Message Type: %r" % msg_type
501 503 except:
502 504 content = error.wrap_exception()
503 505 self.log.error("Bad Message Type: %r", msg_type, exc_info=True)
504 506 self.session.send(self.query, "hub_error", ident=client_id,
505 507 content=content)
506 508 return
507 509
508 510 else:
509 511 handler(idents, msg)
510 512
511 513 def dispatch_db(self, msg):
512 514 """"""
513 515 raise NotImplementedError
514 516
515 517 #---------------------------------------------------------------------------
516 518 # handler methods (1 per event)
517 519 #---------------------------------------------------------------------------
518 520
519 521 #----------------------- Heartbeat --------------------------------------
520 522
521 523 def handle_new_heart(self, heart):
522 524 """handler to attach to heartbeater.
523 525 Called when a new heart starts to beat.
524 526 Triggers completion of registration."""
525 527 self.log.debug("heartbeat::handle_new_heart(%r)", heart)
526 528 if heart not in self.incoming_registrations:
527 529 self.log.info("heartbeat::ignoring new heart: %r", heart)
528 530 else:
529 531 self.finish_registration(heart)
530 532
531 533
532 534 def handle_heart_failure(self, heart):
533 535 """handler to attach to heartbeater.
534 536 called when a previously registered heart fails to respond to beat request.
535 537 triggers unregistration"""
536 538 self.log.debug("heartbeat::handle_heart_failure(%r)", heart)
537 539 eid = self.hearts.get(heart, None)
538 540 queue = self.engines[eid].queue
539 541 if eid is None or self.keytable[eid] in self.dead_engines:
540 542 self.log.info("heartbeat::ignoring heart failure %r (not an engine or already dead)", heart)
541 543 else:
542 544 self.unregister_engine(heart, dict(content=dict(id=eid, queue=queue)))
543 545
544 546 #----------------------- MUX Queue Traffic ------------------------------
545 547
546 548 def save_queue_request(self, idents, msg):
547 549 if len(idents) < 2:
548 550 self.log.error("invalid identity prefix: %r", idents)
549 551 return
550 552 queue_id, client_id = idents[:2]
551 553 try:
552 554 msg = self.session.unserialize(msg)
553 555 except Exception:
554 556 self.log.error("queue::client %r sent invalid message to %r: %r", client_id, queue_id, msg, exc_info=True)
555 557 return
556 558
557 559 eid = self.by_ident.get(queue_id, None)
558 560 if eid is None:
559 561 self.log.error("queue::target %r not registered", queue_id)
560 562 self.log.debug("queue:: valid are: %r", self.by_ident.keys())
561 563 return
562 564 record = init_record(msg)
563 565 msg_id = record['msg_id']
564 566 self.log.info("queue::client %r submitted request %r to %s", client_id, msg_id, eid)
565 567 # Unicode in records
566 568 record['engine_uuid'] = queue_id.decode('ascii')
567 569 record['client_uuid'] = client_id.decode('ascii')
568 570 record['queue'] = 'mux'
569 571
570 572 try:
571 573 # it's posible iopub arrived first:
572 574 existing = self.db.get_record(msg_id)
573 575 for key,evalue in existing.iteritems():
574 576 rvalue = record.get(key, None)
575 577 if evalue and rvalue and evalue != rvalue:
576 578 self.log.warn("conflicting initial state for record: %r:%r <%r> %r", msg_id, rvalue, key, evalue)
577 579 elif evalue and not rvalue:
578 580 record[key] = evalue
579 581 try:
580 582 self.db.update_record(msg_id, record)
581 583 except Exception:
582 584 self.log.error("DB Error updating record %r", msg_id, exc_info=True)
583 585 except KeyError:
584 586 try:
585 587 self.db.add_record(msg_id, record)
586 588 except Exception:
587 589 self.log.error("DB Error adding record %r", msg_id, exc_info=True)
588 590
589 591
590 592 self.pending.add(msg_id)
591 593 self.queues[eid].append(msg_id)
592 594
593 595 def save_queue_result(self, idents, msg):
594 596 if len(idents) < 2:
595 597 self.log.error("invalid identity prefix: %r", idents)
596 598 return
597 599
598 600 client_id, queue_id = idents[:2]
599 601 try:
600 602 msg = self.session.unserialize(msg)
601 603 except Exception:
602 604 self.log.error("queue::engine %r sent invalid message to %r: %r",
603 605 queue_id, client_id, msg, exc_info=True)
604 606 return
605 607
606 608 eid = self.by_ident.get(queue_id, None)
607 609 if eid is None:
608 610 self.log.error("queue::unknown engine %r is sending a reply: ", queue_id)
609 611 return
610 612
611 613 parent = msg['parent_header']
612 614 if not parent:
613 615 return
614 616 msg_id = parent['msg_id']
615 617 if msg_id in self.pending:
616 618 self.pending.remove(msg_id)
617 619 self.all_completed.add(msg_id)
618 620 self.queues[eid].remove(msg_id)
619 621 self.completed[eid].append(msg_id)
620 622 self.log.info("queue::request %r completed on %s", msg_id, eid)
621 623 elif msg_id not in self.all_completed:
622 624 # it could be a result from a dead engine that died before delivering the
623 625 # result
624 626 self.log.warn("queue:: unknown msg finished %r", msg_id)
625 627 return
626 628 # update record anyway, because the unregistration could have been premature
627 629 rheader = msg['header']
628 630 completed = rheader['date']
629 631 started = rheader.get('started', None)
630 632 result = {
631 633 'result_header' : rheader,
632 634 'result_content': msg['content'],
633 635 'started' : started,
634 636 'completed' : completed
635 637 }
636 638
637 639 result['result_buffers'] = msg['buffers']
638 640 try:
639 641 self.db.update_record(msg_id, result)
640 642 except Exception:
641 643 self.log.error("DB Error updating record %r", msg_id, exc_info=True)
642 644
643 645
644 646 #--------------------- Task Queue Traffic ------------------------------
645 647
646 648 def save_task_request(self, idents, msg):
647 649 """Save the submission of a task."""
648 650 client_id = idents[0]
649 651
650 652 try:
651 653 msg = self.session.unserialize(msg)
652 654 except Exception:
653 655 self.log.error("task::client %r sent invalid task message: %r",
654 656 client_id, msg, exc_info=True)
655 657 return
656 658 record = init_record(msg)
657 659
658 660 record['client_uuid'] = client_id.decode('ascii')
659 661 record['queue'] = 'task'
660 662 header = msg['header']
661 663 msg_id = header['msg_id']
662 664 self.pending.add(msg_id)
663 665 self.unassigned.add(msg_id)
664 666 try:
665 667 # it's posible iopub arrived first:
666 668 existing = self.db.get_record(msg_id)
667 669 if existing['resubmitted']:
668 670 for key in ('submitted', 'client_uuid', 'buffers'):
669 671 # don't clobber these keys on resubmit
670 672 # submitted and client_uuid should be different
671 673 # and buffers might be big, and shouldn't have changed
672 674 record.pop(key)
673 675 # still check content,header which should not change
674 676 # but are not expensive to compare as buffers
675 677
676 678 for key,evalue in existing.iteritems():
677 679 if key.endswith('buffers'):
678 680 # don't compare buffers
679 681 continue
680 682 rvalue = record.get(key, None)
681 683 if evalue and rvalue and evalue != rvalue:
682 684 self.log.warn("conflicting initial state for record: %r:%r <%r> %r", msg_id, rvalue, key, evalue)
683 685 elif evalue and not rvalue:
684 686 record[key] = evalue
685 687 try:
686 688 self.db.update_record(msg_id, record)
687 689 except Exception:
688 690 self.log.error("DB Error updating record %r", msg_id, exc_info=True)
689 691 except KeyError:
690 692 try:
691 693 self.db.add_record(msg_id, record)
692 694 except Exception:
693 695 self.log.error("DB Error adding record %r", msg_id, exc_info=True)
694 696 except Exception:
695 697 self.log.error("DB Error saving task request %r", msg_id, exc_info=True)
696 698
697 699 def save_task_result(self, idents, msg):
698 700 """save the result of a completed task."""
699 701 client_id = idents[0]
700 702 try:
701 703 msg = self.session.unserialize(msg)
702 704 except Exception:
703 705 self.log.error("task::invalid task result message send to %r: %r",
704 706 client_id, msg, exc_info=True)
705 707 return
706 708
707 709 parent = msg['parent_header']
708 710 if not parent:
709 711 # print msg
710 712 self.log.warn("Task %r had no parent!", msg)
711 713 return
712 714 msg_id = parent['msg_id']
713 715 if msg_id in self.unassigned:
714 716 self.unassigned.remove(msg_id)
715 717
716 718 header = msg['header']
717 719 engine_uuid = header.get('engine', None)
718 720 eid = self.by_ident.get(engine_uuid, None)
719 721
720 722 if msg_id in self.pending:
721 723 self.log.info("task::task %r finished on %s", msg_id, eid)
722 724 self.pending.remove(msg_id)
723 725 self.all_completed.add(msg_id)
724 726 if eid is not None:
725 727 self.completed[eid].append(msg_id)
726 728 if msg_id in self.tasks[eid]:
727 729 self.tasks[eid].remove(msg_id)
728 730 completed = header['date']
729 731 started = header.get('started', None)
730 732 result = {
731 733 'result_header' : header,
732 734 'result_content': msg['content'],
733 735 'started' : started,
734 736 'completed' : completed,
735 737 'engine_uuid': engine_uuid
736 738 }
737 739
738 740 result['result_buffers'] = msg['buffers']
739 741 try:
740 742 self.db.update_record(msg_id, result)
741 743 except Exception:
742 744 self.log.error("DB Error saving task request %r", msg_id, exc_info=True)
743 745
744 746 else:
745 747 self.log.debug("task::unknown task %r finished", msg_id)
746 748
747 749 def save_task_destination(self, idents, msg):
748 750 try:
749 751 msg = self.session.unserialize(msg, content=True)
750 752 except Exception:
751 753 self.log.error("task::invalid task tracking message", exc_info=True)
752 754 return
753 755 content = msg['content']
754 756 # print (content)
755 757 msg_id = content['msg_id']
756 758 engine_uuid = content['engine_id']
757 759 eid = self.by_ident[util.asbytes(engine_uuid)]
758 760
759 761 self.log.info("task::task %r arrived on %r", msg_id, eid)
760 762 if msg_id in self.unassigned:
761 763 self.unassigned.remove(msg_id)
762 764 # else:
763 765 # self.log.debug("task::task %r not listed as MIA?!"%(msg_id))
764 766
765 767 self.tasks[eid].append(msg_id)
766 768 # self.pending[msg_id][1].update(received=datetime.now(),engine=(eid,engine_uuid))
767 769 try:
768 770 self.db.update_record(msg_id, dict(engine_uuid=engine_uuid))
769 771 except Exception:
770 772 self.log.error("DB Error saving task destination %r", msg_id, exc_info=True)
771 773
772 774
773 775 def mia_task_request(self, idents, msg):
774 776 raise NotImplementedError
775 777 client_id = idents[0]
776 778 # content = dict(mia=self.mia,status='ok')
777 779 # self.session.send('mia_reply', content=content, idents=client_id)
778 780
779 781
780 782 #--------------------- IOPub Traffic ------------------------------
781 783
782 784 def save_iopub_message(self, topics, msg):
783 785 """save an iopub message into the db"""
784 786 # print (topics)
785 787 try:
786 788 msg = self.session.unserialize(msg, content=True)
787 789 except Exception:
788 790 self.log.error("iopub::invalid IOPub message", exc_info=True)
789 791 return
790 792
791 793 parent = msg['parent_header']
792 794 if not parent:
793 795 self.log.error("iopub::invalid IOPub message: %r", msg)
794 796 return
795 797 msg_id = parent['msg_id']
796 798 msg_type = msg['header']['msg_type']
797 799 content = msg['content']
798 800
799 801 # ensure msg_id is in db
800 802 try:
801 803 rec = self.db.get_record(msg_id)
802 804 except KeyError:
803 805 rec = empty_record()
804 806 rec['msg_id'] = msg_id
805 807 self.db.add_record(msg_id, rec)
806 808 # stream
807 809 d = {}
808 810 if msg_type == 'stream':
809 811 name = content['name']
810 812 s = rec[name] or ''
811 813 d[name] = s + content['data']
812 814
813 815 elif msg_type == 'pyerr':
814 816 d['pyerr'] = content
815 817 elif msg_type == 'pyin':
816 818 d['pyin'] = content['code']
817 819 else:
818 820 d[msg_type] = content.get('data', '')
819 821
820 822 try:
821 823 self.db.update_record(msg_id, d)
822 824 except Exception:
823 825 self.log.error("DB Error saving iopub message %r", msg_id, exc_info=True)
824 826
825 827
826 828
827 829 #-------------------------------------------------------------------------
828 830 # Registration requests
829 831 #-------------------------------------------------------------------------
830 832
831 833 def connection_request(self, client_id, msg):
832 834 """Reply with connection addresses for clients."""
833 835 self.log.info("client::client %r connected", client_id)
834 836 content = dict(status='ok')
835 837 content.update(self.client_info)
836 838 jsonable = {}
837 839 for k,v in self.keytable.iteritems():
838 840 if v not in self.dead_engines:
839 841 jsonable[str(k)] = v.decode('ascii')
840 842 content['engines'] = jsonable
841 843 self.session.send(self.query, 'connection_reply', content, parent=msg, ident=client_id)
842 844
843 845 def register_engine(self, reg, msg):
844 846 """Register a new engine."""
845 847 content = msg['content']
846 848 try:
847 849 queue = util.asbytes(content['queue'])
848 850 except KeyError:
849 851 self.log.error("registration::queue not specified", exc_info=True)
850 852 return
851 853 heart = content.get('heartbeat', None)
852 854 if heart:
853 855 heart = util.asbytes(heart)
854 856 """register a new engine, and create the socket(s) necessary"""
855 857 eid = self._next_id
856 858 # print (eid, queue, reg, heart)
857 859
858 860 self.log.debug("registration::register_engine(%i, %r, %r, %r)", eid, queue, reg, heart)
859 861
860 862 content = dict(id=eid,status='ok')
861 863 content.update(self.engine_info)
862 864 # check if requesting available IDs:
863 865 if queue in self.by_ident:
864 866 try:
865 867 raise KeyError("queue_id %r in use" % queue)
866 868 except:
867 869 content = error.wrap_exception()
868 870 self.log.error("queue_id %r in use", queue, exc_info=True)
869 871 elif heart in self.hearts: # need to check unique hearts?
870 872 try:
871 873 raise KeyError("heart_id %r in use" % heart)
872 874 except:
873 875 self.log.error("heart_id %r in use", heart, exc_info=True)
874 876 content = error.wrap_exception()
875 877 else:
876 878 for h, pack in self.incoming_registrations.iteritems():
877 879 if heart == h:
878 880 try:
879 881 raise KeyError("heart_id %r in use" % heart)
880 882 except:
881 883 self.log.error("heart_id %r in use", heart, exc_info=True)
882 884 content = error.wrap_exception()
883 885 break
884 886 elif queue == pack[1]:
885 887 try:
886 888 raise KeyError("queue_id %r in use" % queue)
887 889 except:
888 890 self.log.error("queue_id %r in use", queue, exc_info=True)
889 891 content = error.wrap_exception()
890 892 break
891 893
892 894 msg = self.session.send(self.query, "registration_reply",
893 895 content=content,
894 896 ident=reg)
895 897
896 898 if content['status'] == 'ok':
897 899 if heart in self.heartmonitor.hearts:
898 900 # already beating
899 901 self.incoming_registrations[heart] = (eid,queue,reg[0],None)
900 902 self.finish_registration(heart)
901 903 else:
902 904 purge = lambda : self._purge_stalled_registration(heart)
903 905 dc = ioloop.DelayedCallback(purge, self.registration_timeout, self.loop)
904 906 dc.start()
905 907 self.incoming_registrations[heart] = (eid,queue,reg[0],dc)
906 908 else:
907 909 self.log.error("registration::registration %i failed: %r", eid, content['evalue'])
908 910 return eid
909 911
910 912 def unregister_engine(self, ident, msg):
911 913 """Unregister an engine that explicitly requested to leave."""
912 914 try:
913 915 eid = msg['content']['id']
914 916 except:
915 917 self.log.error("registration::bad engine id for unregistration: %r", ident, exc_info=True)
916 918 return
917 919 self.log.info("registration::unregister_engine(%r)", eid)
918 920 # print (eid)
919 921 uuid = self.keytable[eid]
920 922 content=dict(id=eid, queue=uuid.decode('ascii'))
921 923 self.dead_engines.add(uuid)
922 924 # self.ids.remove(eid)
923 925 # uuid = self.keytable.pop(eid)
924 926 #
925 927 # ec = self.engines.pop(eid)
926 928 # self.hearts.pop(ec.heartbeat)
927 929 # self.by_ident.pop(ec.queue)
928 930 # self.completed.pop(eid)
929 931 handleit = lambda : self._handle_stranded_msgs(eid, uuid)
930 932 dc = ioloop.DelayedCallback(handleit, self.registration_timeout, self.loop)
931 933 dc.start()
932 934 ############## TODO: HANDLE IT ################
933 935
934 936 if self.notifier:
935 937 self.session.send(self.notifier, "unregistration_notification", content=content)
936 938
937 939 def _handle_stranded_msgs(self, eid, uuid):
938 940 """Handle messages known to be on an engine when the engine unregisters.
939 941
940 942 It is possible that this will fire prematurely - that is, an engine will
941 943 go down after completing a result, and the client will be notified
942 944 that the result failed and later receive the actual result.
943 945 """
944 946
945 947 outstanding = self.queues[eid]
946 948
947 949 for msg_id in outstanding:
948 950 self.pending.remove(msg_id)
949 951 self.all_completed.add(msg_id)
950 952 try:
951 953 raise error.EngineError("Engine %r died while running task %r" % (eid, msg_id))
952 954 except:
953 955 content = error.wrap_exception()
954 956 # build a fake header:
955 957 header = {}
956 958 header['engine'] = uuid
957 959 header['date'] = datetime.now()
958 960 rec = dict(result_content=content, result_header=header, result_buffers=[])
959 961 rec['completed'] = header['date']
960 962 rec['engine_uuid'] = uuid
961 963 try:
962 964 self.db.update_record(msg_id, rec)
963 965 except Exception:
964 966 self.log.error("DB Error handling stranded msg %r", msg_id, exc_info=True)
965 967
966 968
967 969 def finish_registration(self, heart):
968 970 """Second half of engine registration, called after our HeartMonitor
969 971 has received a beat from the Engine's Heart."""
970 972 try:
971 973 (eid,queue,reg,purge) = self.incoming_registrations.pop(heart)
972 974 except KeyError:
973 975 self.log.error("registration::tried to finish nonexistant registration", exc_info=True)
974 976 return
975 977 self.log.info("registration::finished registering engine %i:%r", eid, queue)
976 978 if purge is not None:
977 979 purge.stop()
978 980 control = queue
979 981 self.ids.add(eid)
980 982 self.keytable[eid] = queue
981 983 self.engines[eid] = EngineConnector(id=eid, queue=queue, registration=reg,
982 984 control=control, heartbeat=heart)
983 985 self.by_ident[queue] = eid
984 986 self.queues[eid] = list()
985 987 self.tasks[eid] = list()
986 988 self.completed[eid] = list()
987 989 self.hearts[heart] = eid
988 990 content = dict(id=eid, queue=self.engines[eid].queue.decode('ascii'))
989 991 if self.notifier:
990 992 self.session.send(self.notifier, "registration_notification", content=content)
991 993 self.log.info("engine::Engine Connected: %i", eid)
992 994
993 995 def _purge_stalled_registration(self, heart):
994 996 if heart in self.incoming_registrations:
995 997 eid = self.incoming_registrations.pop(heart)[0]
996 998 self.log.info("registration::purging stalled registration: %i", eid)
997 999 else:
998 1000 pass
999 1001
1000 1002 #-------------------------------------------------------------------------
1001 1003 # Client Requests
1002 1004 #-------------------------------------------------------------------------
1003 1005
1004 1006 def shutdown_request(self, client_id, msg):
1005 1007 """handle shutdown request."""
1006 1008 self.session.send(self.query, 'shutdown_reply', content={'status': 'ok'}, ident=client_id)
1007 1009 # also notify other clients of shutdown
1008 1010 self.session.send(self.notifier, 'shutdown_notice', content={'status': 'ok'})
1009 1011 dc = ioloop.DelayedCallback(lambda : self._shutdown(), 1000, self.loop)
1010 1012 dc.start()
1011 1013
1012 1014 def _shutdown(self):
1013 1015 self.log.info("hub::hub shutting down.")
1014 1016 time.sleep(0.1)
1015 1017 sys.exit(0)
1016 1018
1017 1019
1018 1020 def check_load(self, client_id, msg):
1019 1021 content = msg['content']
1020 1022 try:
1021 1023 targets = content['targets']
1022 1024 targets = self._validate_targets(targets)
1023 1025 except:
1024 1026 content = error.wrap_exception()
1025 1027 self.session.send(self.query, "hub_error",
1026 1028 content=content, ident=client_id)
1027 1029 return
1028 1030
1029 1031 content = dict(status='ok')
1030 1032 # loads = {}
1031 1033 for t in targets:
1032 1034 content[bytes(t)] = len(self.queues[t])+len(self.tasks[t])
1033 1035 self.session.send(self.query, "load_reply", content=content, ident=client_id)
1034 1036
1035 1037
1036 1038 def queue_status(self, client_id, msg):
1037 1039 """Return the Queue status of one or more targets.
1038 1040 if verbose: return the msg_ids
1039 1041 else: return len of each type.
1040 1042 keys: queue (pending MUX jobs)
1041 1043 tasks (pending Task jobs)
1042 1044 completed (finished jobs from both queues)"""
1043 1045 content = msg['content']
1044 1046 targets = content['targets']
1045 1047 try:
1046 1048 targets = self._validate_targets(targets)
1047 1049 except:
1048 1050 content = error.wrap_exception()
1049 1051 self.session.send(self.query, "hub_error",
1050 1052 content=content, ident=client_id)
1051 1053 return
1052 1054 verbose = content.get('verbose', False)
1053 1055 content = dict(status='ok')
1054 1056 for t in targets:
1055 1057 queue = self.queues[t]
1056 1058 completed = self.completed[t]
1057 1059 tasks = self.tasks[t]
1058 1060 if not verbose:
1059 1061 queue = len(queue)
1060 1062 completed = len(completed)
1061 1063 tasks = len(tasks)
1062 1064 content[str(t)] = {'queue': queue, 'completed': completed , 'tasks': tasks}
1063 1065 content['unassigned'] = list(self.unassigned) if verbose else len(self.unassigned)
1064 1066 # print (content)
1065 1067 self.session.send(self.query, "queue_reply", content=content, ident=client_id)
1066 1068
1067 1069 def purge_results(self, client_id, msg):
1068 1070 """Purge results from memory. This method is more valuable before we move
1069 1071 to a DB based message storage mechanism."""
1070 1072 content = msg['content']
1071 1073 self.log.info("Dropping records with %s", content)
1072 1074 msg_ids = content.get('msg_ids', [])
1073 1075 reply = dict(status='ok')
1074 1076 if msg_ids == 'all':
1075 1077 try:
1076 1078 self.db.drop_matching_records(dict(completed={'$ne':None}))
1077 1079 except Exception:
1078 1080 reply = error.wrap_exception()
1079 1081 else:
1080 1082 pending = filter(lambda m: m in self.pending, msg_ids)
1081 1083 if pending:
1082 1084 try:
1083 1085 raise IndexError("msg pending: %r" % pending[0])
1084 1086 except:
1085 1087 reply = error.wrap_exception()
1086 1088 else:
1087 1089 try:
1088 1090 self.db.drop_matching_records(dict(msg_id={'$in':msg_ids}))
1089 1091 except Exception:
1090 1092 reply = error.wrap_exception()
1091 1093
1092 1094 if reply['status'] == 'ok':
1093 1095 eids = content.get('engine_ids', [])
1094 1096 for eid in eids:
1095 1097 if eid not in self.engines:
1096 1098 try:
1097 1099 raise IndexError("No such engine: %i" % eid)
1098 1100 except:
1099 1101 reply = error.wrap_exception()
1100 1102 break
1101 1103 uid = self.engines[eid].queue
1102 1104 try:
1103 1105 self.db.drop_matching_records(dict(engine_uuid=uid, completed={'$ne':None}))
1104 1106 except Exception:
1105 1107 reply = error.wrap_exception()
1106 1108 break
1107 1109
1108 1110 self.session.send(self.query, 'purge_reply', content=reply, ident=client_id)
1109 1111
1110 1112 def resubmit_task(self, client_id, msg):
1111 1113 """Resubmit one or more tasks."""
1112 1114 def finish(reply):
1113 1115 self.session.send(self.query, 'resubmit_reply', content=reply, ident=client_id)
1114 1116
1115 1117 content = msg['content']
1116 1118 msg_ids = content['msg_ids']
1117 1119 reply = dict(status='ok')
1118 1120 try:
1119 1121 records = self.db.find_records({'msg_id' : {'$in' : msg_ids}}, keys=[
1120 1122 'header', 'content', 'buffers'])
1121 1123 except Exception:
1122 1124 self.log.error('db::db error finding tasks to resubmit', exc_info=True)
1123 1125 return finish(error.wrap_exception())
1124 1126
1125 1127 # validate msg_ids
1126 1128 found_ids = [ rec['msg_id'] for rec in records ]
1127 1129 invalid_ids = filter(lambda m: m in self.pending, found_ids)
1128 1130 if len(records) > len(msg_ids):
1129 1131 try:
1130 1132 raise RuntimeError("DB appears to be in an inconsistent state."
1131 1133 "More matching records were found than should exist")
1132 1134 except Exception:
1133 1135 return finish(error.wrap_exception())
1134 1136 elif len(records) < len(msg_ids):
1135 1137 missing = [ m for m in msg_ids if m not in found_ids ]
1136 1138 try:
1137 1139 raise KeyError("No such msg(s): %r" % missing)
1138 1140 except KeyError:
1139 1141 return finish(error.wrap_exception())
1140 1142 elif invalid_ids:
1141 1143 msg_id = invalid_ids[0]
1142 1144 try:
1143 1145 raise ValueError("Task %r appears to be inflight" % msg_id)
1144 1146 except Exception:
1145 1147 return finish(error.wrap_exception())
1146 1148
1147 1149 # clear the existing records
1148 1150 now = datetime.now()
1149 1151 rec = empty_record()
1150 1152 map(rec.pop, ['msg_id', 'header', 'content', 'buffers', 'submitted'])
1151 1153 rec['resubmitted'] = now
1152 1154 rec['queue'] = 'task'
1153 1155 rec['client_uuid'] = client_id[0]
1154 1156 try:
1155 1157 for msg_id in msg_ids:
1156 1158 self.all_completed.discard(msg_id)
1157 1159 self.db.update_record(msg_id, rec)
1158 1160 except Exception:
1159 1161 self.log.error('db::db error upating record', exc_info=True)
1160 1162 reply = error.wrap_exception()
1161 1163 else:
1162 1164 # send the messages
1163 1165 for rec in records:
1164 1166 header = rec['header']
1165 1167 # include resubmitted in header to prevent digest collision
1166 1168 header['resubmitted'] = now
1167 1169 msg = self.session.msg(header['msg_type'])
1168 1170 msg['content'] = rec['content']
1169 1171 msg['header'] = header
1170 1172 msg['header']['msg_id'] = rec['msg_id']
1171 1173 self.session.send(self.resubmit, msg, buffers=rec['buffers'])
1172 1174
1173 1175 finish(dict(status='ok'))
1174 1176
1175 1177
1176 1178 def _extract_record(self, rec):
1177 1179 """decompose a TaskRecord dict into subsection of reply for get_result"""
1178 1180 io_dict = {}
1179 1181 for key in 'pyin pyout pyerr stdout stderr'.split():
1180 1182 io_dict[key] = rec[key]
1181 1183 content = { 'result_content': rec['result_content'],
1182 1184 'header': rec['header'],
1183 1185 'result_header' : rec['result_header'],
1184 1186 'io' : io_dict,
1185 1187 }
1186 1188 if rec['result_buffers']:
1187 1189 buffers = map(bytes, rec['result_buffers'])
1188 1190 else:
1189 1191 buffers = []
1190 1192
1191 1193 return content, buffers
1192 1194
1193 1195 def get_results(self, client_id, msg):
1194 1196 """Get the result of 1 or more messages."""
1195 1197 content = msg['content']
1196 1198 msg_ids = sorted(set(content['msg_ids']))
1197 1199 statusonly = content.get('status_only', False)
1198 1200 pending = []
1199 1201 completed = []
1200 1202 content = dict(status='ok')
1201 1203 content['pending'] = pending
1202 1204 content['completed'] = completed
1203 1205 buffers = []
1204 1206 if not statusonly:
1205 1207 try:
1206 1208 matches = self.db.find_records(dict(msg_id={'$in':msg_ids}))
1207 1209 # turn match list into dict, for faster lookup
1208 1210 records = {}
1209 1211 for rec in matches:
1210 1212 records[rec['msg_id']] = rec
1211 1213 except Exception:
1212 1214 content = error.wrap_exception()
1213 1215 self.session.send(self.query, "result_reply", content=content,
1214 1216 parent=msg, ident=client_id)
1215 1217 return
1216 1218 else:
1217 1219 records = {}
1218 1220 for msg_id in msg_ids:
1219 1221 if msg_id in self.pending:
1220 1222 pending.append(msg_id)
1221 1223 elif msg_id in self.all_completed:
1222 1224 completed.append(msg_id)
1223 1225 if not statusonly:
1224 1226 c,bufs = self._extract_record(records[msg_id])
1225 1227 content[msg_id] = c
1226 1228 buffers.extend(bufs)
1227 1229 elif msg_id in records:
1228 1230 if rec['completed']:
1229 1231 completed.append(msg_id)
1230 1232 c,bufs = self._extract_record(records[msg_id])
1231 1233 content[msg_id] = c
1232 1234 buffers.extend(bufs)
1233 1235 else:
1234 1236 pending.append(msg_id)
1235 1237 else:
1236 1238 try:
1237 1239 raise KeyError('No such message: '+msg_id)
1238 1240 except:
1239 1241 content = error.wrap_exception()
1240 1242 break
1241 1243 self.session.send(self.query, "result_reply", content=content,
1242 1244 parent=msg, ident=client_id,
1243 1245 buffers=buffers)
1244 1246
1245 1247 def get_history(self, client_id, msg):
1246 1248 """Get a list of all msg_ids in our DB records"""
1247 1249 try:
1248 1250 msg_ids = self.db.get_history()
1249 1251 except Exception as e:
1250 1252 content = error.wrap_exception()
1251 1253 else:
1252 1254 content = dict(status='ok', history=msg_ids)
1253 1255
1254 1256 self.session.send(self.query, "history_reply", content=content,
1255 1257 parent=msg, ident=client_id)
1256 1258
1257 1259 def db_query(self, client_id, msg):
1258 1260 """Perform a raw query on the task record database."""
1259 1261 content = msg['content']
1260 1262 query = content.get('query', {})
1261 1263 keys = content.get('keys', None)
1262 1264 buffers = []
1263 1265 empty = list()
1264 1266 try:
1265 1267 records = self.db.find_records(query, keys)
1266 1268 except Exception as e:
1267 1269 content = error.wrap_exception()
1268 1270 else:
1269 1271 # extract buffers from reply content:
1270 1272 if keys is not None:
1271 1273 buffer_lens = [] if 'buffers' in keys else None
1272 1274 result_buffer_lens = [] if 'result_buffers' in keys else None
1273 1275 else:
1274 1276 buffer_lens = None
1275 1277 result_buffer_lens = None
1276 1278
1277 1279 for rec in records:
1278 1280 # buffers may be None, so double check
1279 1281 b = rec.pop('buffers', empty) or empty
1280 1282 if buffer_lens is not None:
1281 1283 buffer_lens.append(len(b))
1282 1284 buffers.extend(b)
1283 1285 rb = rec.pop('result_buffers', empty) or empty
1284 1286 if result_buffer_lens is not None:
1285 1287 result_buffer_lens.append(len(rb))
1286 1288 buffers.extend(rb)
1287 1289 content = dict(status='ok', records=records, buffer_lens=buffer_lens,
1288 1290 result_buffer_lens=result_buffer_lens)
1289 1291 # self.log.debug (content)
1290 1292 self.session.send(self.query, "db_reply", content=content,
1291 1293 parent=msg, ident=client_id,
1292 1294 buffers=buffers)
1293 1295
@@ -1,759 +1,767 b''
1 1 """The Python scheduler for rich scheduling.
2 2
3 3 The Pure ZMQ scheduler does not allow routing schemes other than LRU,
4 4 nor does it check msg_id DAG dependencies. For those, a slightly slower
5 5 Python Scheduler exists.
6 6
7 7 Authors:
8 8
9 9 * Min RK
10 10 """
11 11 #-----------------------------------------------------------------------------
12 12 # Copyright (C) 2010-2011 The IPython Development Team
13 13 #
14 14 # Distributed under the terms of the BSD License. The full license is in
15 15 # the file COPYING, distributed as part of this software.
16 16 #-----------------------------------------------------------------------------
17 17
18 18 #----------------------------------------------------------------------
19 19 # Imports
20 20 #----------------------------------------------------------------------
21 21
22 22 from __future__ import print_function
23 23
24 24 import logging
25 25 import sys
26 26 import time
27 27
28 28 from datetime import datetime, timedelta
29 29 from random import randint, random
30 30 from types import FunctionType
31 31
32 32 try:
33 33 import numpy
34 34 except ImportError:
35 35 numpy = None
36 36
37 37 import zmq
38 38 from zmq.eventloop import ioloop, zmqstream
39 39
40 40 # local imports
41 41 from IPython.external.decorator import decorator
42 42 from IPython.config.application import Application
43 43 from IPython.config.loader import Config
44 44 from IPython.utils.traitlets import Instance, Dict, List, Set, Integer, Enum, CBytes
45 45
46 from IPython.parallel import error
46 from IPython.parallel import error, util
47 47 from IPython.parallel.factory import SessionFactory
48 48 from IPython.parallel.util import connect_logger, local_logger, asbytes
49 49
50 50 from .dependency import Dependency
51 51
52 52 @decorator
53 53 def logged(f,self,*args,**kwargs):
54 54 # print ("#--------------------")
55 55 self.log.debug("scheduler::%s(*%s,**%s)", f.func_name, args, kwargs)
56 56 # print ("#--")
57 57 return f(self,*args, **kwargs)
58 58
59 59 #----------------------------------------------------------------------
60 60 # Chooser functions
61 61 #----------------------------------------------------------------------
62 62
63 63 def plainrandom(loads):
64 64 """Plain random pick."""
65 65 n = len(loads)
66 66 return randint(0,n-1)
67 67
68 68 def lru(loads):
69 69 """Always pick the front of the line.
70 70
71 71 The content of `loads` is ignored.
72 72
73 73 Assumes LRU ordering of loads, with oldest first.
74 74 """
75 75 return 0
76 76
77 77 def twobin(loads):
78 78 """Pick two at random, use the LRU of the two.
79 79
80 80 The content of loads is ignored.
81 81
82 82 Assumes LRU ordering of loads, with oldest first.
83 83 """
84 84 n = len(loads)
85 85 a = randint(0,n-1)
86 86 b = randint(0,n-1)
87 87 return min(a,b)
88 88
89 89 def weighted(loads):
90 90 """Pick two at random using inverse load as weight.
91 91
92 92 Return the less loaded of the two.
93 93 """
94 94 # weight 0 a million times more than 1:
95 95 weights = 1./(1e-6+numpy.array(loads))
96 96 sums = weights.cumsum()
97 97 t = sums[-1]
98 98 x = random()*t
99 99 y = random()*t
100 100 idx = 0
101 101 idy = 0
102 102 while sums[idx] < x:
103 103 idx += 1
104 104 while sums[idy] < y:
105 105 idy += 1
106 106 if weights[idy] > weights[idx]:
107 107 return idy
108 108 else:
109 109 return idx
110 110
111 111 def leastload(loads):
112 112 """Always choose the lowest load.
113 113
114 114 If the lowest load occurs more than once, the first
115 115 occurance will be used. If loads has LRU ordering, this means
116 116 the LRU of those with the lowest load is chosen.
117 117 """
118 118 return loads.index(min(loads))
119 119
120 120 #---------------------------------------------------------------------
121 121 # Classes
122 122 #---------------------------------------------------------------------
123 123
124 124
125 125 # store empty default dependency:
126 126 MET = Dependency([])
127 127
128 128
129 129 class Job(object):
130 130 """Simple container for a job"""
131 131 def __init__(self, msg_id, raw_msg, idents, msg, header, targets, after, follow, timeout):
132 132 self.msg_id = msg_id
133 133 self.raw_msg = raw_msg
134 134 self.idents = idents
135 135 self.msg = msg
136 136 self.header = header
137 137 self.targets = targets
138 138 self.after = after
139 139 self.follow = follow
140 140 self.timeout = timeout
141 141
142 142
143 143 self.timestamp = time.time()
144 144 self.blacklist = set()
145 145
146 146 @property
147 147 def dependents(self):
148 148 return self.follow.union(self.after)
149 149
150 150 class TaskScheduler(SessionFactory):
151 151 """Python TaskScheduler object.
152 152
153 153 This is the simplest object that supports msg_id based
154 154 DAG dependencies. *Only* task msg_ids are checked, not
155 155 msg_ids of jobs submitted via the MUX queue.
156 156
157 157 """
158 158
159 159 hwm = Integer(1, config=True,
160 160 help="""specify the High Water Mark (HWM) for the downstream
161 161 socket in the Task scheduler. This is the maximum number
162 162 of allowed outstanding tasks on each engine.
163 163
164 164 The default (1) means that only one task can be outstanding on each
165 165 engine. Setting TaskScheduler.hwm=0 means there is no limit, and the
166 166 engines continue to be assigned tasks while they are working,
167 167 effectively hiding network latency behind computation, but can result
168 168 in an imbalance of work when submitting many heterogenous tasks all at
169 169 once. Any positive value greater than one is a compromise between the
170 170 two.
171 171
172 172 """
173 173 )
174 174 scheme_name = Enum(('leastload', 'pure', 'lru', 'plainrandom', 'weighted', 'twobin'),
175 175 'leastload', config=True, allow_none=False,
176 176 help="""select the task scheduler scheme [default: Python LRU]
177 177 Options are: 'pure', 'lru', 'plainrandom', 'weighted', 'twobin','leastload'"""
178 178 )
179 179 def _scheme_name_changed(self, old, new):
180 180 self.log.debug("Using scheme %r"%new)
181 181 self.scheme = globals()[new]
182 182
183 183 # input arguments:
184 184 scheme = Instance(FunctionType) # function for determining the destination
185 185 def _scheme_default(self):
186 186 return leastload
187 187 client_stream = Instance(zmqstream.ZMQStream) # client-facing stream
188 188 engine_stream = Instance(zmqstream.ZMQStream) # engine-facing stream
189 189 notifier_stream = Instance(zmqstream.ZMQStream) # hub-facing sub stream
190 190 mon_stream = Instance(zmqstream.ZMQStream) # hub-facing pub stream
191 191
192 192 # internals:
193 193 graph = Dict() # dict by msg_id of [ msg_ids that depend on key ]
194 194 retries = Dict() # dict by msg_id of retries remaining (non-neg ints)
195 195 # waiting = List() # list of msg_ids ready to run, but haven't due to HWM
196 196 depending = Dict() # dict by msg_id of Jobs
197 197 pending = Dict() # dict by engine_uuid of submitted tasks
198 198 completed = Dict() # dict by engine_uuid of completed tasks
199 199 failed = Dict() # dict by engine_uuid of failed tasks
200 200 destinations = Dict() # dict by msg_id of engine_uuids where jobs ran (reverse of completed+failed)
201 201 clients = Dict() # dict by msg_id for who submitted the task
202 202 targets = List() # list of target IDENTs
203 203 loads = List() # list of engine loads
204 204 # full = Set() # set of IDENTs that have HWM outstanding tasks
205 205 all_completed = Set() # set of all completed tasks
206 206 all_failed = Set() # set of all failed tasks
207 207 all_done = Set() # set of all finished tasks=union(completed,failed)
208 208 all_ids = Set() # set of all submitted task IDs
209 209
210 210 auditor = Instance('zmq.eventloop.ioloop.PeriodicCallback')
211 211
212 212 ident = CBytes() # ZMQ identity. This should just be self.session.session
213 213 # but ensure Bytes
214 214 def _ident_default(self):
215 215 return self.session.bsession
216 216
217 217 def start(self):
218 218 self.engine_stream.on_recv(self.dispatch_result, copy=False)
219 219 self.client_stream.on_recv(self.dispatch_submission, copy=False)
220 220
221 221 self._notification_handlers = dict(
222 222 registration_notification = self._register_engine,
223 223 unregistration_notification = self._unregister_engine
224 224 )
225 225 self.notifier_stream.on_recv(self.dispatch_notification)
226 226 self.auditor = ioloop.PeriodicCallback(self.audit_timeouts, 2e3, self.loop) # 1 Hz
227 227 self.auditor.start()
228 228 self.log.info("Scheduler started [%s]"%self.scheme_name)
229 229
230 230 def resume_receiving(self):
231 231 """Resume accepting jobs."""
232 232 self.client_stream.on_recv(self.dispatch_submission, copy=False)
233 233
234 234 def stop_receiving(self):
235 235 """Stop accepting jobs while there are no engines.
236 236 Leave them in the ZMQ queue."""
237 237 self.client_stream.on_recv(None)
238 238
239 239 #-----------------------------------------------------------------------
240 240 # [Un]Registration Handling
241 241 #-----------------------------------------------------------------------
242 242
243
244 @util.log_errors
243 245 def dispatch_notification(self, msg):
244 246 """dispatch register/unregister events."""
245 247 try:
246 248 idents,msg = self.session.feed_identities(msg)
247 249 except ValueError:
248 250 self.log.warn("task::Invalid Message: %r",msg)
249 251 return
250 252 try:
251 253 msg = self.session.unserialize(msg)
252 254 except ValueError:
253 255 self.log.warn("task::Unauthorized message from: %r"%idents)
254 256 return
255 257
256 258 msg_type = msg['header']['msg_type']
257 259
258 260 handler = self._notification_handlers.get(msg_type, None)
259 261 if handler is None:
260 262 self.log.error("Unhandled message type: %r"%msg_type)
261 263 else:
262 264 try:
263 265 handler(asbytes(msg['content']['queue']))
264 266 except Exception:
265 267 self.log.error("task::Invalid notification msg: %r", msg, exc_info=True)
266 268
267 269 def _register_engine(self, uid):
268 270 """New engine with ident `uid` became available."""
269 271 # head of the line:
270 272 self.targets.insert(0,uid)
271 273 self.loads.insert(0,0)
272 274
273 275 # initialize sets
274 276 self.completed[uid] = set()
275 277 self.failed[uid] = set()
276 278 self.pending[uid] = {}
277 279
278 280 # rescan the graph:
279 281 self.update_graph(None)
280 282
281 283 def _unregister_engine(self, uid):
282 284 """Existing engine with ident `uid` became unavailable."""
283 285 if len(self.targets) == 1:
284 286 # this was our only engine
285 287 pass
286 288
287 289 # handle any potentially finished tasks:
288 290 self.engine_stream.flush()
289 291
290 292 # don't pop destinations, because they might be used later
291 293 # map(self.destinations.pop, self.completed.pop(uid))
292 294 # map(self.destinations.pop, self.failed.pop(uid))
293 295
294 296 # prevent this engine from receiving work
295 297 idx = self.targets.index(uid)
296 298 self.targets.pop(idx)
297 299 self.loads.pop(idx)
298 300
299 301 # wait 5 seconds before cleaning up pending jobs, since the results might
300 302 # still be incoming
301 303 if self.pending[uid]:
302 304 dc = ioloop.DelayedCallback(lambda : self.handle_stranded_tasks(uid), 5000, self.loop)
303 305 dc.start()
304 306 else:
305 307 self.completed.pop(uid)
306 308 self.failed.pop(uid)
307 309
308 310
309 311 def handle_stranded_tasks(self, engine):
310 312 """Deal with jobs resident in an engine that died."""
311 313 lost = self.pending[engine]
312 314 for msg_id in lost.keys():
313 315 if msg_id not in self.pending[engine]:
314 316 # prevent double-handling of messages
315 317 continue
316 318
317 319 raw_msg = lost[msg_id][0]
318 320 idents,msg = self.session.feed_identities(raw_msg, copy=False)
319 321 parent = self.session.unpack(msg[1].bytes)
320 322 idents = [engine, idents[0]]
321 323
322 324 # build fake error reply
323 325 try:
324 326 raise error.EngineError("Engine %r died while running task %r"%(engine, msg_id))
325 327 except:
326 328 content = error.wrap_exception()
327 329 # build fake header
328 330 header = dict(
329 331 status='error',
330 332 engine=engine,
331 333 date=datetime.now(),
332 334 )
333 335 msg = self.session.msg('apply_reply', content, parent=parent, subheader=header)
334 336 raw_reply = map(zmq.Message, self.session.serialize(msg, ident=idents))
335 337 # and dispatch it
336 338 self.dispatch_result(raw_reply)
337 339
338 340 # finally scrub completed/failed lists
339 341 self.completed.pop(engine)
340 342 self.failed.pop(engine)
341 343
342 344
343 345 #-----------------------------------------------------------------------
344 346 # Job Submission
345 347 #-----------------------------------------------------------------------
348
349
350 @util.log_errors
346 351 def dispatch_submission(self, raw_msg):
347 352 """Dispatch job submission to appropriate handlers."""
348 353 # ensure targets up to date:
349 354 self.notifier_stream.flush()
350 355 try:
351 356 idents, msg = self.session.feed_identities(raw_msg, copy=False)
352 357 msg = self.session.unserialize(msg, content=False, copy=False)
353 358 except Exception:
354 359 self.log.error("task::Invaid task msg: %r"%raw_msg, exc_info=True)
355 360 return
356 361
357 362
358 363 # send to monitor
359 364 self.mon_stream.send_multipart([b'intask']+raw_msg, copy=False)
360 365
361 366 header = msg['header']
362 367 msg_id = header['msg_id']
363 368 self.all_ids.add(msg_id)
364 369
365 370 # get targets as a set of bytes objects
366 371 # from a list of unicode objects
367 372 targets = header.get('targets', [])
368 373 targets = map(asbytes, targets)
369 374 targets = set(targets)
370 375
371 376 retries = header.get('retries', 0)
372 377 self.retries[msg_id] = retries
373 378
374 379 # time dependencies
375 380 after = header.get('after', None)
376 381 if after:
377 382 after = Dependency(after)
378 383 if after.all:
379 384 if after.success:
380 385 after = Dependency(after.difference(self.all_completed),
381 386 success=after.success,
382 387 failure=after.failure,
383 388 all=after.all,
384 389 )
385 390 if after.failure:
386 391 after = Dependency(after.difference(self.all_failed),
387 392 success=after.success,
388 393 failure=after.failure,
389 394 all=after.all,
390 395 )
391 396 if after.check(self.all_completed, self.all_failed):
392 397 # recast as empty set, if `after` already met,
393 398 # to prevent unnecessary set comparisons
394 399 after = MET
395 400 else:
396 401 after = MET
397 402
398 403 # location dependencies
399 404 follow = Dependency(header.get('follow', []))
400 405
401 406 # turn timeouts into datetime objects:
402 407 timeout = header.get('timeout', None)
403 408 if timeout:
404 409 # cast to float, because jsonlib returns floats as decimal.Decimal,
405 410 # which timedelta does not accept
406 411 timeout = datetime.now() + timedelta(0,float(timeout),0)
407 412
408 413 job = Job(msg_id=msg_id, raw_msg=raw_msg, idents=idents, msg=msg,
409 414 header=header, targets=targets, after=after, follow=follow,
410 415 timeout=timeout,
411 416 )
412 417
413 418 # validate and reduce dependencies:
414 419 for dep in after,follow:
415 420 if not dep: # empty dependency
416 421 continue
417 422 # check valid:
418 423 if msg_id in dep or dep.difference(self.all_ids):
419 424 self.depending[msg_id] = job
420 425 return self.fail_unreachable(msg_id, error.InvalidDependency)
421 426 # check if unreachable:
422 427 if dep.unreachable(self.all_completed, self.all_failed):
423 428 self.depending[msg_id] = job
424 429 return self.fail_unreachable(msg_id)
425 430
426 431 if after.check(self.all_completed, self.all_failed):
427 432 # time deps already met, try to run
428 433 if not self.maybe_run(job):
429 434 # can't run yet
430 435 if msg_id not in self.all_failed:
431 436 # could have failed as unreachable
432 437 self.save_unmet(job)
433 438 else:
434 439 self.save_unmet(job)
435 440
436 441 def audit_timeouts(self):
437 442 """Audit all waiting tasks for expired timeouts."""
438 443 now = datetime.now()
439 444 for msg_id in self.depending.keys():
440 445 # must recheck, in case one failure cascaded to another:
441 446 if msg_id in self.depending:
442 447 job = self.depending[msg_id]
443 448 if job.timeout and job.timeout < now:
444 449 self.fail_unreachable(msg_id, error.TaskTimeout)
445 450
446 451 def fail_unreachable(self, msg_id, why=error.ImpossibleDependency):
447 452 """a task has become unreachable, send a reply with an ImpossibleDependency
448 453 error."""
449 454 if msg_id not in self.depending:
450 455 self.log.error("msg %r already failed!", msg_id)
451 456 return
452 457 job = self.depending.pop(msg_id)
453 458 for mid in job.dependents:
454 459 if mid in self.graph:
455 460 self.graph[mid].remove(msg_id)
456 461
457 462 try:
458 463 raise why()
459 464 except:
460 465 content = error.wrap_exception()
461 466
462 467 self.all_done.add(msg_id)
463 468 self.all_failed.add(msg_id)
464 469
465 470 msg = self.session.send(self.client_stream, 'apply_reply', content,
466 471 parent=job.header, ident=job.idents)
467 472 self.session.send(self.mon_stream, msg, ident=[b'outtask']+job.idents)
468 473
469 474 self.update_graph(msg_id, success=False)
470 475
471 476 def maybe_run(self, job):
472 477 """check location dependencies, and run if they are met."""
473 478 msg_id = job.msg_id
474 479 self.log.debug("Attempting to assign task %s", msg_id)
475 480 if not self.targets:
476 481 # no engines, definitely can't run
477 482 return False
478 483
479 484 if job.follow or job.targets or job.blacklist or self.hwm:
480 485 # we need a can_run filter
481 486 def can_run(idx):
482 487 # check hwm
483 488 if self.hwm and self.loads[idx] == self.hwm:
484 489 return False
485 490 target = self.targets[idx]
486 491 # check blacklist
487 492 if target in job.blacklist:
488 493 return False
489 494 # check targets
490 495 if job.targets and target not in job.targets:
491 496 return False
492 497 # check follow
493 498 return job.follow.check(self.completed[target], self.failed[target])
494 499
495 500 indices = filter(can_run, range(len(self.targets)))
496 501
497 502 if not indices:
498 503 # couldn't run
499 504 if job.follow.all:
500 505 # check follow for impossibility
501 506 dests = set()
502 507 relevant = set()
503 508 if job.follow.success:
504 509 relevant = self.all_completed
505 510 if job.follow.failure:
506 511 relevant = relevant.union(self.all_failed)
507 512 for m in job.follow.intersection(relevant):
508 513 dests.add(self.destinations[m])
509 514 if len(dests) > 1:
510 515 self.depending[msg_id] = job
511 516 self.fail_unreachable(msg_id)
512 517 return False
513 518 if job.targets:
514 519 # check blacklist+targets for impossibility
515 520 job.targets.difference_update(blacklist)
516 521 if not job.targets or not job.targets.intersection(self.targets):
517 522 self.depending[msg_id] = job
518 523 self.fail_unreachable(msg_id)
519 524 return False
520 525 return False
521 526 else:
522 527 indices = None
523 528
524 529 self.submit_task(job, indices)
525 530 return True
526 531
527 532 def save_unmet(self, job):
528 533 """Save a message for later submission when its dependencies are met."""
529 534 msg_id = job.msg_id
530 535 self.depending[msg_id] = job
531 536 # track the ids in follow or after, but not those already finished
532 537 for dep_id in job.after.union(job.follow).difference(self.all_done):
533 538 if dep_id not in self.graph:
534 539 self.graph[dep_id] = set()
535 540 self.graph[dep_id].add(msg_id)
536 541
537 542 def submit_task(self, job, indices=None):
538 543 """Submit a task to any of a subset of our targets."""
539 544 if indices:
540 545 loads = [self.loads[i] for i in indices]
541 546 else:
542 547 loads = self.loads
543 548 idx = self.scheme(loads)
544 549 if indices:
545 550 idx = indices[idx]
546 551 target = self.targets[idx]
547 552 # print (target, map(str, msg[:3]))
548 553 # send job to the engine
549 554 self.engine_stream.send(target, flags=zmq.SNDMORE, copy=False)
550 555 self.engine_stream.send_multipart(job.raw_msg, copy=False)
551 556 # update load
552 557 self.add_job(idx)
553 558 self.pending[target][job.msg_id] = job
554 559 # notify Hub
555 560 content = dict(msg_id=job.msg_id, engine_id=target.decode('ascii'))
556 561 self.session.send(self.mon_stream, 'task_destination', content=content,
557 562 ident=[b'tracktask',self.ident])
558 563
559 564
560 565 #-----------------------------------------------------------------------
561 566 # Result Handling
562 567 #-----------------------------------------------------------------------
568
569
570 @util.log_errors
563 571 def dispatch_result(self, raw_msg):
564 572 """dispatch method for result replies"""
565 573 try:
566 574 idents,msg = self.session.feed_identities(raw_msg, copy=False)
567 575 msg = self.session.unserialize(msg, content=False, copy=False)
568 576 engine = idents[0]
569 577 try:
570 578 idx = self.targets.index(engine)
571 579 except ValueError:
572 580 pass # skip load-update for dead engines
573 581 else:
574 582 self.finish_job(idx)
575 583 except Exception:
576 584 self.log.error("task::Invaid result: %r", raw_msg, exc_info=True)
577 585 return
578 586
579 587 header = msg['header']
580 588 parent = msg['parent_header']
581 589 if header.get('dependencies_met', True):
582 590 success = (header['status'] == 'ok')
583 591 msg_id = parent['msg_id']
584 592 retries = self.retries[msg_id]
585 593 if not success and retries > 0:
586 594 # failed
587 595 self.retries[msg_id] = retries - 1
588 596 self.handle_unmet_dependency(idents, parent)
589 597 else:
590 598 del self.retries[msg_id]
591 599 # relay to client and update graph
592 600 self.handle_result(idents, parent, raw_msg, success)
593 601 # send to Hub monitor
594 602 self.mon_stream.send_multipart([b'outtask']+raw_msg, copy=False)
595 603 else:
596 604 self.handle_unmet_dependency(idents, parent)
597 605
598 606 def handle_result(self, idents, parent, raw_msg, success=True):
599 607 """handle a real task result, either success or failure"""
600 608 # first, relay result to client
601 609 engine = idents[0]
602 610 client = idents[1]
603 611 # swap_ids for XREP-XREP mirror
604 612 raw_msg[:2] = [client,engine]
605 613 # print (map(str, raw_msg[:4]))
606 614 self.client_stream.send_multipart(raw_msg, copy=False)
607 615 # now, update our data structures
608 616 msg_id = parent['msg_id']
609 617 self.pending[engine].pop(msg_id)
610 618 if success:
611 619 self.completed[engine].add(msg_id)
612 620 self.all_completed.add(msg_id)
613 621 else:
614 622 self.failed[engine].add(msg_id)
615 623 self.all_failed.add(msg_id)
616 624 self.all_done.add(msg_id)
617 625 self.destinations[msg_id] = engine
618 626
619 627 self.update_graph(msg_id, success)
620 628
621 629 def handle_unmet_dependency(self, idents, parent):
622 630 """handle an unmet dependency"""
623 631 engine = idents[0]
624 632 msg_id = parent['msg_id']
625 633
626 634 job = self.pending[engine].pop(msg_id)
627 635 job.blacklist.add(engine)
628 636
629 637 if job.blacklist == job.targets:
630 638 self.depending[msg_id] = job
631 639 self.fail_unreachable(msg_id)
632 640 elif not self.maybe_run(job):
633 641 # resubmit failed
634 642 if msg_id not in self.all_failed:
635 643 # put it back in our dependency tree
636 644 self.save_unmet(job)
637 645
638 646 if self.hwm:
639 647 try:
640 648 idx = self.targets.index(engine)
641 649 except ValueError:
642 650 pass # skip load-update for dead engines
643 651 else:
644 652 if self.loads[idx] == self.hwm-1:
645 653 self.update_graph(None)
646 654
647 655
648 656
649 657 def update_graph(self, dep_id=None, success=True):
650 658 """dep_id just finished. Update our dependency
651 659 graph and submit any jobs that just became runable.
652 660
653 661 Called with dep_id=None to update entire graph for hwm, but without finishing
654 662 a task.
655 663 """
656 664 # print ("\n\n***********")
657 665 # pprint (dep_id)
658 666 # pprint (self.graph)
659 667 # pprint (self.depending)
660 668 # pprint (self.all_completed)
661 669 # pprint (self.all_failed)
662 670 # print ("\n\n***********\n\n")
663 671 # update any jobs that depended on the dependency
664 672 jobs = self.graph.pop(dep_id, [])
665 673
666 674 # recheck *all* jobs if
667 675 # a) we have HWM and an engine just become no longer full
668 676 # or b) dep_id was given as None
669 677
670 678 if dep_id is None or self.hwm and any( [ load==self.hwm-1 for load in self.loads ]):
671 679 jobs = self.depending.keys()
672 680
673 681 for msg_id in sorted(jobs, key=lambda msg_id: self.depending[msg_id].timestamp):
674 682 job = self.depending[msg_id]
675 683
676 684 if job.after.unreachable(self.all_completed, self.all_failed)\
677 685 or job.follow.unreachable(self.all_completed, self.all_failed):
678 686 self.fail_unreachable(msg_id)
679 687
680 688 elif job.after.check(self.all_completed, self.all_failed): # time deps met, maybe run
681 689 if self.maybe_run(job):
682 690
683 691 self.depending.pop(msg_id)
684 692 for mid in job.dependents:
685 693 if mid in self.graph:
686 694 self.graph[mid].remove(msg_id)
687 695
688 696 #----------------------------------------------------------------------
689 697 # methods to be overridden by subclasses
690 698 #----------------------------------------------------------------------
691 699
692 700 def add_job(self, idx):
693 701 """Called after self.targets[idx] just got the job with header.
694 702 Override with subclasses. The default ordering is simple LRU.
695 703 The default loads are the number of outstanding jobs."""
696 704 self.loads[idx] += 1
697 705 for lis in (self.targets, self.loads):
698 706 lis.append(lis.pop(idx))
699 707
700 708
701 709 def finish_job(self, idx):
702 710 """Called after self.targets[idx] just finished a job.
703 711 Override with subclasses."""
704 712 self.loads[idx] -= 1
705 713
706 714
707 715
708 716 def launch_scheduler(in_addr, out_addr, mon_addr, not_addr, config=None,
709 717 logname='root', log_url=None, loglevel=logging.DEBUG,
710 718 identity=b'task', in_thread=False):
711 719
712 720 ZMQStream = zmqstream.ZMQStream
713 721
714 722 if config:
715 723 # unwrap dict back into Config
716 724 config = Config(config)
717 725
718 726 if in_thread:
719 727 # use instance() to get the same Context/Loop as our parent
720 728 ctx = zmq.Context.instance()
721 729 loop = ioloop.IOLoop.instance()
722 730 else:
723 731 # in a process, don't use instance()
724 732 # for safety with multiprocessing
725 733 ctx = zmq.Context()
726 734 loop = ioloop.IOLoop()
727 735 ins = ZMQStream(ctx.socket(zmq.ROUTER),loop)
728 736 ins.setsockopt(zmq.IDENTITY, identity)
729 737 ins.bind(in_addr)
730 738
731 739 outs = ZMQStream(ctx.socket(zmq.ROUTER),loop)
732 740 outs.setsockopt(zmq.IDENTITY, identity)
733 741 outs.bind(out_addr)
734 742 mons = zmqstream.ZMQStream(ctx.socket(zmq.PUB),loop)
735 743 mons.connect(mon_addr)
736 744 nots = zmqstream.ZMQStream(ctx.socket(zmq.SUB),loop)
737 745 nots.setsockopt(zmq.SUBSCRIBE, b'')
738 746 nots.connect(not_addr)
739 747
740 748 # setup logging.
741 749 if in_thread:
742 750 log = Application.instance().log
743 751 else:
744 752 if log_url:
745 753 log = connect_logger(logname, ctx, log_url, root="scheduler", loglevel=loglevel)
746 754 else:
747 755 log = local_logger(logname, loglevel)
748 756
749 757 scheduler = TaskScheduler(client_stream=ins, engine_stream=outs,
750 758 mon_stream=mons, notifier_stream=nots,
751 759 loop=loop, log=log,
752 760 config=config)
753 761 scheduler.start()
754 762 if not in_thread:
755 763 try:
756 764 loop.start()
757 765 except KeyboardInterrupt:
758 766 scheduler.log.critical("Interrupted, exiting...")
759 767
@@ -1,480 +1,495 b''
1 1 """some generic utilities for dealing with classes, urls, and serialization
2 2
3 3 Authors:
4 4
5 5 * Min RK
6 6 """
7 7 #-----------------------------------------------------------------------------
8 8 # Copyright (C) 2010-2011 The IPython Development Team
9 9 #
10 10 # Distributed under the terms of the BSD License. The full license is in
11 11 # the file COPYING, distributed as part of this software.
12 12 #-----------------------------------------------------------------------------
13 13
14 14 #-----------------------------------------------------------------------------
15 15 # Imports
16 16 #-----------------------------------------------------------------------------
17 17
18 18 # Standard library imports.
19 19 import logging
20 20 import os
21 21 import re
22 22 import stat
23 23 import socket
24 24 import sys
25 25 from signal import signal, SIGINT, SIGABRT, SIGTERM
26 26 try:
27 27 from signal import SIGKILL
28 28 except ImportError:
29 29 SIGKILL=None
30 30
31 31 try:
32 32 import cPickle
33 33 pickle = cPickle
34 34 except:
35 35 cPickle = None
36 36 import pickle
37 37
38 38 # System library imports
39 39 import zmq
40 40 from zmq.log import handlers
41 41
42 from IPython.external.decorator import decorator
43
42 44 # IPython imports
43 45 from IPython.config.application import Application
44 46 from IPython.utils import py3compat
45 47 from IPython.utils.pickleutil import can, uncan, canSequence, uncanSequence
46 48 from IPython.utils.newserialized import serialize, unserialize
47 49 from IPython.zmq.log import EnginePUBHandler
48 50
49 51 if py3compat.PY3:
50 52 buffer = memoryview
51 53
52 54 #-----------------------------------------------------------------------------
53 55 # Classes
54 56 #-----------------------------------------------------------------------------
55 57
56 58 class Namespace(dict):
57 59 """Subclass of dict for attribute access to keys."""
58 60
59 61 def __getattr__(self, key):
60 62 """getattr aliased to getitem"""
61 63 if key in self.iterkeys():
62 64 return self[key]
63 65 else:
64 66 raise NameError(key)
65 67
66 68 def __setattr__(self, key, value):
67 69 """setattr aliased to setitem, with strict"""
68 70 if hasattr(dict, key):
69 71 raise KeyError("Cannot override dict keys %r"%key)
70 72 self[key] = value
71 73
72 74
73 75 class ReverseDict(dict):
74 76 """simple double-keyed subset of dict methods."""
75 77
76 78 def __init__(self, *args, **kwargs):
77 79 dict.__init__(self, *args, **kwargs)
78 80 self._reverse = dict()
79 81 for key, value in self.iteritems():
80 82 self._reverse[value] = key
81 83
82 84 def __getitem__(self, key):
83 85 try:
84 86 return dict.__getitem__(self, key)
85 87 except KeyError:
86 88 return self._reverse[key]
87 89
88 90 def __setitem__(self, key, value):
89 91 if key in self._reverse:
90 92 raise KeyError("Can't have key %r on both sides!"%key)
91 93 dict.__setitem__(self, key, value)
92 94 self._reverse[value] = key
93 95
94 96 def pop(self, key):
95 97 value = dict.pop(self, key)
96 98 self._reverse.pop(value)
97 99 return value
98 100
99 101 def get(self, key, default=None):
100 102 try:
101 103 return self[key]
102 104 except KeyError:
103 105 return default
104 106
105 107 #-----------------------------------------------------------------------------
106 108 # Functions
107 109 #-----------------------------------------------------------------------------
108 110
111 @decorator
112 def log_errors(f, self, *args, **kwargs):
113 """decorator to log unhandled exceptions raised in a method.
114
115 For use wrapping on_recv callbacks, so that exceptions
116 do not cause the stream to be closed.
117 """
118 try:
119 return f(self, *args, **kwargs)
120 except Exception:
121 self.log.error("Uncaught exception in %r" % f, exc_info=True)
122
123
109 124 def asbytes(s):
110 125 """ensure that an object is ascii bytes"""
111 126 if isinstance(s, unicode):
112 127 s = s.encode('ascii')
113 128 return s
114 129
115 130 def is_url(url):
116 131 """boolean check for whether a string is a zmq url"""
117 132 if '://' not in url:
118 133 return False
119 134 proto, addr = url.split('://', 1)
120 135 if proto.lower() not in ['tcp','pgm','epgm','ipc','inproc']:
121 136 return False
122 137 return True
123 138
124 139 def validate_url(url):
125 140 """validate a url for zeromq"""
126 141 if not isinstance(url, basestring):
127 142 raise TypeError("url must be a string, not %r"%type(url))
128 143 url = url.lower()
129 144
130 145 proto_addr = url.split('://')
131 146 assert len(proto_addr) == 2, 'Invalid url: %r'%url
132 147 proto, addr = proto_addr
133 148 assert proto in ['tcp','pgm','epgm','ipc','inproc'], "Invalid protocol: %r"%proto
134 149
135 150 # domain pattern adapted from http://www.regexlib.com/REDetails.aspx?regexp_id=391
136 151 # author: Remi Sabourin
137 152 pat = re.compile(r'^([\w\d]([\w\d\-]{0,61}[\w\d])?\.)*[\w\d]([\w\d\-]{0,61}[\w\d])?$')
138 153
139 154 if proto == 'tcp':
140 155 lis = addr.split(':')
141 156 assert len(lis) == 2, 'Invalid url: %r'%url
142 157 addr,s_port = lis
143 158 try:
144 159 port = int(s_port)
145 160 except ValueError:
146 161 raise AssertionError("Invalid port %r in url: %r"%(port, url))
147 162
148 163 assert addr == '*' or pat.match(addr) is not None, 'Invalid url: %r'%url
149 164
150 165 else:
151 166 # only validate tcp urls currently
152 167 pass
153 168
154 169 return True
155 170
156 171
157 172 def validate_url_container(container):
158 173 """validate a potentially nested collection of urls."""
159 174 if isinstance(container, basestring):
160 175 url = container
161 176 return validate_url(url)
162 177 elif isinstance(container, dict):
163 178 container = container.itervalues()
164 179
165 180 for element in container:
166 181 validate_url_container(element)
167 182
168 183
169 184 def split_url(url):
170 185 """split a zmq url (tcp://ip:port) into ('tcp','ip','port')."""
171 186 proto_addr = url.split('://')
172 187 assert len(proto_addr) == 2, 'Invalid url: %r'%url
173 188 proto, addr = proto_addr
174 189 lis = addr.split(':')
175 190 assert len(lis) == 2, 'Invalid url: %r'%url
176 191 addr,s_port = lis
177 192 return proto,addr,s_port
178 193
179 194 def disambiguate_ip_address(ip, location=None):
180 195 """turn multi-ip interfaces '0.0.0.0' and '*' into connectable
181 196 ones, based on the location (default interpretation of location is localhost)."""
182 197 if ip in ('0.0.0.0', '*'):
183 198 try:
184 199 external_ips = socket.gethostbyname_ex(socket.gethostname())[2]
185 200 except (socket.gaierror, IndexError):
186 201 # couldn't identify this machine, assume localhost
187 202 external_ips = []
188 203 if location is None or location in external_ips or not external_ips:
189 204 # If location is unspecified or cannot be determined, assume local
190 205 ip='127.0.0.1'
191 206 elif location:
192 207 return location
193 208 return ip
194 209
195 210 def disambiguate_url(url, location=None):
196 211 """turn multi-ip interfaces '0.0.0.0' and '*' into connectable
197 212 ones, based on the location (default interpretation is localhost).
198 213
199 214 This is for zeromq urls, such as tcp://*:10101."""
200 215 try:
201 216 proto,ip,port = split_url(url)
202 217 except AssertionError:
203 218 # probably not tcp url; could be ipc, etc.
204 219 return url
205 220
206 221 ip = disambiguate_ip_address(ip,location)
207 222
208 223 return "%s://%s:%s"%(proto,ip,port)
209 224
210 225 def serialize_object(obj, threshold=64e-6):
211 226 """Serialize an object into a list of sendable buffers.
212 227
213 228 Parameters
214 229 ----------
215 230
216 231 obj : object
217 232 The object to be serialized
218 233 threshold : float
219 234 The threshold for not double-pickling the content.
220 235
221 236
222 237 Returns
223 238 -------
224 239 ('pmd', [bufs]) :
225 240 where pmd is the pickled metadata wrapper,
226 241 bufs is a list of data buffers
227 242 """
228 243 databuffers = []
229 244 if isinstance(obj, (list, tuple)):
230 245 clist = canSequence(obj)
231 246 slist = map(serialize, clist)
232 247 for s in slist:
233 248 if s.typeDescriptor in ('buffer', 'ndarray') or s.getDataSize() > threshold:
234 249 databuffers.append(s.getData())
235 250 s.data = None
236 251 return pickle.dumps(slist,-1), databuffers
237 252 elif isinstance(obj, dict):
238 253 sobj = {}
239 254 for k in sorted(obj.iterkeys()):
240 255 s = serialize(can(obj[k]))
241 256 if s.typeDescriptor in ('buffer', 'ndarray') or s.getDataSize() > threshold:
242 257 databuffers.append(s.getData())
243 258 s.data = None
244 259 sobj[k] = s
245 260 return pickle.dumps(sobj,-1),databuffers
246 261 else:
247 262 s = serialize(can(obj))
248 263 if s.typeDescriptor in ('buffer', 'ndarray') or s.getDataSize() > threshold:
249 264 databuffers.append(s.getData())
250 265 s.data = None
251 266 return pickle.dumps(s,-1),databuffers
252 267
253 268
254 269 def unserialize_object(bufs):
255 270 """reconstruct an object serialized by serialize_object from data buffers."""
256 271 bufs = list(bufs)
257 272 sobj = pickle.loads(bufs.pop(0))
258 273 if isinstance(sobj, (list, tuple)):
259 274 for s in sobj:
260 275 if s.data is None:
261 276 s.data = bufs.pop(0)
262 277 return uncanSequence(map(unserialize, sobj)), bufs
263 278 elif isinstance(sobj, dict):
264 279 newobj = {}
265 280 for k in sorted(sobj.iterkeys()):
266 281 s = sobj[k]
267 282 if s.data is None:
268 283 s.data = bufs.pop(0)
269 284 newobj[k] = uncan(unserialize(s))
270 285 return newobj, bufs
271 286 else:
272 287 if sobj.data is None:
273 288 sobj.data = bufs.pop(0)
274 289 return uncan(unserialize(sobj)), bufs
275 290
276 291 def pack_apply_message(f, args, kwargs, threshold=64e-6):
277 292 """pack up a function, args, and kwargs to be sent over the wire
278 293 as a series of buffers. Any object whose data is larger than `threshold`
279 294 will not have their data copied (currently only numpy arrays support zero-copy)"""
280 295 msg = [pickle.dumps(can(f),-1)]
281 296 databuffers = [] # for large objects
282 297 sargs, bufs = serialize_object(args,threshold)
283 298 msg.append(sargs)
284 299 databuffers.extend(bufs)
285 300 skwargs, bufs = serialize_object(kwargs,threshold)
286 301 msg.append(skwargs)
287 302 databuffers.extend(bufs)
288 303 msg.extend(databuffers)
289 304 return msg
290 305
291 306 def unpack_apply_message(bufs, g=None, copy=True):
292 307 """unpack f,args,kwargs from buffers packed by pack_apply_message()
293 308 Returns: original f,args,kwargs"""
294 309 bufs = list(bufs) # allow us to pop
295 310 assert len(bufs) >= 3, "not enough buffers!"
296 311 if not copy:
297 312 for i in range(3):
298 313 bufs[i] = bufs[i].bytes
299 314 cf = pickle.loads(bufs.pop(0))
300 315 sargs = list(pickle.loads(bufs.pop(0)))
301 316 skwargs = dict(pickle.loads(bufs.pop(0)))
302 317 # print sargs, skwargs
303 318 f = uncan(cf, g)
304 319 for sa in sargs:
305 320 if sa.data is None:
306 321 m = bufs.pop(0)
307 322 if sa.getTypeDescriptor() in ('buffer', 'ndarray'):
308 323 # always use a buffer, until memoryviews get sorted out
309 324 sa.data = buffer(m)
310 325 # disable memoryview support
311 326 # if copy:
312 327 # sa.data = buffer(m)
313 328 # else:
314 329 # sa.data = m.buffer
315 330 else:
316 331 if copy:
317 332 sa.data = m
318 333 else:
319 334 sa.data = m.bytes
320 335
321 336 args = uncanSequence(map(unserialize, sargs), g)
322 337 kwargs = {}
323 338 for k in sorted(skwargs.iterkeys()):
324 339 sa = skwargs[k]
325 340 if sa.data is None:
326 341 m = bufs.pop(0)
327 342 if sa.getTypeDescriptor() in ('buffer', 'ndarray'):
328 343 # always use a buffer, until memoryviews get sorted out
329 344 sa.data = buffer(m)
330 345 # disable memoryview support
331 346 # if copy:
332 347 # sa.data = buffer(m)
333 348 # else:
334 349 # sa.data = m.buffer
335 350 else:
336 351 if copy:
337 352 sa.data = m
338 353 else:
339 354 sa.data = m.bytes
340 355
341 356 kwargs[k] = uncan(unserialize(sa), g)
342 357
343 358 return f,args,kwargs
344 359
345 360 #--------------------------------------------------------------------------
346 361 # helpers for implementing old MEC API via view.apply
347 362 #--------------------------------------------------------------------------
348 363
349 364 def interactive(f):
350 365 """decorator for making functions appear as interactively defined.
351 366 This results in the function being linked to the user_ns as globals()
352 367 instead of the module globals().
353 368 """
354 369 f.__module__ = '__main__'
355 370 return f
356 371
357 372 @interactive
358 373 def _push(**ns):
359 374 """helper method for implementing `client.push` via `client.apply`"""
360 375 globals().update(ns)
361 376
362 377 @interactive
363 378 def _pull(keys):
364 379 """helper method for implementing `client.pull` via `client.apply`"""
365 380 user_ns = globals()
366 381 if isinstance(keys, (list,tuple, set)):
367 382 for key in keys:
368 383 if not user_ns.has_key(key):
369 384 raise NameError("name '%s' is not defined"%key)
370 385 return map(user_ns.get, keys)
371 386 else:
372 387 if not user_ns.has_key(keys):
373 388 raise NameError("name '%s' is not defined"%keys)
374 389 return user_ns.get(keys)
375 390
376 391 @interactive
377 392 def _execute(code):
378 393 """helper method for implementing `client.execute` via `client.apply`"""
379 394 exec code in globals()
380 395
381 396 #--------------------------------------------------------------------------
382 397 # extra process management utilities
383 398 #--------------------------------------------------------------------------
384 399
385 400 _random_ports = set()
386 401
387 402 def select_random_ports(n):
388 403 """Selects and return n random ports that are available."""
389 404 ports = []
390 405 for i in xrange(n):
391 406 sock = socket.socket()
392 407 sock.bind(('', 0))
393 408 while sock.getsockname()[1] in _random_ports:
394 409 sock.close()
395 410 sock = socket.socket()
396 411 sock.bind(('', 0))
397 412 ports.append(sock)
398 413 for i, sock in enumerate(ports):
399 414 port = sock.getsockname()[1]
400 415 sock.close()
401 416 ports[i] = port
402 417 _random_ports.add(port)
403 418 return ports
404 419
405 420 def signal_children(children):
406 421 """Relay interupt/term signals to children, for more solid process cleanup."""
407 422 def terminate_children(sig, frame):
408 423 log = Application.instance().log
409 424 log.critical("Got signal %i, terminating children..."%sig)
410 425 for child in children:
411 426 child.terminate()
412 427
413 428 sys.exit(sig != SIGINT)
414 429 # sys.exit(sig)
415 430 for sig in (SIGINT, SIGABRT, SIGTERM):
416 431 signal(sig, terminate_children)
417 432
418 433 def generate_exec_key(keyfile):
419 434 import uuid
420 435 newkey = str(uuid.uuid4())
421 436 with open(keyfile, 'w') as f:
422 437 # f.write('ipython-key ')
423 438 f.write(newkey+'\n')
424 439 # set user-only RW permissions (0600)
425 440 # this will have no effect on Windows
426 441 os.chmod(keyfile, stat.S_IRUSR|stat.S_IWUSR)
427 442
428 443
429 444 def integer_loglevel(loglevel):
430 445 try:
431 446 loglevel = int(loglevel)
432 447 except ValueError:
433 448 if isinstance(loglevel, str):
434 449 loglevel = getattr(logging, loglevel)
435 450 return loglevel
436 451
437 452 def connect_logger(logname, context, iface, root="ip", loglevel=logging.DEBUG):
438 453 logger = logging.getLogger(logname)
439 454 if any([isinstance(h, handlers.PUBHandler) for h in logger.handlers]):
440 455 # don't add a second PUBHandler
441 456 return
442 457 loglevel = integer_loglevel(loglevel)
443 458 lsock = context.socket(zmq.PUB)
444 459 lsock.connect(iface)
445 460 handler = handlers.PUBHandler(lsock)
446 461 handler.setLevel(loglevel)
447 462 handler.root_topic = root
448 463 logger.addHandler(handler)
449 464 logger.setLevel(loglevel)
450 465
451 466 def connect_engine_logger(context, iface, engine, loglevel=logging.DEBUG):
452 467 logger = logging.getLogger()
453 468 if any([isinstance(h, handlers.PUBHandler) for h in logger.handlers]):
454 469 # don't add a second PUBHandler
455 470 return
456 471 loglevel = integer_loglevel(loglevel)
457 472 lsock = context.socket(zmq.PUB)
458 473 lsock.connect(iface)
459 474 handler = EnginePUBHandler(engine, lsock)
460 475 handler.setLevel(loglevel)
461 476 logger.addHandler(handler)
462 477 logger.setLevel(loglevel)
463 478 return logger
464 479
465 480 def local_logger(logname, loglevel=logging.DEBUG):
466 481 loglevel = integer_loglevel(loglevel)
467 482 logger = logging.getLogger(logname)
468 483 if any([isinstance(h, logging.StreamHandler) for h in logger.handlers]):
469 484 # don't add a second StreamHandler
470 485 return
471 486 handler = logging.StreamHandler()
472 487 handler.setLevel(loglevel)
473 488 formatter = logging.Formatter("%(asctime)s.%(msecs).03d [%(name)s] %(message)s",
474 489 datefmt="%Y-%m-%d %H:%M:%S")
475 490 handler.setFormatter(formatter)
476 491
477 492 logger.addHandler(handler)
478 493 logger.setLevel(loglevel)
479 494 return logger
480 495
General Comments 0
You need to be logged in to leave comments. Login now