##// END OF EJS Templates
allow true single-threaded Controller...
MinRK -
Show More
@@ -1,408 +1,414 b''
1 1 #!/usr/bin/env python
2 2 # encoding: utf-8
3 3 """
4 4 The IPython controller application.
5 5
6 6 Authors:
7 7
8 8 * Brian Granger
9 9 * MinRK
10 10
11 11 """
12 12
13 13 #-----------------------------------------------------------------------------
14 14 # Copyright (C) 2008-2011 The IPython Development Team
15 15 #
16 16 # Distributed under the terms of the BSD License. The full license is in
17 17 # the file COPYING, distributed as part of this software.
18 18 #-----------------------------------------------------------------------------
19 19
20 20 #-----------------------------------------------------------------------------
21 21 # Imports
22 22 #-----------------------------------------------------------------------------
23 23
24 24 from __future__ import with_statement
25 25
26 26 import os
27 27 import socket
28 28 import stat
29 29 import sys
30 30 import uuid
31 31
32 32 from multiprocessing import Process
33 33
34 34 import zmq
35 35 from zmq.devices import ProcessMonitoredQueue
36 36 from zmq.log.handlers import PUBHandler
37 37 from zmq.utils import jsonapi as json
38 38
39 39 from IPython.config.application import boolean_flag
40 40 from IPython.core.profiledir import ProfileDir
41 41
42 42 from IPython.parallel.apps.baseapp import (
43 43 BaseParallelApplication,
44 44 base_flags
45 45 )
46 46 from IPython.utils.importstring import import_item
47 47 from IPython.utils.traitlets import Instance, Unicode, Bool, List, Dict
48 48
49 49 # from IPython.parallel.controller.controller import ControllerFactory
50 50 from IPython.zmq.session import Session
51 51 from IPython.parallel.controller.heartmonitor import HeartMonitor
52 52 from IPython.parallel.controller.hub import HubFactory
53 53 from IPython.parallel.controller.scheduler import TaskScheduler,launch_scheduler
54 54 from IPython.parallel.controller.sqlitedb import SQLiteDB
55 55
56 56 from IPython.parallel.util import signal_children, split_url
57 57
58 58 # conditional import of MongoDB backend class
59 59
60 60 try:
61 61 from IPython.parallel.controller.mongodb import MongoDB
62 62 except ImportError:
63 63 maybe_mongo = []
64 64 else:
65 65 maybe_mongo = [MongoDB]
66 66
67 67
68 68 #-----------------------------------------------------------------------------
69 69 # Module level variables
70 70 #-----------------------------------------------------------------------------
71 71
72 72
73 73 #: The default config file name for this application
74 74 default_config_file_name = u'ipcontroller_config.py'
75 75
76 76
77 77 _description = """Start the IPython controller for parallel computing.
78 78
79 79 The IPython controller provides a gateway between the IPython engines and
80 80 clients. The controller needs to be started before the engines and can be
81 81 configured using command line options or using a cluster directory. Cluster
82 82 directories contain config, log and security files and are usually located in
83 83 your ipython directory and named as "profile_name". See the `profile`
84 84 and `profile_dir` options for details.
85 85 """
86 86
87 87
88 88
89 89
90 90 #-----------------------------------------------------------------------------
91 91 # The main application
92 92 #-----------------------------------------------------------------------------
93 93 flags = {}
94 94 flags.update(base_flags)
95 95 flags.update({
96 96 'usethreads' : ( {'IPControllerApp' : {'use_threads' : True}},
97 97 'Use threads instead of processes for the schedulers'),
98 98 'sqlitedb' : ({'HubFactory' : {'db_class' : 'IPython.parallel.controller.sqlitedb.SQLiteDB'}},
99 99 'use the SQLiteDB backend'),
100 100 'mongodb' : ({'HubFactory' : {'db_class' : 'IPython.parallel.controller.mongodb.MongoDB'}},
101 101 'use the MongoDB backend'),
102 102 'dictdb' : ({'HubFactory' : {'db_class' : 'IPython.parallel.controller.dictdb.DictDB'}},
103 103 'use the in-memory DictDB backend'),
104 104 'reuse' : ({'IPControllerApp' : {'reuse_files' : True}},
105 105 'reuse existing json connection files')
106 106 })
107 107
108 108 flags.update(boolean_flag('secure', 'IPControllerApp.secure',
109 109 "Use HMAC digests for authentication of messages.",
110 110 "Don't authenticate messages."
111 111 ))
112 112
113 113 class IPControllerApp(BaseParallelApplication):
114 114
115 115 name = u'ipcontroller'
116 116 description = _description
117 117 config_file_name = Unicode(default_config_file_name)
118 118 classes = [ProfileDir, Session, HubFactory, TaskScheduler, HeartMonitor, SQLiteDB] + maybe_mongo
119 119
120 120 # change default to True
121 121 auto_create = Bool(True, config=True,
122 122 help="""Whether to create profile dir if it doesn't exist.""")
123 123
124 124 reuse_files = Bool(False, config=True,
125 125 help='Whether to reuse existing json connection files.'
126 126 )
127 127 secure = Bool(True, config=True,
128 128 help='Whether to use HMAC digests for extra message authentication.'
129 129 )
130 130 ssh_server = Unicode(u'', config=True,
131 131 help="""ssh url for clients to use when connecting to the Controller
132 132 processes. It should be of the form: [user@]server[:port]. The
133 133 Controller's listening addresses must be accessible from the ssh server""",
134 134 )
135 135 location = Unicode(u'', config=True,
136 136 help="""The external IP or domain name of the Controller, used for disambiguating
137 137 engine and client connections.""",
138 138 )
139 139 import_statements = List([], config=True,
140 140 help="import statements to be run at startup. Necessary in some environments"
141 141 )
142 142
143 143 use_threads = Bool(False, config=True,
144 144 help='Use threads instead of processes for the schedulers',
145 145 )
146 146
147 147 # internal
148 148 children = List()
149 149 mq_class = Unicode('zmq.devices.ProcessMonitoredQueue')
150 150
151 151 def _use_threads_changed(self, name, old, new):
152 152 self.mq_class = 'zmq.devices.%sMonitoredQueue'%('Thread' if new else 'Process')
153 153
154 154 aliases = Dict(dict(
155 155 log_level = 'IPControllerApp.log_level',
156 156 log_url = 'IPControllerApp.log_url',
157 157 reuse_files = 'IPControllerApp.reuse_files',
158 158 secure = 'IPControllerApp.secure',
159 159 ssh = 'IPControllerApp.ssh_server',
160 160 use_threads = 'IPControllerApp.use_threads',
161 161 import_statements = 'IPControllerApp.import_statements',
162 162 location = 'IPControllerApp.location',
163 163
164 164 ident = 'Session.session',
165 165 user = 'Session.username',
166 166 exec_key = 'Session.keyfile',
167 167
168 168 url = 'HubFactory.url',
169 169 ip = 'HubFactory.ip',
170 170 transport = 'HubFactory.transport',
171 171 port = 'HubFactory.regport',
172 172
173 173 ping = 'HeartMonitor.period',
174 174
175 175 scheme = 'TaskScheduler.scheme_name',
176 176 hwm = 'TaskScheduler.hwm',
177 177
178 178
179 179 profile = "BaseIPythonApplication.profile",
180 180 profile_dir = 'ProfileDir.location',
181 181
182 182 ))
183 183 flags = Dict(flags)
184 184
185 185
186 186 def save_connection_dict(self, fname, cdict):
187 187 """save a connection dict to json file."""
188 188 c = self.config
189 189 url = cdict['url']
190 190 location = cdict['location']
191 191 if not location:
192 192 try:
193 193 proto,ip,port = split_url(url)
194 194 except AssertionError:
195 195 pass
196 196 else:
197 197 location = socket.gethostbyname_ex(socket.gethostname())[2][-1]
198 198 cdict['location'] = location
199 199 fname = os.path.join(self.profile_dir.security_dir, fname)
200 200 with open(fname, 'w') as f:
201 201 f.write(json.dumps(cdict, indent=2))
202 202 os.chmod(fname, stat.S_IRUSR|stat.S_IWUSR)
203 203
204 204 def load_config_from_json(self):
205 205 """load config from existing json connector files."""
206 206 c = self.config
207 207 # load from engine config
208 208 with open(os.path.join(self.profile_dir.security_dir, 'ipcontroller-engine.json')) as f:
209 209 cfg = json.loads(f.read())
210 210 key = c.Session.key = cfg['exec_key']
211 211 xport,addr = cfg['url'].split('://')
212 212 c.HubFactory.engine_transport = xport
213 213 ip,ports = addr.split(':')
214 214 c.HubFactory.engine_ip = ip
215 215 c.HubFactory.regport = int(ports)
216 216 self.location = cfg['location']
217 217
218 218 # load client config
219 219 with open(os.path.join(self.profile_dir.security_dir, 'ipcontroller-client.json')) as f:
220 220 cfg = json.loads(f.read())
221 221 assert key == cfg['exec_key'], "exec_key mismatch between engine and client keys"
222 222 xport,addr = cfg['url'].split('://')
223 223 c.HubFactory.client_transport = xport
224 224 ip,ports = addr.split(':')
225 225 c.HubFactory.client_ip = ip
226 226 self.ssh_server = cfg['ssh']
227 227 assert int(ports) == c.HubFactory.regport, "regport mismatch"
228 228
229 229 def init_hub(self):
230 230 c = self.config
231 231
232 232 self.do_import_statements()
233 233 reusing = self.reuse_files
234 234 if reusing:
235 235 try:
236 236 self.load_config_from_json()
237 237 except (AssertionError,IOError):
238 238 reusing=False
239 239 # check again, because reusing may have failed:
240 240 if reusing:
241 241 pass
242 242 elif self.secure:
243 243 key = str(uuid.uuid4())
244 244 # keyfile = os.path.join(self.profile_dir.security_dir, self.exec_key)
245 245 # with open(keyfile, 'w') as f:
246 246 # f.write(key)
247 247 # os.chmod(keyfile, stat.S_IRUSR|stat.S_IWUSR)
248 248 c.Session.key = key
249 249 else:
250 250 key = c.Session.key = ''
251 251
252 252 try:
253 253 self.factory = HubFactory(config=c, log=self.log)
254 254 # self.start_logging()
255 255 self.factory.init_hub()
256 256 except:
257 257 self.log.error("Couldn't construct the Controller", exc_info=True)
258 258 self.exit(1)
259 259
260 260 if not reusing:
261 261 # save to new json config files
262 262 f = self.factory
263 263 cdict = {'exec_key' : key,
264 264 'ssh' : self.ssh_server,
265 265 'url' : "%s://%s:%s"%(f.client_transport, f.client_ip, f.regport),
266 266 'location' : self.location
267 267 }
268 268 self.save_connection_dict('ipcontroller-client.json', cdict)
269 269 edict = cdict
270 270 edict['url']="%s://%s:%s"%((f.client_transport, f.client_ip, f.regport))
271 271 self.save_connection_dict('ipcontroller-engine.json', edict)
272 272
273 273 #
274 274 def init_schedulers(self):
275 275 children = self.children
276 276 mq = import_item(str(self.mq_class))
277 277
278 278 hub = self.factory
279 279 # maybe_inproc = 'inproc://monitor' if self.use_threads else self.monitor_url
280 280 # IOPub relay (in a Process)
281 281 q = mq(zmq.PUB, zmq.SUB, zmq.PUB, 'N/A','iopub')
282 282 q.bind_in(hub.client_info['iopub'])
283 283 q.bind_out(hub.engine_info['iopub'])
284 284 q.setsockopt_out(zmq.SUBSCRIBE, '')
285 285 q.connect_mon(hub.monitor_url)
286 286 q.daemon=True
287 287 children.append(q)
288 288
289 289 # Multiplexer Queue (in a Process)
290 290 q = mq(zmq.XREP, zmq.XREP, zmq.PUB, 'in', 'out')
291 291 q.bind_in(hub.client_info['mux'])
292 292 q.setsockopt_in(zmq.IDENTITY, 'mux')
293 293 q.bind_out(hub.engine_info['mux'])
294 294 q.connect_mon(hub.monitor_url)
295 295 q.daemon=True
296 296 children.append(q)
297 297
298 298 # Control Queue (in a Process)
299 299 q = mq(zmq.XREP, zmq.XREP, zmq.PUB, 'incontrol', 'outcontrol')
300 300 q.bind_in(hub.client_info['control'])
301 301 q.setsockopt_in(zmq.IDENTITY, 'control')
302 302 q.bind_out(hub.engine_info['control'])
303 303 q.connect_mon(hub.monitor_url)
304 304 q.daemon=True
305 305 children.append(q)
306 306 try:
307 307 scheme = self.config.TaskScheduler.scheme_name
308 308 except AttributeError:
309 309 scheme = TaskScheduler.scheme_name.get_default_value()
310 310 # Task Queue (in a Process)
311 311 if scheme == 'pure':
312 312 self.log.warn("task::using pure XREQ Task scheduler")
313 313 q = mq(zmq.XREP, zmq.XREQ, zmq.PUB, 'intask', 'outtask')
314 314 # q.setsockopt_out(zmq.HWM, hub.hwm)
315 315 q.bind_in(hub.client_info['task'][1])
316 316 q.setsockopt_in(zmq.IDENTITY, 'task')
317 317 q.bind_out(hub.engine_info['task'])
318 318 q.connect_mon(hub.monitor_url)
319 319 q.daemon=True
320 320 children.append(q)
321 321 elif scheme == 'none':
322 322 self.log.warn("task::using no Task scheduler")
323 323
324 324 else:
325 325 self.log.info("task::using Python %s Task scheduler"%scheme)
326 326 sargs = (hub.client_info['task'][1], hub.engine_info['task'],
327 327 hub.monitor_url, hub.client_info['notification'])
328 328 kwargs = dict(logname='scheduler', loglevel=self.log_level,
329 329 log_url = self.log_url, config=dict(self.config))
330 q = Process(target=launch_scheduler, args=sargs, kwargs=kwargs)
331 q.daemon=True
332 children.append(q)
330 if 'Process' in self.mq_class:
331 # run the Python scheduler in a Process
332 q = Process(target=launch_scheduler, args=sargs, kwargs=kwargs)
333 q.daemon=True
334 children.append(q)
335 else:
336 # single-threaded Controller
337 kwargs['in_thread'] = True
338 launch_scheduler(*sargs, **kwargs)
333 339
334 340
335 341 def save_urls(self):
336 342 """save the registration urls to files."""
337 343 c = self.config
338 344
339 345 sec_dir = self.profile_dir.security_dir
340 346 cf = self.factory
341 347
342 348 with open(os.path.join(sec_dir, 'ipcontroller-engine.url'), 'w') as f:
343 349 f.write("%s://%s:%s"%(cf.engine_transport, cf.engine_ip, cf.regport))
344 350
345 351 with open(os.path.join(sec_dir, 'ipcontroller-client.url'), 'w') as f:
346 352 f.write("%s://%s:%s"%(cf.client_transport, cf.client_ip, cf.regport))
347 353
348 354
349 355 def do_import_statements(self):
350 356 statements = self.import_statements
351 357 for s in statements:
352 358 try:
353 359 self.log.msg("Executing statement: '%s'" % s)
354 360 exec s in globals(), locals()
355 361 except:
356 362 self.log.msg("Error running statement: %s" % s)
357 363
358 364 def forward_logging(self):
359 365 if self.log_url:
360 366 self.log.info("Forwarding logging to %s"%self.log_url)
361 367 context = zmq.Context.instance()
362 368 lsock = context.socket(zmq.PUB)
363 369 lsock.connect(self.log_url)
364 370 handler = PUBHandler(lsock)
365 371 self.log.removeHandler(self._log_handler)
366 372 handler.root_topic = 'controller'
367 373 handler.setLevel(self.log_level)
368 374 self.log.addHandler(handler)
369 375 self._log_handler = handler
370 376 # #
371 377
372 378 def initialize(self, argv=None):
373 379 super(IPControllerApp, self).initialize(argv)
374 380 self.forward_logging()
375 381 self.init_hub()
376 382 self.init_schedulers()
377 383
378 384 def start(self):
379 385 # Start the subprocesses:
380 386 self.factory.start()
381 387 child_procs = []
382 388 for child in self.children:
383 389 child.start()
384 390 if isinstance(child, ProcessMonitoredQueue):
385 391 child_procs.append(child.launcher)
386 392 elif isinstance(child, Process):
387 393 child_procs.append(child)
388 394 if child_procs:
389 395 signal_children(child_procs)
390 396
391 397 self.write_pid_file(overwrite=True)
392 398
393 399 try:
394 400 self.factory.loop.start()
395 401 except KeyboardInterrupt:
396 402 self.log.critical("Interrupted, Exiting...\n")
397 403
398 404
399 405
400 406 def launch_new_instance():
401 407 """Create and run the IPython controller"""
402 408 app = IPControllerApp.instance()
403 409 app.initialize()
404 410 app.start()
405 411
406 412
407 413 if __name__ == '__main__':
408 414 launch_new_instance()
@@ -1,692 +1,703 b''
1 1 """The Python scheduler for rich scheduling.
2 2
3 3 The Pure ZMQ scheduler does not allow routing schemes other than LRU,
4 4 nor does it check msg_id DAG dependencies. For those, a slightly slower
5 5 Python Scheduler exists.
6 6
7 7 Authors:
8 8
9 9 * Min RK
10 10 """
11 11 #-----------------------------------------------------------------------------
12 12 # Copyright (C) 2010-2011 The IPython Development Team
13 13 #
14 14 # Distributed under the terms of the BSD License. The full license is in
15 15 # the file COPYING, distributed as part of this software.
16 16 #-----------------------------------------------------------------------------
17 17
18 18 #----------------------------------------------------------------------
19 19 # Imports
20 20 #----------------------------------------------------------------------
21 21
22 22 from __future__ import print_function
23 23
24 24 import logging
25 25 import sys
26 26
27 27 from datetime import datetime, timedelta
28 28 from random import randint, random
29 29 from types import FunctionType
30 30
31 31 try:
32 32 import numpy
33 33 except ImportError:
34 34 numpy = None
35 35
36 36 import zmq
37 37 from zmq.eventloop import ioloop, zmqstream
38 38
39 39 # local imports
40 40 from IPython.external.decorator import decorator
41 from IPython.config.application import Application
41 42 from IPython.config.loader import Config
42 43 from IPython.utils.traitlets import Instance, Dict, List, Set, Int, Enum
43 44
44 45 from IPython.parallel import error
45 46 from IPython.parallel.factory import SessionFactory
46 47 from IPython.parallel.util import connect_logger, local_logger
47 48
48 49 from .dependency import Dependency
49 50
50 51 @decorator
51 52 def logged(f,self,*args,**kwargs):
52 53 # print ("#--------------------")
53 54 self.log.debug("scheduler::%s(*%s,**%s)"%(f.func_name, args, kwargs))
54 55 # print ("#--")
55 56 return f(self,*args, **kwargs)
56 57
57 58 #----------------------------------------------------------------------
58 59 # Chooser functions
59 60 #----------------------------------------------------------------------
60 61
61 62 def plainrandom(loads):
62 63 """Plain random pick."""
63 64 n = len(loads)
64 65 return randint(0,n-1)
65 66
66 67 def lru(loads):
67 68 """Always pick the front of the line.
68 69
69 70 The content of `loads` is ignored.
70 71
71 72 Assumes LRU ordering of loads, with oldest first.
72 73 """
73 74 return 0
74 75
75 76 def twobin(loads):
76 77 """Pick two at random, use the LRU of the two.
77 78
78 79 The content of loads is ignored.
79 80
80 81 Assumes LRU ordering of loads, with oldest first.
81 82 """
82 83 n = len(loads)
83 84 a = randint(0,n-1)
84 85 b = randint(0,n-1)
85 86 return min(a,b)
86 87
87 88 def weighted(loads):
88 89 """Pick two at random using inverse load as weight.
89 90
90 91 Return the less loaded of the two.
91 92 """
92 93 # weight 0 a million times more than 1:
93 94 weights = 1./(1e-6+numpy.array(loads))
94 95 sums = weights.cumsum()
95 96 t = sums[-1]
96 97 x = random()*t
97 98 y = random()*t
98 99 idx = 0
99 100 idy = 0
100 101 while sums[idx] < x:
101 102 idx += 1
102 103 while sums[idy] < y:
103 104 idy += 1
104 105 if weights[idy] > weights[idx]:
105 106 return idy
106 107 else:
107 108 return idx
108 109
109 110 def leastload(loads):
110 111 """Always choose the lowest load.
111 112
112 113 If the lowest load occurs more than once, the first
113 114 occurance will be used. If loads has LRU ordering, this means
114 115 the LRU of those with the lowest load is chosen.
115 116 """
116 117 return loads.index(min(loads))
117 118
118 119 #---------------------------------------------------------------------
119 120 # Classes
120 121 #---------------------------------------------------------------------
121 122 # store empty default dependency:
122 123 MET = Dependency([])
123 124
124 125 class TaskScheduler(SessionFactory):
125 126 """Python TaskScheduler object.
126 127
127 128 This is the simplest object that supports msg_id based
128 129 DAG dependencies. *Only* task msg_ids are checked, not
129 130 msg_ids of jobs submitted via the MUX queue.
130 131
131 132 """
132 133
133 134 hwm = Int(0, config=True, shortname='hwm',
134 135 help="""specify the High Water Mark (HWM) for the downstream
135 136 socket in the Task scheduler. This is the maximum number
136 137 of allowed outstanding tasks on each engine."""
137 138 )
138 139 scheme_name = Enum(('leastload', 'pure', 'lru', 'plainrandom', 'weighted', 'twobin'),
139 140 'leastload', config=True, shortname='scheme', allow_none=False,
140 141 help="""select the task scheduler scheme [default: Python LRU]
141 142 Options are: 'pure', 'lru', 'plainrandom', 'weighted', 'twobin','leastload'"""
142 143 )
143 144 def _scheme_name_changed(self, old, new):
144 145 self.log.debug("Using scheme %r"%new)
145 146 self.scheme = globals()[new]
146 147
147 148 # input arguments:
148 149 scheme = Instance(FunctionType) # function for determining the destination
149 150 def _scheme_default(self):
150 151 return leastload
151 152 client_stream = Instance(zmqstream.ZMQStream) # client-facing stream
152 153 engine_stream = Instance(zmqstream.ZMQStream) # engine-facing stream
153 154 notifier_stream = Instance(zmqstream.ZMQStream) # hub-facing sub stream
154 155 mon_stream = Instance(zmqstream.ZMQStream) # hub-facing pub stream
155 156
156 157 # internals:
157 158 graph = Dict() # dict by msg_id of [ msg_ids that depend on key ]
158 159 retries = Dict() # dict by msg_id of retries remaining (non-neg ints)
159 160 # waiting = List() # list of msg_ids ready to run, but haven't due to HWM
160 161 depending = Dict() # dict by msg_id of (msg_id, raw_msg, after, follow)
161 162 pending = Dict() # dict by engine_uuid of submitted tasks
162 163 completed = Dict() # dict by engine_uuid of completed tasks
163 164 failed = Dict() # dict by engine_uuid of failed tasks
164 165 destinations = Dict() # dict by msg_id of engine_uuids where jobs ran (reverse of completed+failed)
165 166 clients = Dict() # dict by msg_id for who submitted the task
166 167 targets = List() # list of target IDENTs
167 168 loads = List() # list of engine loads
168 169 # full = Set() # set of IDENTs that have HWM outstanding tasks
169 170 all_completed = Set() # set of all completed tasks
170 171 all_failed = Set() # set of all failed tasks
171 172 all_done = Set() # set of all finished tasks=union(completed,failed)
172 173 all_ids = Set() # set of all submitted task IDs
173 174 blacklist = Dict() # dict by msg_id of locations where a job has encountered UnmetDependency
174 175 auditor = Instance('zmq.eventloop.ioloop.PeriodicCallback')
175 176
176 177
177 178 def start(self):
178 179 self.engine_stream.on_recv(self.dispatch_result, copy=False)
179 180 self._notification_handlers = dict(
180 181 registration_notification = self._register_engine,
181 182 unregistration_notification = self._unregister_engine
182 183 )
183 184 self.notifier_stream.on_recv(self.dispatch_notification)
184 185 self.auditor = ioloop.PeriodicCallback(self.audit_timeouts, 2e3, self.loop) # 1 Hz
185 186 self.auditor.start()
186 187 self.log.info("Scheduler started [%s]"%self.scheme_name)
187 188
188 189 def resume_receiving(self):
189 190 """Resume accepting jobs."""
190 191 self.client_stream.on_recv(self.dispatch_submission, copy=False)
191 192
192 193 def stop_receiving(self):
193 194 """Stop accepting jobs while there are no engines.
194 195 Leave them in the ZMQ queue."""
195 196 self.client_stream.on_recv(None)
196 197
197 198 #-----------------------------------------------------------------------
198 199 # [Un]Registration Handling
199 200 #-----------------------------------------------------------------------
200 201
201 202 def dispatch_notification(self, msg):
202 203 """dispatch register/unregister events."""
203 204 try:
204 205 idents,msg = self.session.feed_identities(msg)
205 206 except ValueError:
206 207 self.log.warn("task::Invalid Message: %r"%msg)
207 208 return
208 209 try:
209 210 msg = self.session.unpack_message(msg)
210 211 except ValueError:
211 212 self.log.warn("task::Unauthorized message from: %r"%idents)
212 213 return
213 214
214 215 msg_type = msg['msg_type']
215 216
216 217 handler = self._notification_handlers.get(msg_type, None)
217 218 if handler is None:
218 219 self.log.error("Unhandled message type: %r"%msg_type)
219 220 else:
220 221 try:
221 222 handler(str(msg['content']['queue']))
222 223 except KeyError:
223 224 self.log.error("task::Invalid notification msg: %r"%msg)
224 225
225 226 @logged
226 227 def _register_engine(self, uid):
227 228 """New engine with ident `uid` became available."""
228 229 # head of the line:
229 230 self.targets.insert(0,uid)
230 231 self.loads.insert(0,0)
231 232 # initialize sets
232 233 self.completed[uid] = set()
233 234 self.failed[uid] = set()
234 235 self.pending[uid] = {}
235 236 if len(self.targets) == 1:
236 237 self.resume_receiving()
237 238 # rescan the graph:
238 239 self.update_graph(None)
239 240
240 241 def _unregister_engine(self, uid):
241 242 """Existing engine with ident `uid` became unavailable."""
242 243 if len(self.targets) == 1:
243 244 # this was our only engine
244 245 self.stop_receiving()
245 246
246 247 # handle any potentially finished tasks:
247 248 self.engine_stream.flush()
248 249
249 250 # don't pop destinations, because they might be used later
250 251 # map(self.destinations.pop, self.completed.pop(uid))
251 252 # map(self.destinations.pop, self.failed.pop(uid))
252 253
253 254 # prevent this engine from receiving work
254 255 idx = self.targets.index(uid)
255 256 self.targets.pop(idx)
256 257 self.loads.pop(idx)
257 258
258 259 # wait 5 seconds before cleaning up pending jobs, since the results might
259 260 # still be incoming
260 261 if self.pending[uid]:
261 262 dc = ioloop.DelayedCallback(lambda : self.handle_stranded_tasks(uid), 5000, self.loop)
262 263 dc.start()
263 264 else:
264 265 self.completed.pop(uid)
265 266 self.failed.pop(uid)
266 267
267 268
268 269 @logged
269 270 def handle_stranded_tasks(self, engine):
270 271 """Deal with jobs resident in an engine that died."""
271 272 lost = self.pending[engine]
272 273 for msg_id in lost.keys():
273 274 if msg_id not in self.pending[engine]:
274 275 # prevent double-handling of messages
275 276 continue
276 277
277 278 raw_msg = lost[msg_id][0]
278 279 idents,msg = self.session.feed_identities(raw_msg, copy=False)
279 280 parent = self.session.unpack(msg[1].bytes)
280 281 idents = [engine, idents[0]]
281 282
282 283 # build fake error reply
283 284 try:
284 285 raise error.EngineError("Engine %r died while running task %r"%(engine, msg_id))
285 286 except:
286 287 content = error.wrap_exception()
287 288 msg = self.session.msg('apply_reply', content, parent=parent, subheader={'status':'error'})
288 289 raw_reply = map(zmq.Message, self.session.serialize(msg, ident=idents))
289 290 # and dispatch it
290 291 self.dispatch_result(raw_reply)
291 292
292 293 # finally scrub completed/failed lists
293 294 self.completed.pop(engine)
294 295 self.failed.pop(engine)
295 296
296 297
297 298 #-----------------------------------------------------------------------
298 299 # Job Submission
299 300 #-----------------------------------------------------------------------
300 301 @logged
301 302 def dispatch_submission(self, raw_msg):
302 303 """Dispatch job submission to appropriate handlers."""
303 304 # ensure targets up to date:
304 305 self.notifier_stream.flush()
305 306 try:
306 307 idents, msg = self.session.feed_identities(raw_msg, copy=False)
307 308 msg = self.session.unpack_message(msg, content=False, copy=False)
308 309 except Exception:
309 310 self.log.error("task::Invaid task msg: %r"%raw_msg, exc_info=True)
310 311 return
311 312
312 313
313 314 # send to monitor
314 315 self.mon_stream.send_multipart(['intask']+raw_msg, copy=False)
315 316
316 317 header = msg['header']
317 318 msg_id = header['msg_id']
318 319 self.all_ids.add(msg_id)
319 320
320 321 # targets
321 322 targets = set(header.get('targets', []))
322 323 retries = header.get('retries', 0)
323 324 self.retries[msg_id] = retries
324 325
325 326 # time dependencies
326 327 after = Dependency(header.get('after', []))
327 328 if after.all:
328 329 if after.success:
329 330 after.difference_update(self.all_completed)
330 331 if after.failure:
331 332 after.difference_update(self.all_failed)
332 333 if after.check(self.all_completed, self.all_failed):
333 334 # recast as empty set, if `after` already met,
334 335 # to prevent unnecessary set comparisons
335 336 after = MET
336 337
337 338 # location dependencies
338 339 follow = Dependency(header.get('follow', []))
339 340
340 341 # turn timeouts into datetime objects:
341 342 timeout = header.get('timeout', None)
342 343 if timeout:
343 344 timeout = datetime.now() + timedelta(0,timeout,0)
344 345
345 346 args = [raw_msg, targets, after, follow, timeout]
346 347
347 348 # validate and reduce dependencies:
348 349 for dep in after,follow:
349 350 # check valid:
350 351 if msg_id in dep or dep.difference(self.all_ids):
351 352 self.depending[msg_id] = args
352 353 return self.fail_unreachable(msg_id, error.InvalidDependency)
353 354 # check if unreachable:
354 355 if dep.unreachable(self.all_completed, self.all_failed):
355 356 self.depending[msg_id] = args
356 357 return self.fail_unreachable(msg_id)
357 358
358 359 if after.check(self.all_completed, self.all_failed):
359 360 # time deps already met, try to run
360 361 if not self.maybe_run(msg_id, *args):
361 362 # can't run yet
362 363 if msg_id not in self.all_failed:
363 364 # could have failed as unreachable
364 365 self.save_unmet(msg_id, *args)
365 366 else:
366 367 self.save_unmet(msg_id, *args)
367 368
368 369 # @logged
369 370 def audit_timeouts(self):
370 371 """Audit all waiting tasks for expired timeouts."""
371 372 now = datetime.now()
372 373 for msg_id in self.depending.keys():
373 374 # must recheck, in case one failure cascaded to another:
374 375 if msg_id in self.depending:
375 376 raw,after,targets,follow,timeout = self.depending[msg_id]
376 377 if timeout and timeout < now:
377 378 self.fail_unreachable(msg_id, error.TaskTimeout)
378 379
379 380 @logged
380 381 def fail_unreachable(self, msg_id, why=error.ImpossibleDependency):
381 382 """a task has become unreachable, send a reply with an ImpossibleDependency
382 383 error."""
383 384 if msg_id not in self.depending:
384 385 self.log.error("msg %r already failed!"%msg_id)
385 386 return
386 387 raw_msg,targets,after,follow,timeout = self.depending.pop(msg_id)
387 388 for mid in follow.union(after):
388 389 if mid in self.graph:
389 390 self.graph[mid].remove(msg_id)
390 391
391 392 # FIXME: unpacking a message I've already unpacked, but didn't save:
392 393 idents,msg = self.session.feed_identities(raw_msg, copy=False)
393 394 header = self.session.unpack(msg[1].bytes)
394 395
395 396 try:
396 397 raise why()
397 398 except:
398 399 content = error.wrap_exception()
399 400
400 401 self.all_done.add(msg_id)
401 402 self.all_failed.add(msg_id)
402 403
403 404 msg = self.session.send(self.client_stream, 'apply_reply', content,
404 405 parent=header, ident=idents)
405 406 self.session.send(self.mon_stream, msg, ident=['outtask']+idents)
406 407
407 408 self.update_graph(msg_id, success=False)
408 409
409 410 @logged
410 411 def maybe_run(self, msg_id, raw_msg, targets, after, follow, timeout):
411 412 """check location dependencies, and run if they are met."""
412 413 blacklist = self.blacklist.setdefault(msg_id, set())
413 414 if follow or targets or blacklist or self.hwm:
414 415 # we need a can_run filter
415 416 def can_run(idx):
416 417 # check hwm
417 418 if self.hwm and self.loads[idx] == self.hwm:
418 419 return False
419 420 target = self.targets[idx]
420 421 # check blacklist
421 422 if target in blacklist:
422 423 return False
423 424 # check targets
424 425 if targets and target not in targets:
425 426 return False
426 427 # check follow
427 428 return follow.check(self.completed[target], self.failed[target])
428 429
429 430 indices = filter(can_run, range(len(self.targets)))
430 431
431 432 if not indices:
432 433 # couldn't run
433 434 if follow.all:
434 435 # check follow for impossibility
435 436 dests = set()
436 437 relevant = set()
437 438 if follow.success:
438 439 relevant = self.all_completed
439 440 if follow.failure:
440 441 relevant = relevant.union(self.all_failed)
441 442 for m in follow.intersection(relevant):
442 443 dests.add(self.destinations[m])
443 444 if len(dests) > 1:
444 445 self.depending[msg_id] = (raw_msg, targets, after, follow, timeout)
445 446 self.fail_unreachable(msg_id)
446 447 return False
447 448 if targets:
448 449 # check blacklist+targets for impossibility
449 450 targets.difference_update(blacklist)
450 451 if not targets or not targets.intersection(self.targets):
451 452 self.depending[msg_id] = (raw_msg, targets, after, follow, timeout)
452 453 self.fail_unreachable(msg_id)
453 454 return False
454 455 return False
455 456 else:
456 457 indices = None
457 458
458 459 self.submit_task(msg_id, raw_msg, targets, follow, timeout, indices)
459 460 return True
460 461
461 462 @logged
462 463 def save_unmet(self, msg_id, raw_msg, targets, after, follow, timeout):
463 464 """Save a message for later submission when its dependencies are met."""
464 465 self.depending[msg_id] = [raw_msg,targets,after,follow,timeout]
465 466 # track the ids in follow or after, but not those already finished
466 467 for dep_id in after.union(follow).difference(self.all_done):
467 468 if dep_id not in self.graph:
468 469 self.graph[dep_id] = set()
469 470 self.graph[dep_id].add(msg_id)
470 471
471 472 @logged
472 473 def submit_task(self, msg_id, raw_msg, targets, follow, timeout, indices=None):
473 474 """Submit a task to any of a subset of our targets."""
474 475 if indices:
475 476 loads = [self.loads[i] for i in indices]
476 477 else:
477 478 loads = self.loads
478 479 idx = self.scheme(loads)
479 480 if indices:
480 481 idx = indices[idx]
481 482 target = self.targets[idx]
482 483 # print (target, map(str, msg[:3]))
483 484 # send job to the engine
484 485 self.engine_stream.send(target, flags=zmq.SNDMORE, copy=False)
485 486 self.engine_stream.send_multipart(raw_msg, copy=False)
486 487 # update load
487 488 self.add_job(idx)
488 489 self.pending[target][msg_id] = (raw_msg, targets, MET, follow, timeout)
489 490 # notify Hub
490 491 content = dict(msg_id=msg_id, engine_id=target)
491 492 self.session.send(self.mon_stream, 'task_destination', content=content,
492 493 ident=['tracktask',self.session.session])
493 494
494 495
495 496 #-----------------------------------------------------------------------
496 497 # Result Handling
497 498 #-----------------------------------------------------------------------
498 499 @logged
499 500 def dispatch_result(self, raw_msg):
500 501 """dispatch method for result replies"""
501 502 try:
502 503 idents,msg = self.session.feed_identities(raw_msg, copy=False)
503 504 msg = self.session.unpack_message(msg, content=False, copy=False)
504 505 engine = idents[0]
505 506 try:
506 507 idx = self.targets.index(engine)
507 508 except ValueError:
508 509 pass # skip load-update for dead engines
509 510 else:
510 511 self.finish_job(idx)
511 512 except Exception:
512 513 self.log.error("task::Invaid result: %r"%raw_msg, exc_info=True)
513 514 return
514 515
515 516 header = msg['header']
516 517 parent = msg['parent_header']
517 518 if header.get('dependencies_met', True):
518 519 success = (header['status'] == 'ok')
519 520 msg_id = parent['msg_id']
520 521 retries = self.retries[msg_id]
521 522 if not success and retries > 0:
522 523 # failed
523 524 self.retries[msg_id] = retries - 1
524 525 self.handle_unmet_dependency(idents, parent)
525 526 else:
526 527 del self.retries[msg_id]
527 528 # relay to client and update graph
528 529 self.handle_result(idents, parent, raw_msg, success)
529 530 # send to Hub monitor
530 531 self.mon_stream.send_multipart(['outtask']+raw_msg, copy=False)
531 532 else:
532 533 self.handle_unmet_dependency(idents, parent)
533 534
534 535 @logged
535 536 def handle_result(self, idents, parent, raw_msg, success=True):
536 537 """handle a real task result, either success or failure"""
537 538 # first, relay result to client
538 539 engine = idents[0]
539 540 client = idents[1]
540 541 # swap_ids for XREP-XREP mirror
541 542 raw_msg[:2] = [client,engine]
542 543 # print (map(str, raw_msg[:4]))
543 544 self.client_stream.send_multipart(raw_msg, copy=False)
544 545 # now, update our data structures
545 546 msg_id = parent['msg_id']
546 547 self.blacklist.pop(msg_id, None)
547 548 self.pending[engine].pop(msg_id)
548 549 if success:
549 550 self.completed[engine].add(msg_id)
550 551 self.all_completed.add(msg_id)
551 552 else:
552 553 self.failed[engine].add(msg_id)
553 554 self.all_failed.add(msg_id)
554 555 self.all_done.add(msg_id)
555 556 self.destinations[msg_id] = engine
556 557
557 558 self.update_graph(msg_id, success)
558 559
559 560 @logged
560 561 def handle_unmet_dependency(self, idents, parent):
561 562 """handle an unmet dependency"""
562 563 engine = idents[0]
563 564 msg_id = parent['msg_id']
564 565
565 566 if msg_id not in self.blacklist:
566 567 self.blacklist[msg_id] = set()
567 568 self.blacklist[msg_id].add(engine)
568 569
569 570 args = self.pending[engine].pop(msg_id)
570 571 raw,targets,after,follow,timeout = args
571 572
572 573 if self.blacklist[msg_id] == targets:
573 574 self.depending[msg_id] = args
574 575 self.fail_unreachable(msg_id)
575 576 elif not self.maybe_run(msg_id, *args):
576 577 # resubmit failed
577 578 if msg_id not in self.all_failed:
578 579 # put it back in our dependency tree
579 580 self.save_unmet(msg_id, *args)
580 581
581 582 if self.hwm:
582 583 try:
583 584 idx = self.targets.index(engine)
584 585 except ValueError:
585 586 pass # skip load-update for dead engines
586 587 else:
587 588 if self.loads[idx] == self.hwm-1:
588 589 self.update_graph(None)
589 590
590 591
591 592
592 593 @logged
593 594 def update_graph(self, dep_id=None, success=True):
594 595 """dep_id just finished. Update our dependency
595 596 graph and submit any jobs that just became runable.
596 597
597 598 Called with dep_id=None to update entire graph for hwm, but without finishing
598 599 a task.
599 600 """
600 601 # print ("\n\n***********")
601 602 # pprint (dep_id)
602 603 # pprint (self.graph)
603 604 # pprint (self.depending)
604 605 # pprint (self.all_completed)
605 606 # pprint (self.all_failed)
606 607 # print ("\n\n***********\n\n")
607 608 # update any jobs that depended on the dependency
608 609 jobs = self.graph.pop(dep_id, [])
609 610
610 611 # recheck *all* jobs if
611 612 # a) we have HWM and an engine just become no longer full
612 613 # or b) dep_id was given as None
613 614 if dep_id is None or self.hwm and any( [ load==self.hwm-1 for load in self.loads ]):
614 615 jobs = self.depending.keys()
615 616
616 617 for msg_id in jobs:
617 618 raw_msg, targets, after, follow, timeout = self.depending[msg_id]
618 619
619 620 if after.unreachable(self.all_completed, self.all_failed)\
620 621 or follow.unreachable(self.all_completed, self.all_failed):
621 622 self.fail_unreachable(msg_id)
622 623
623 624 elif after.check(self.all_completed, self.all_failed): # time deps met, maybe run
624 625 if self.maybe_run(msg_id, raw_msg, targets, MET, follow, timeout):
625 626
626 627 self.depending.pop(msg_id)
627 628 for mid in follow.union(after):
628 629 if mid in self.graph:
629 630 self.graph[mid].remove(msg_id)
630 631
631 632 #----------------------------------------------------------------------
632 633 # methods to be overridden by subclasses
633 634 #----------------------------------------------------------------------
634 635
635 636 def add_job(self, idx):
636 637 """Called after self.targets[idx] just got the job with header.
637 638 Override with subclasses. The default ordering is simple LRU.
638 639 The default loads are the number of outstanding jobs."""
639 640 self.loads[idx] += 1
640 641 for lis in (self.targets, self.loads):
641 642 lis.append(lis.pop(idx))
642 643
643 644
644 645 def finish_job(self, idx):
645 646 """Called after self.targets[idx] just finished a job.
646 647 Override with subclasses."""
647 648 self.loads[idx] -= 1
648 649
649 650
650 651
651 652 def launch_scheduler(in_addr, out_addr, mon_addr, not_addr, config=None,
652 653 logname='root', log_url=None, loglevel=logging.DEBUG,
653 identity=b'task'):
654 from zmq.eventloop import ioloop
655 from zmq.eventloop.zmqstream import ZMQStream
654 identity=b'task', in_thread=False):
655
656 ZMQStream = zmqstream.ZMQStream
656 657
657 658 if config:
658 659 # unwrap dict back into Config
659 660 config = Config(config)
660 661
661 ctx = zmq.Context()
662 loop = ioloop.IOLoop()
662 if in_thread:
663 # use instance() to get the same Context/Loop as our parent
664 ctx = zmq.Context.instance()
665 loop = ioloop.IOLoop.instance()
666 else:
667 # in a process, don't use instance()
668 # for safety with multiprocessing
669 ctx = zmq.Context()
670 loop = ioloop.IOLoop()
663 671 ins = ZMQStream(ctx.socket(zmq.XREP),loop)
664 672 ins.setsockopt(zmq.IDENTITY, identity)
665 673 ins.bind(in_addr)
666 674
667 675 outs = ZMQStream(ctx.socket(zmq.XREP),loop)
668 676 outs.setsockopt(zmq.IDENTITY, identity)
669 677 outs.bind(out_addr)
670 mons = ZMQStream(ctx.socket(zmq.PUB),loop)
678 mons = zmqstream.ZMQStream(ctx.socket(zmq.PUB),loop)
671 679 mons.connect(mon_addr)
672 nots = ZMQStream(ctx.socket(zmq.SUB),loop)
680 nots = zmqstream.ZMQStream(ctx.socket(zmq.SUB),loop)
673 681 nots.setsockopt(zmq.SUBSCRIBE, '')
674 682 nots.connect(not_addr)
675 683
676 # setup logging. Note that these will not work in-process, because they clobber
677 # existing loggers.
678 if log_url:
679 log = connect_logger(logname, ctx, log_url, root="scheduler", loglevel=loglevel)
684 # setup logging.
685 if in_thread:
686 log = Application.instance().log
680 687 else:
681 log = local_logger(logname, loglevel)
688 if log_url:
689 log = connect_logger(logname, ctx, log_url, root="scheduler", loglevel=loglevel)
690 else:
691 log = local_logger(logname, loglevel)
682 692
683 693 scheduler = TaskScheduler(client_stream=ins, engine_stream=outs,
684 694 mon_stream=mons, notifier_stream=nots,
685 695 loop=loop, log=log,
686 696 config=config)
687 697 scheduler.start()
688 try:
689 loop.start()
690 except KeyboardInterrupt:
691 print ("interrupted, exiting...", file=sys.__stderr__)
698 if not in_thread:
699 try:
700 loop.start()
701 except KeyboardInterrupt:
702 print ("interrupted, exiting...", file=sys.__stderr__)
692 703
General Comments 0
You need to be logged in to leave comments. Login now