##// END OF EJS Templates
ipcontroller cleans up connection files unless reuse=True...
MinRK -
Show More
@@ -1,459 +1,493 b''
1 #!/usr/bin/env python
1 #!/usr/bin/env python
2 # encoding: utf-8
2 # encoding: utf-8
3 """
3 """
4 The IPython controller application.
4 The IPython controller application.
5
5
6 Authors:
6 Authors:
7
7
8 * Brian Granger
8 * Brian Granger
9 * MinRK
9 * MinRK
10
10
11 """
11 """
12
12
13 #-----------------------------------------------------------------------------
13 #-----------------------------------------------------------------------------
14 # Copyright (C) 2008-2011 The IPython Development Team
14 # Copyright (C) 2008-2011 The IPython Development Team
15 #
15 #
16 # Distributed under the terms of the BSD License. The full license is in
16 # Distributed under the terms of the BSD License. The full license is in
17 # the file COPYING, distributed as part of this software.
17 # the file COPYING, distributed as part of this software.
18 #-----------------------------------------------------------------------------
18 #-----------------------------------------------------------------------------
19
19
20 #-----------------------------------------------------------------------------
20 #-----------------------------------------------------------------------------
21 # Imports
21 # Imports
22 #-----------------------------------------------------------------------------
22 #-----------------------------------------------------------------------------
23
23
24 from __future__ import with_statement
24 from __future__ import with_statement
25
25
26 import json
26 import json
27 import os
27 import os
28 import socket
28 import socket
29 import stat
29 import stat
30 import sys
30 import sys
31
31
32 from multiprocessing import Process
32 from multiprocessing import Process
33 from signal import signal, SIGINT, SIGABRT, SIGTERM
33
34
34 import zmq
35 import zmq
35 from zmq.devices import ProcessMonitoredQueue
36 from zmq.devices import ProcessMonitoredQueue
36 from zmq.log.handlers import PUBHandler
37 from zmq.log.handlers import PUBHandler
37
38
38 from IPython.core.profiledir import ProfileDir
39 from IPython.core.profiledir import ProfileDir
39
40
40 from IPython.parallel.apps.baseapp import (
41 from IPython.parallel.apps.baseapp import (
41 BaseParallelApplication,
42 BaseParallelApplication,
42 base_aliases,
43 base_aliases,
43 base_flags,
44 base_flags,
44 catch_config_error,
45 catch_config_error,
45 )
46 )
46 from IPython.utils.importstring import import_item
47 from IPython.utils.importstring import import_item
47 from IPython.utils.traitlets import Instance, Unicode, Bool, List, Dict, TraitError
48 from IPython.utils.traitlets import Instance, Unicode, Bool, List, Dict, TraitError
48
49
49 from IPython.zmq.session import (
50 from IPython.zmq.session import (
50 Session, session_aliases, session_flags, default_secure
51 Session, session_aliases, session_flags, default_secure
51 )
52 )
52
53
53 from IPython.parallel.controller.heartmonitor import HeartMonitor
54 from IPython.parallel.controller.heartmonitor import HeartMonitor
54 from IPython.parallel.controller.hub import HubFactory
55 from IPython.parallel.controller.hub import HubFactory
55 from IPython.parallel.controller.scheduler import TaskScheduler,launch_scheduler
56 from IPython.parallel.controller.scheduler import TaskScheduler,launch_scheduler
56 from IPython.parallel.controller.sqlitedb import SQLiteDB
57 from IPython.parallel.controller.sqlitedb import SQLiteDB
57
58
58 from IPython.parallel.util import signal_children, split_url, disambiguate_url
59 from IPython.parallel.util import split_url, disambiguate_url
59
60
60 # conditional import of MongoDB backend class
61 # conditional import of MongoDB backend class
61
62
62 try:
63 try:
63 from IPython.parallel.controller.mongodb import MongoDB
64 from IPython.parallel.controller.mongodb import MongoDB
64 except ImportError:
65 except ImportError:
65 maybe_mongo = []
66 maybe_mongo = []
66 else:
67 else:
67 maybe_mongo = [MongoDB]
68 maybe_mongo = [MongoDB]
68
69
69
70
70 #-----------------------------------------------------------------------------
71 #-----------------------------------------------------------------------------
71 # Module level variables
72 # Module level variables
72 #-----------------------------------------------------------------------------
73 #-----------------------------------------------------------------------------
73
74
74
75
75 #: The default config file name for this application
76 #: The default config file name for this application
76 default_config_file_name = u'ipcontroller_config.py'
77 default_config_file_name = u'ipcontroller_config.py'
77
78
78
79
79 _description = """Start the IPython controller for parallel computing.
80 _description = """Start the IPython controller for parallel computing.
80
81
81 The IPython controller provides a gateway between the IPython engines and
82 The IPython controller provides a gateway between the IPython engines and
82 clients. The controller needs to be started before the engines and can be
83 clients. The controller needs to be started before the engines and can be
83 configured using command line options or using a cluster directory. Cluster
84 configured using command line options or using a cluster directory. Cluster
84 directories contain config, log and security files and are usually located in
85 directories contain config, log and security files and are usually located in
85 your ipython directory and named as "profile_name". See the `profile`
86 your ipython directory and named as "profile_name". See the `profile`
86 and `profile-dir` options for details.
87 and `profile-dir` options for details.
87 """
88 """
88
89
89 _examples = """
90 _examples = """
90 ipcontroller --ip=192.168.0.1 --port=1000 # listen on ip, port for engines
91 ipcontroller --ip=192.168.0.1 --port=1000 # listen on ip, port for engines
91 ipcontroller --scheme=pure # use the pure zeromq scheduler
92 ipcontroller --scheme=pure # use the pure zeromq scheduler
92 """
93 """
93
94
94
95
95 #-----------------------------------------------------------------------------
96 #-----------------------------------------------------------------------------
96 # The main application
97 # The main application
97 #-----------------------------------------------------------------------------
98 #-----------------------------------------------------------------------------
98 flags = {}
99 flags = {}
99 flags.update(base_flags)
100 flags.update(base_flags)
100 flags.update({
101 flags.update({
101 'usethreads' : ( {'IPControllerApp' : {'use_threads' : True}},
102 'usethreads' : ( {'IPControllerApp' : {'use_threads' : True}},
102 'Use threads instead of processes for the schedulers'),
103 'Use threads instead of processes for the schedulers'),
103 'sqlitedb' : ({'HubFactory' : {'db_class' : 'IPython.parallel.controller.sqlitedb.SQLiteDB'}},
104 'sqlitedb' : ({'HubFactory' : {'db_class' : 'IPython.parallel.controller.sqlitedb.SQLiteDB'}},
104 'use the SQLiteDB backend'),
105 'use the SQLiteDB backend'),
105 'mongodb' : ({'HubFactory' : {'db_class' : 'IPython.parallel.controller.mongodb.MongoDB'}},
106 'mongodb' : ({'HubFactory' : {'db_class' : 'IPython.parallel.controller.mongodb.MongoDB'}},
106 'use the MongoDB backend'),
107 'use the MongoDB backend'),
107 'dictdb' : ({'HubFactory' : {'db_class' : 'IPython.parallel.controller.dictdb.DictDB'}},
108 'dictdb' : ({'HubFactory' : {'db_class' : 'IPython.parallel.controller.dictdb.DictDB'}},
108 'use the in-memory DictDB backend'),
109 'use the in-memory DictDB backend'),
109 'nodb' : ({'HubFactory' : {'db_class' : 'IPython.parallel.controller.dictdb.NoDB'}},
110 'nodb' : ({'HubFactory' : {'db_class' : 'IPython.parallel.controller.dictdb.NoDB'}},
110 """use dummy DB backend, which doesn't store any information.
111 """use dummy DB backend, which doesn't store any information.
111
112
112 This can be used to prevent growth of the memory footprint of the Hub
113 This can be used to prevent growth of the memory footprint of the Hub
113 in cases where its record-keeping is not required. Requesting results
114 in cases where its record-keeping is not required. Requesting results
114 of tasks submitted by other clients, db_queries, and task resubmission
115 of tasks submitted by other clients, db_queries, and task resubmission
115 will not be available."""),
116 will not be available."""),
116 'reuse' : ({'IPControllerApp' : {'reuse_files' : True}},
117 'reuse' : ({'IPControllerApp' : {'reuse_files' : True}},
117 'reuse existing json connection files')
118 'reuse existing json connection files')
118 })
119 })
119
120
120 flags.update(session_flags)
121 flags.update(session_flags)
121
122
122 aliases = dict(
123 aliases = dict(
123 ssh = 'IPControllerApp.ssh_server',
124 ssh = 'IPControllerApp.ssh_server',
124 enginessh = 'IPControllerApp.engine_ssh_server',
125 enginessh = 'IPControllerApp.engine_ssh_server',
125 location = 'IPControllerApp.location',
126 location = 'IPControllerApp.location',
126
127
127 url = 'HubFactory.url',
128 url = 'HubFactory.url',
128 ip = 'HubFactory.ip',
129 ip = 'HubFactory.ip',
129 transport = 'HubFactory.transport',
130 transport = 'HubFactory.transport',
130 port = 'HubFactory.regport',
131 port = 'HubFactory.regport',
131
132
132 ping = 'HeartMonitor.period',
133 ping = 'HeartMonitor.period',
133
134
134 scheme = 'TaskScheduler.scheme_name',
135 scheme = 'TaskScheduler.scheme_name',
135 hwm = 'TaskScheduler.hwm',
136 hwm = 'TaskScheduler.hwm',
136 )
137 )
137 aliases.update(base_aliases)
138 aliases.update(base_aliases)
138 aliases.update(session_aliases)
139 aliases.update(session_aliases)
139
140
140
141
141 class IPControllerApp(BaseParallelApplication):
142 class IPControllerApp(BaseParallelApplication):
142
143
143 name = u'ipcontroller'
144 name = u'ipcontroller'
144 description = _description
145 description = _description
145 examples = _examples
146 examples = _examples
146 config_file_name = Unicode(default_config_file_name)
147 config_file_name = Unicode(default_config_file_name)
147 classes = [ProfileDir, Session, HubFactory, TaskScheduler, HeartMonitor, SQLiteDB] + maybe_mongo
148 classes = [ProfileDir, Session, HubFactory, TaskScheduler, HeartMonitor, SQLiteDB] + maybe_mongo
148
149
149 # change default to True
150 # change default to True
150 auto_create = Bool(True, config=True,
151 auto_create = Bool(True, config=True,
151 help="""Whether to create profile dir if it doesn't exist.""")
152 help="""Whether to create profile dir if it doesn't exist.""")
152
153
153 reuse_files = Bool(False, config=True,
154 reuse_files = Bool(False, config=True,
154 help='Whether to reuse existing json connection files.'
155 help="""Whether to reuse existing json connection files.
156 If False, connection files will be removed on a clean exit.
157 """
155 )
158 )
156 ssh_server = Unicode(u'', config=True,
159 ssh_server = Unicode(u'', config=True,
157 help="""ssh url for clients to use when connecting to the Controller
160 help="""ssh url for clients to use when connecting to the Controller
158 processes. It should be of the form: [user@]server[:port]. The
161 processes. It should be of the form: [user@]server[:port]. The
159 Controller's listening addresses must be accessible from the ssh server""",
162 Controller's listening addresses must be accessible from the ssh server""",
160 )
163 )
161 engine_ssh_server = Unicode(u'', config=True,
164 engine_ssh_server = Unicode(u'', config=True,
162 help="""ssh url for engines to use when connecting to the Controller
165 help="""ssh url for engines to use when connecting to the Controller
163 processes. It should be of the form: [user@]server[:port]. The
166 processes. It should be of the form: [user@]server[:port]. The
164 Controller's listening addresses must be accessible from the ssh server""",
167 Controller's listening addresses must be accessible from the ssh server""",
165 )
168 )
166 location = Unicode(u'', config=True,
169 location = Unicode(u'', config=True,
167 help="""The external IP or domain name of the Controller, used for disambiguating
170 help="""The external IP or domain name of the Controller, used for disambiguating
168 engine and client connections.""",
171 engine and client connections.""",
169 )
172 )
170 import_statements = List([], config=True,
173 import_statements = List([], config=True,
171 help="import statements to be run at startup. Necessary in some environments"
174 help="import statements to be run at startup. Necessary in some environments"
172 )
175 )
173
176
174 use_threads = Bool(False, config=True,
177 use_threads = Bool(False, config=True,
175 help='Use threads instead of processes for the schedulers',
178 help='Use threads instead of processes for the schedulers',
176 )
179 )
177
180
178 engine_json_file = Unicode('ipcontroller-engine.json', config=True,
181 engine_json_file = Unicode('ipcontroller-engine.json', config=True,
179 help="JSON filename where engine connection info will be stored.")
182 help="JSON filename where engine connection info will be stored.")
180 client_json_file = Unicode('ipcontroller-client.json', config=True,
183 client_json_file = Unicode('ipcontroller-client.json', config=True,
181 help="JSON filename where client connection info will be stored.")
184 help="JSON filename where client connection info will be stored.")
182
185
183 def _cluster_id_changed(self, name, old, new):
186 def _cluster_id_changed(self, name, old, new):
184 super(IPControllerApp, self)._cluster_id_changed(name, old, new)
187 super(IPControllerApp, self)._cluster_id_changed(name, old, new)
185 self.engine_json_file = "%s-engine.json" % self.name
188 self.engine_json_file = "%s-engine.json" % self.name
186 self.client_json_file = "%s-client.json" % self.name
189 self.client_json_file = "%s-client.json" % self.name
187
190
188
191
189 # internal
192 # internal
190 children = List()
193 children = List()
191 mq_class = Unicode('zmq.devices.ProcessMonitoredQueue')
194 mq_class = Unicode('zmq.devices.ProcessMonitoredQueue')
192
195
193 def _use_threads_changed(self, name, old, new):
196 def _use_threads_changed(self, name, old, new):
194 self.mq_class = 'zmq.devices.%sMonitoredQueue'%('Thread' if new else 'Process')
197 self.mq_class = 'zmq.devices.%sMonitoredQueue'%('Thread' if new else 'Process')
195
198
199 write_connection_files = Bool(True,
200 help="""Whether to write connection files to disk.
201 True in all cases other than runs with `reuse_files=True` *after the first*
202 """
203 )
204
196 aliases = Dict(aliases)
205 aliases = Dict(aliases)
197 flags = Dict(flags)
206 flags = Dict(flags)
198
207
199
208
200 def save_connection_dict(self, fname, cdict):
209 def save_connection_dict(self, fname, cdict):
201 """save a connection dict to json file."""
210 """save a connection dict to json file."""
202 c = self.config
211 c = self.config
203 url = cdict['url']
212 url = cdict['url']
204 location = cdict['location']
213 location = cdict['location']
205 if not location:
214 if not location:
206 try:
215 try:
207 proto,ip,port = split_url(url)
216 proto,ip,port = split_url(url)
208 except AssertionError:
217 except AssertionError:
209 pass
218 pass
210 else:
219 else:
211 try:
220 try:
212 location = socket.gethostbyname_ex(socket.gethostname())[2][-1]
221 location = socket.gethostbyname_ex(socket.gethostname())[2][-1]
213 except (socket.gaierror, IndexError):
222 except (socket.gaierror, IndexError):
214 self.log.warn("Could not identify this machine's IP, assuming 127.0.0.1."
223 self.log.warn("Could not identify this machine's IP, assuming 127.0.0.1."
215 " You may need to specify '--location=<external_ip_address>' to help"
224 " You may need to specify '--location=<external_ip_address>' to help"
216 " IPython decide when to connect via loopback.")
225 " IPython decide when to connect via loopback.")
217 location = '127.0.0.1'
226 location = '127.0.0.1'
218 cdict['location'] = location
227 cdict['location'] = location
219 fname = os.path.join(self.profile_dir.security_dir, fname)
228 fname = os.path.join(self.profile_dir.security_dir, fname)
220 self.log.info("writing connection info to %s", fname)
229 self.log.info("writing connection info to %s", fname)
221 with open(fname, 'w') as f:
230 with open(fname, 'w') as f:
222 f.write(json.dumps(cdict, indent=2))
231 f.write(json.dumps(cdict, indent=2))
223 os.chmod(fname, stat.S_IRUSR|stat.S_IWUSR)
232 os.chmod(fname, stat.S_IRUSR|stat.S_IWUSR)
224
233
225 def load_config_from_json(self):
234 def load_config_from_json(self):
226 """load config from existing json connector files."""
235 """load config from existing json connector files."""
227 c = self.config
236 c = self.config
228 self.log.debug("loading config from JSON")
237 self.log.debug("loading config from JSON")
229 # load from engine config
238 # load from engine config
230 fname = os.path.join(self.profile_dir.security_dir, self.engine_json_file)
239 fname = os.path.join(self.profile_dir.security_dir, self.engine_json_file)
231 self.log.info("loading connection info from %s", fname)
240 self.log.info("loading connection info from %s", fname)
232 with open(fname) as f:
241 with open(fname) as f:
233 cfg = json.loads(f.read())
242 cfg = json.loads(f.read())
234 key = cfg['exec_key']
243 key = cfg['exec_key']
235 # json gives unicode, Session.key wants bytes
244 # json gives unicode, Session.key wants bytes
236 c.Session.key = key.encode('ascii')
245 c.Session.key = key.encode('ascii')
237 xport,addr = cfg['url'].split('://')
246 xport,addr = cfg['url'].split('://')
238 c.HubFactory.engine_transport = xport
247 c.HubFactory.engine_transport = xport
239 ip,ports = addr.split(':')
248 ip,ports = addr.split(':')
240 c.HubFactory.engine_ip = ip
249 c.HubFactory.engine_ip = ip
241 c.HubFactory.regport = int(ports)
250 c.HubFactory.regport = int(ports)
242 self.location = cfg['location']
251 self.location = cfg['location']
243 if not self.engine_ssh_server:
252 if not self.engine_ssh_server:
244 self.engine_ssh_server = cfg['ssh']
253 self.engine_ssh_server = cfg['ssh']
245 # load client config
254 # load client config
246 fname = os.path.join(self.profile_dir.security_dir, self.client_json_file)
255 fname = os.path.join(self.profile_dir.security_dir, self.client_json_file)
247 self.log.info("loading connection info from %s", fname)
256 self.log.info("loading connection info from %s", fname)
248 with open(fname) as f:
257 with open(fname) as f:
249 cfg = json.loads(f.read())
258 cfg = json.loads(f.read())
250 assert key == cfg['exec_key'], "exec_key mismatch between engine and client keys"
259 assert key == cfg['exec_key'], "exec_key mismatch between engine and client keys"
251 xport,addr = cfg['url'].split('://')
260 xport,addr = cfg['url'].split('://')
252 c.HubFactory.client_transport = xport
261 c.HubFactory.client_transport = xport
253 ip,ports = addr.split(':')
262 ip,ports = addr.split(':')
254 c.HubFactory.client_ip = ip
263 c.HubFactory.client_ip = ip
255 if not self.ssh_server:
264 if not self.ssh_server:
256 self.ssh_server = cfg['ssh']
265 self.ssh_server = cfg['ssh']
257 assert int(ports) == c.HubFactory.regport, "regport mismatch"
266 assert int(ports) == c.HubFactory.regport, "regport mismatch"
258
267
268 def cleanup_connection_files(self):
269 if self.reuse_files:
270 self.log.debug("leaving JSON connection files for reuse")
271 return
272 self.log.debug("cleaning up JSON connection files")
273 for f in (self.client_json_file, self.engine_json_file):
274 f = os.path.join(self.profile_dir.security_dir, f)
275 try:
276 os.remove(f)
277 except Exception as e:
278 self.log.error("Failed to cleanup connection file: %s", e)
279 else:
280 self.log.debug(u"removed %s", f)
281
259 def load_secondary_config(self):
282 def load_secondary_config(self):
260 """secondary config, loading from JSON and setting defaults"""
283 """secondary config, loading from JSON and setting defaults"""
261 if self.reuse_files:
284 if self.reuse_files:
262 try:
285 try:
263 self.load_config_from_json()
286 self.load_config_from_json()
264 except (AssertionError,IOError) as e:
287 except (AssertionError,IOError) as e:
265 self.log.error("Could not load config from JSON: %s" % e)
288 self.log.error("Could not load config from JSON: %s" % e)
266 self.reuse_files=False
289 else:
290 # successfully loaded config from JSON, and reuse=True
291 # no need to wite back the same file
292 self.write_connection_files = False
293
267 # switch Session.key default to secure
294 # switch Session.key default to secure
268 default_secure(self.config)
295 default_secure(self.config)
269 self.log.debug("Config changed")
296 self.log.debug("Config changed")
270 self.log.debug(repr(self.config))
297 self.log.debug(repr(self.config))
271
298
272 def init_hub(self):
299 def init_hub(self):
273 c = self.config
300 c = self.config
274
301
275 self.do_import_statements()
302 self.do_import_statements()
276
303
277 try:
304 try:
278 self.factory = HubFactory(config=c, log=self.log)
305 self.factory = HubFactory(config=c, log=self.log)
279 # self.start_logging()
306 # self.start_logging()
280 self.factory.init_hub()
307 self.factory.init_hub()
281 except TraitError:
308 except TraitError:
282 raise
309 raise
283 except Exception:
310 except Exception:
284 self.log.error("Couldn't construct the Controller", exc_info=True)
311 self.log.error("Couldn't construct the Controller", exc_info=True)
285 self.exit(1)
312 self.exit(1)
286
313
287 if not self.reuse_files:
314 if self.write_connection_files:
288 # save to new json config files
315 # save to new json config files
289 f = self.factory
316 f = self.factory
290 cdict = {'exec_key' : f.session.key.decode('ascii'),
317 cdict = {'exec_key' : f.session.key.decode('ascii'),
291 'ssh' : self.ssh_server,
318 'ssh' : self.ssh_server,
292 'url' : "%s://%s:%s"%(f.client_transport, f.client_ip, f.regport),
319 'url' : "%s://%s:%s"%(f.client_transport, f.client_ip, f.regport),
293 'location' : self.location
320 'location' : self.location
294 }
321 }
295 self.save_connection_dict(self.client_json_file, cdict)
322 self.save_connection_dict(self.client_json_file, cdict)
296 edict = cdict
323 edict = cdict
297 edict['url']="%s://%s:%s"%((f.client_transport, f.client_ip, f.regport))
324 edict['url']="%s://%s:%s"%((f.client_transport, f.client_ip, f.regport))
298 edict['ssh'] = self.engine_ssh_server
325 edict['ssh'] = self.engine_ssh_server
299 self.save_connection_dict(self.engine_json_file, edict)
326 self.save_connection_dict(self.engine_json_file, edict)
300
327
301 #
302 def init_schedulers(self):
328 def init_schedulers(self):
303 children = self.children
329 children = self.children
304 mq = import_item(str(self.mq_class))
330 mq = import_item(str(self.mq_class))
305
331
306 hub = self.factory
332 hub = self.factory
307 # disambiguate url, in case of *
333 # disambiguate url, in case of *
308 monitor_url = disambiguate_url(hub.monitor_url)
334 monitor_url = disambiguate_url(hub.monitor_url)
309 # maybe_inproc = 'inproc://monitor' if self.use_threads else monitor_url
335 # maybe_inproc = 'inproc://monitor' if self.use_threads else monitor_url
310 # IOPub relay (in a Process)
336 # IOPub relay (in a Process)
311 q = mq(zmq.PUB, zmq.SUB, zmq.PUB, b'N/A',b'iopub')
337 q = mq(zmq.PUB, zmq.SUB, zmq.PUB, b'N/A',b'iopub')
312 q.bind_in(hub.client_info['iopub'])
338 q.bind_in(hub.client_info['iopub'])
313 q.bind_out(hub.engine_info['iopub'])
339 q.bind_out(hub.engine_info['iopub'])
314 q.setsockopt_out(zmq.SUBSCRIBE, b'')
340 q.setsockopt_out(zmq.SUBSCRIBE, b'')
315 q.connect_mon(monitor_url)
341 q.connect_mon(monitor_url)
316 q.daemon=True
342 q.daemon=True
317 children.append(q)
343 children.append(q)
318
344
319 # Multiplexer Queue (in a Process)
345 # Multiplexer Queue (in a Process)
320 q = mq(zmq.ROUTER, zmq.ROUTER, zmq.PUB, b'in', b'out')
346 q = mq(zmq.ROUTER, zmq.ROUTER, zmq.PUB, b'in', b'out')
321 q.bind_in(hub.client_info['mux'])
347 q.bind_in(hub.client_info['mux'])
322 q.setsockopt_in(zmq.IDENTITY, b'mux')
348 q.setsockopt_in(zmq.IDENTITY, b'mux')
323 q.bind_out(hub.engine_info['mux'])
349 q.bind_out(hub.engine_info['mux'])
324 q.connect_mon(monitor_url)
350 q.connect_mon(monitor_url)
325 q.daemon=True
351 q.daemon=True
326 children.append(q)
352 children.append(q)
327
353
328 # Control Queue (in a Process)
354 # Control Queue (in a Process)
329 q = mq(zmq.ROUTER, zmq.ROUTER, zmq.PUB, b'incontrol', b'outcontrol')
355 q = mq(zmq.ROUTER, zmq.ROUTER, zmq.PUB, b'incontrol', b'outcontrol')
330 q.bind_in(hub.client_info['control'])
356 q.bind_in(hub.client_info['control'])
331 q.setsockopt_in(zmq.IDENTITY, b'control')
357 q.setsockopt_in(zmq.IDENTITY, b'control')
332 q.bind_out(hub.engine_info['control'])
358 q.bind_out(hub.engine_info['control'])
333 q.connect_mon(monitor_url)
359 q.connect_mon(monitor_url)
334 q.daemon=True
360 q.daemon=True
335 children.append(q)
361 children.append(q)
336 try:
362 try:
337 scheme = self.config.TaskScheduler.scheme_name
363 scheme = self.config.TaskScheduler.scheme_name
338 except AttributeError:
364 except AttributeError:
339 scheme = TaskScheduler.scheme_name.get_default_value()
365 scheme = TaskScheduler.scheme_name.get_default_value()
340 # Task Queue (in a Process)
366 # Task Queue (in a Process)
341 if scheme == 'pure':
367 if scheme == 'pure':
342 self.log.warn("task::using pure XREQ Task scheduler")
368 self.log.warn("task::using pure XREQ Task scheduler")
343 q = mq(zmq.ROUTER, zmq.DEALER, zmq.PUB, b'intask', b'outtask')
369 q = mq(zmq.ROUTER, zmq.DEALER, zmq.PUB, b'intask', b'outtask')
344 # q.setsockopt_out(zmq.HWM, hub.hwm)
370 # q.setsockopt_out(zmq.HWM, hub.hwm)
345 q.bind_in(hub.client_info['task'][1])
371 q.bind_in(hub.client_info['task'][1])
346 q.setsockopt_in(zmq.IDENTITY, b'task')
372 q.setsockopt_in(zmq.IDENTITY, b'task')
347 q.bind_out(hub.engine_info['task'])
373 q.bind_out(hub.engine_info['task'])
348 q.connect_mon(monitor_url)
374 q.connect_mon(monitor_url)
349 q.daemon=True
375 q.daemon=True
350 children.append(q)
376 children.append(q)
351 elif scheme == 'none':
377 elif scheme == 'none':
352 self.log.warn("task::using no Task scheduler")
378 self.log.warn("task::using no Task scheduler")
353
379
354 else:
380 else:
355 self.log.info("task::using Python %s Task scheduler"%scheme)
381 self.log.info("task::using Python %s Task scheduler"%scheme)
356 sargs = (hub.client_info['task'][1], hub.engine_info['task'],
382 sargs = (hub.client_info['task'][1], hub.engine_info['task'],
357 monitor_url, disambiguate_url(hub.client_info['notification']))
383 monitor_url, disambiguate_url(hub.client_info['notification']))
358 kwargs = dict(logname='scheduler', loglevel=self.log_level,
384 kwargs = dict(logname='scheduler', loglevel=self.log_level,
359 log_url = self.log_url, config=dict(self.config))
385 log_url = self.log_url, config=dict(self.config))
360 if 'Process' in self.mq_class:
386 if 'Process' in self.mq_class:
361 # run the Python scheduler in a Process
387 # run the Python scheduler in a Process
362 q = Process(target=launch_scheduler, args=sargs, kwargs=kwargs)
388 q = Process(target=launch_scheduler, args=sargs, kwargs=kwargs)
363 q.daemon=True
389 q.daemon=True
364 children.append(q)
390 children.append(q)
365 else:
391 else:
366 # single-threaded Controller
392 # single-threaded Controller
367 kwargs['in_thread'] = True
393 kwargs['in_thread'] = True
368 launch_scheduler(*sargs, **kwargs)
394 launch_scheduler(*sargs, **kwargs)
369
395
396 def terminate_children(self):
397 child_procs = []
398 for child in self.children:
399 if isinstance(child, ProcessMonitoredQueue):
400 child_procs.append(child.launcher)
401 elif isinstance(child, Process):
402 child_procs.append(child)
403 if child_procs:
404 self.log.critical("terminating children...")
405 for child in child_procs:
406 try:
407 child.terminate()
408 except OSError:
409 # already dead
410 pass
370
411
371 def save_urls(self):
412 def handle_signal(self, sig, frame):
372 """save the registration urls to files."""
413 self.log.critical("Received signal %i, shutting down", sig)
373 c = self.config
414 self.terminate_children()
374
415 self.loop.stop()
375 sec_dir = self.profile_dir.security_dir
376 cf = self.factory
377
378 with open(os.path.join(sec_dir, 'ipcontroller-engine.url'), 'w') as f:
379 f.write("%s://%s:%s"%(cf.engine_transport, cf.engine_ip, cf.regport))
380
381 with open(os.path.join(sec_dir, 'ipcontroller-client.url'), 'w') as f:
382 f.write("%s://%s:%s"%(cf.client_transport, cf.client_ip, cf.regport))
383
416
417 def init_signal(self):
418 for sig in (SIGINT, SIGABRT, SIGTERM):
419 signal(sig, self.handle_signal)
384
420
385 def do_import_statements(self):
421 def do_import_statements(self):
386 statements = self.import_statements
422 statements = self.import_statements
387 for s in statements:
423 for s in statements:
388 try:
424 try:
389 self.log.msg("Executing statement: '%s'" % s)
425 self.log.msg("Executing statement: '%s'" % s)
390 exec s in globals(), locals()
426 exec s in globals(), locals()
391 except:
427 except:
392 self.log.msg("Error running statement: %s" % s)
428 self.log.msg("Error running statement: %s" % s)
393
429
394 def forward_logging(self):
430 def forward_logging(self):
395 if self.log_url:
431 if self.log_url:
396 self.log.info("Forwarding logging to %s"%self.log_url)
432 self.log.info("Forwarding logging to %s"%self.log_url)
397 context = zmq.Context.instance()
433 context = zmq.Context.instance()
398 lsock = context.socket(zmq.PUB)
434 lsock = context.socket(zmq.PUB)
399 lsock.connect(self.log_url)
435 lsock.connect(self.log_url)
400 handler = PUBHandler(lsock)
436 handler = PUBHandler(lsock)
401 self.log.removeHandler(self._log_handler)
437 self.log.removeHandler(self._log_handler)
402 handler.root_topic = 'controller'
438 handler.root_topic = 'controller'
403 handler.setLevel(self.log_level)
439 handler.setLevel(self.log_level)
404 self.log.addHandler(handler)
440 self.log.addHandler(handler)
405 self._log_handler = handler
441 self._log_handler = handler
406
442
407 @catch_config_error
443 @catch_config_error
408 def initialize(self, argv=None):
444 def initialize(self, argv=None):
409 super(IPControllerApp, self).initialize(argv)
445 super(IPControllerApp, self).initialize(argv)
410 self.forward_logging()
446 self.forward_logging()
411 self.load_secondary_config()
447 self.load_secondary_config()
412 self.init_hub()
448 self.init_hub()
413 self.init_schedulers()
449 self.init_schedulers()
414
450
415 def start(self):
451 def start(self):
416 # Start the subprocesses:
452 # Start the subprocesses:
417 self.factory.start()
453 self.factory.start()
418 child_procs = []
454 # children must be started before signals are setup,
455 # otherwise signal-handling will fire multiple times
419 for child in self.children:
456 for child in self.children:
420 child.start()
457 child.start()
421 if isinstance(child, ProcessMonitoredQueue):
458 self.init_signal()
422 child_procs.append(child.launcher)
423 elif isinstance(child, Process):
424 child_procs.append(child)
425 if child_procs:
426 signal_children(child_procs)
427
459
428 self.write_pid_file(overwrite=True)
460 self.write_pid_file(overwrite=True)
429
461
430 try:
462 try:
431 self.factory.loop.start()
463 self.factory.loop.start()
432 except KeyboardInterrupt:
464 except KeyboardInterrupt:
433 self.log.critical("Interrupted, Exiting...\n")
465 self.log.critical("Interrupted, Exiting...\n")
466 finally:
467 self.cleanup_connection_files()
434
468
435
469
436
470
437 def launch_new_instance():
471 def launch_new_instance():
438 """Create and run the IPython controller"""
472 """Create and run the IPython controller"""
439 if sys.platform == 'win32':
473 if sys.platform == 'win32':
440 # make sure we don't get called from a multiprocessing subprocess
474 # make sure we don't get called from a multiprocessing subprocess
441 # this can result in infinite Controllers being started on Windows
475 # this can result in infinite Controllers being started on Windows
442 # which doesn't have a proper fork, so multiprocessing is wonky
476 # which doesn't have a proper fork, so multiprocessing is wonky
443
477
444 # this only comes up when IPython has been installed using vanilla
478 # this only comes up when IPython has been installed using vanilla
445 # setuptools, and *not* distribute.
479 # setuptools, and *not* distribute.
446 import multiprocessing
480 import multiprocessing
447 p = multiprocessing.current_process()
481 p = multiprocessing.current_process()
448 # the main process has name 'MainProcess'
482 # the main process has name 'MainProcess'
449 # subprocesses will have names like 'Process-1'
483 # subprocesses will have names like 'Process-1'
450 if p.name != 'MainProcess':
484 if p.name != 'MainProcess':
451 # we are a subprocess, don't start another Controller!
485 # we are a subprocess, don't start another Controller!
452 return
486 return
453 app = IPControllerApp.instance()
487 app = IPControllerApp.instance()
454 app.initialize()
488 app.initialize()
455 app.start()
489 app.start()
456
490
457
491
458 if __name__ == '__main__':
492 if __name__ == '__main__':
459 launch_new_instance()
493 launch_new_instance()
@@ -1,726 +1,726 b''
1 """The Python scheduler for rich scheduling.
1 """The Python scheduler for rich scheduling.
2
2
3 The Pure ZMQ scheduler does not allow routing schemes other than LRU,
3 The Pure ZMQ scheduler does not allow routing schemes other than LRU,
4 nor does it check msg_id DAG dependencies. For those, a slightly slower
4 nor does it check msg_id DAG dependencies. For those, a slightly slower
5 Python Scheduler exists.
5 Python Scheduler exists.
6
6
7 Authors:
7 Authors:
8
8
9 * Min RK
9 * Min RK
10 """
10 """
11 #-----------------------------------------------------------------------------
11 #-----------------------------------------------------------------------------
12 # Copyright (C) 2010-2011 The IPython Development Team
12 # Copyright (C) 2010-2011 The IPython Development Team
13 #
13 #
14 # Distributed under the terms of the BSD License. The full license is in
14 # Distributed under the terms of the BSD License. The full license is in
15 # the file COPYING, distributed as part of this software.
15 # the file COPYING, distributed as part of this software.
16 #-----------------------------------------------------------------------------
16 #-----------------------------------------------------------------------------
17
17
18 #----------------------------------------------------------------------
18 #----------------------------------------------------------------------
19 # Imports
19 # Imports
20 #----------------------------------------------------------------------
20 #----------------------------------------------------------------------
21
21
22 from __future__ import print_function
22 from __future__ import print_function
23
23
24 import logging
24 import logging
25 import sys
25 import sys
26
26
27 from datetime import datetime, timedelta
27 from datetime import datetime, timedelta
28 from random import randint, random
28 from random import randint, random
29 from types import FunctionType
29 from types import FunctionType
30
30
31 try:
31 try:
32 import numpy
32 import numpy
33 except ImportError:
33 except ImportError:
34 numpy = None
34 numpy = None
35
35
36 import zmq
36 import zmq
37 from zmq.eventloop import ioloop, zmqstream
37 from zmq.eventloop import ioloop, zmqstream
38
38
39 # local imports
39 # local imports
40 from IPython.external.decorator import decorator
40 from IPython.external.decorator import decorator
41 from IPython.config.application import Application
41 from IPython.config.application import Application
42 from IPython.config.loader import Config
42 from IPython.config.loader import Config
43 from IPython.utils.traitlets import Instance, Dict, List, Set, Integer, Enum, CBytes
43 from IPython.utils.traitlets import Instance, Dict, List, Set, Integer, Enum, CBytes
44
44
45 from IPython.parallel import error
45 from IPython.parallel import error
46 from IPython.parallel.factory import SessionFactory
46 from IPython.parallel.factory import SessionFactory
47 from IPython.parallel.util import connect_logger, local_logger, asbytes
47 from IPython.parallel.util import connect_logger, local_logger, asbytes
48
48
49 from .dependency import Dependency
49 from .dependency import Dependency
50
50
51 @decorator
51 @decorator
52 def logged(f,self,*args,**kwargs):
52 def logged(f,self,*args,**kwargs):
53 # print ("#--------------------")
53 # print ("#--------------------")
54 self.log.debug("scheduler::%s(*%s,**%s)", f.func_name, args, kwargs)
54 self.log.debug("scheduler::%s(*%s,**%s)", f.func_name, args, kwargs)
55 # print ("#--")
55 # print ("#--")
56 return f(self,*args, **kwargs)
56 return f(self,*args, **kwargs)
57
57
58 #----------------------------------------------------------------------
58 #----------------------------------------------------------------------
59 # Chooser functions
59 # Chooser functions
60 #----------------------------------------------------------------------
60 #----------------------------------------------------------------------
61
61
62 def plainrandom(loads):
62 def plainrandom(loads):
63 """Plain random pick."""
63 """Plain random pick."""
64 n = len(loads)
64 n = len(loads)
65 return randint(0,n-1)
65 return randint(0,n-1)
66
66
67 def lru(loads):
67 def lru(loads):
68 """Always pick the front of the line.
68 """Always pick the front of the line.
69
69
70 The content of `loads` is ignored.
70 The content of `loads` is ignored.
71
71
72 Assumes LRU ordering of loads, with oldest first.
72 Assumes LRU ordering of loads, with oldest first.
73 """
73 """
74 return 0
74 return 0
75
75
76 def twobin(loads):
76 def twobin(loads):
77 """Pick two at random, use the LRU of the two.
77 """Pick two at random, use the LRU of the two.
78
78
79 The content of loads is ignored.
79 The content of loads is ignored.
80
80
81 Assumes LRU ordering of loads, with oldest first.
81 Assumes LRU ordering of loads, with oldest first.
82 """
82 """
83 n = len(loads)
83 n = len(loads)
84 a = randint(0,n-1)
84 a = randint(0,n-1)
85 b = randint(0,n-1)
85 b = randint(0,n-1)
86 return min(a,b)
86 return min(a,b)
87
87
88 def weighted(loads):
88 def weighted(loads):
89 """Pick two at random using inverse load as weight.
89 """Pick two at random using inverse load as weight.
90
90
91 Return the less loaded of the two.
91 Return the less loaded of the two.
92 """
92 """
93 # weight 0 a million times more than 1:
93 # weight 0 a million times more than 1:
94 weights = 1./(1e-6+numpy.array(loads))
94 weights = 1./(1e-6+numpy.array(loads))
95 sums = weights.cumsum()
95 sums = weights.cumsum()
96 t = sums[-1]
96 t = sums[-1]
97 x = random()*t
97 x = random()*t
98 y = random()*t
98 y = random()*t
99 idx = 0
99 idx = 0
100 idy = 0
100 idy = 0
101 while sums[idx] < x:
101 while sums[idx] < x:
102 idx += 1
102 idx += 1
103 while sums[idy] < y:
103 while sums[idy] < y:
104 idy += 1
104 idy += 1
105 if weights[idy] > weights[idx]:
105 if weights[idy] > weights[idx]:
106 return idy
106 return idy
107 else:
107 else:
108 return idx
108 return idx
109
109
110 def leastload(loads):
110 def leastload(loads):
111 """Always choose the lowest load.
111 """Always choose the lowest load.
112
112
113 If the lowest load occurs more than once, the first
113 If the lowest load occurs more than once, the first
114 occurance will be used. If loads has LRU ordering, this means
114 occurance will be used. If loads has LRU ordering, this means
115 the LRU of those with the lowest load is chosen.
115 the LRU of those with the lowest load is chosen.
116 """
116 """
117 return loads.index(min(loads))
117 return loads.index(min(loads))
118
118
119 #---------------------------------------------------------------------
119 #---------------------------------------------------------------------
120 # Classes
120 # Classes
121 #---------------------------------------------------------------------
121 #---------------------------------------------------------------------
122 # store empty default dependency:
122 # store empty default dependency:
123 MET = Dependency([])
123 MET = Dependency([])
124
124
125 class TaskScheduler(SessionFactory):
125 class TaskScheduler(SessionFactory):
126 """Python TaskScheduler object.
126 """Python TaskScheduler object.
127
127
128 This is the simplest object that supports msg_id based
128 This is the simplest object that supports msg_id based
129 DAG dependencies. *Only* task msg_ids are checked, not
129 DAG dependencies. *Only* task msg_ids are checked, not
130 msg_ids of jobs submitted via the MUX queue.
130 msg_ids of jobs submitted via the MUX queue.
131
131
132 """
132 """
133
133
134 hwm = Integer(1, config=True,
134 hwm = Integer(1, config=True,
135 help="""specify the High Water Mark (HWM) for the downstream
135 help="""specify the High Water Mark (HWM) for the downstream
136 socket in the Task scheduler. This is the maximum number
136 socket in the Task scheduler. This is the maximum number
137 of allowed outstanding tasks on each engine.
137 of allowed outstanding tasks on each engine.
138
138
139 The default (1) means that only one task can be outstanding on each
139 The default (1) means that only one task can be outstanding on each
140 engine. Setting TaskScheduler.hwm=0 means there is no limit, and the
140 engine. Setting TaskScheduler.hwm=0 means there is no limit, and the
141 engines continue to be assigned tasks while they are working,
141 engines continue to be assigned tasks while they are working,
142 effectively hiding network latency behind computation, but can result
142 effectively hiding network latency behind computation, but can result
143 in an imbalance of work when submitting many heterogenous tasks all at
143 in an imbalance of work when submitting many heterogenous tasks all at
144 once. Any positive value greater than one is a compromise between the
144 once. Any positive value greater than one is a compromise between the
145 two.
145 two.
146
146
147 """
147 """
148 )
148 )
149 scheme_name = Enum(('leastload', 'pure', 'lru', 'plainrandom', 'weighted', 'twobin'),
149 scheme_name = Enum(('leastload', 'pure', 'lru', 'plainrandom', 'weighted', 'twobin'),
150 'leastload', config=True, allow_none=False,
150 'leastload', config=True, allow_none=False,
151 help="""select the task scheduler scheme [default: Python LRU]
151 help="""select the task scheduler scheme [default: Python LRU]
152 Options are: 'pure', 'lru', 'plainrandom', 'weighted', 'twobin','leastload'"""
152 Options are: 'pure', 'lru', 'plainrandom', 'weighted', 'twobin','leastload'"""
153 )
153 )
154 def _scheme_name_changed(self, old, new):
154 def _scheme_name_changed(self, old, new):
155 self.log.debug("Using scheme %r"%new)
155 self.log.debug("Using scheme %r"%new)
156 self.scheme = globals()[new]
156 self.scheme = globals()[new]
157
157
158 # input arguments:
158 # input arguments:
159 scheme = Instance(FunctionType) # function for determining the destination
159 scheme = Instance(FunctionType) # function for determining the destination
160 def _scheme_default(self):
160 def _scheme_default(self):
161 return leastload
161 return leastload
162 client_stream = Instance(zmqstream.ZMQStream) # client-facing stream
162 client_stream = Instance(zmqstream.ZMQStream) # client-facing stream
163 engine_stream = Instance(zmqstream.ZMQStream) # engine-facing stream
163 engine_stream = Instance(zmqstream.ZMQStream) # engine-facing stream
164 notifier_stream = Instance(zmqstream.ZMQStream) # hub-facing sub stream
164 notifier_stream = Instance(zmqstream.ZMQStream) # hub-facing sub stream
165 mon_stream = Instance(zmqstream.ZMQStream) # hub-facing pub stream
165 mon_stream = Instance(zmqstream.ZMQStream) # hub-facing pub stream
166
166
167 # internals:
167 # internals:
168 graph = Dict() # dict by msg_id of [ msg_ids that depend on key ]
168 graph = Dict() # dict by msg_id of [ msg_ids that depend on key ]
169 retries = Dict() # dict by msg_id of retries remaining (non-neg ints)
169 retries = Dict() # dict by msg_id of retries remaining (non-neg ints)
170 # waiting = List() # list of msg_ids ready to run, but haven't due to HWM
170 # waiting = List() # list of msg_ids ready to run, but haven't due to HWM
171 depending = Dict() # dict by msg_id of (msg_id, raw_msg, after, follow)
171 depending = Dict() # dict by msg_id of (msg_id, raw_msg, after, follow)
172 pending = Dict() # dict by engine_uuid of submitted tasks
172 pending = Dict() # dict by engine_uuid of submitted tasks
173 completed = Dict() # dict by engine_uuid of completed tasks
173 completed = Dict() # dict by engine_uuid of completed tasks
174 failed = Dict() # dict by engine_uuid of failed tasks
174 failed = Dict() # dict by engine_uuid of failed tasks
175 destinations = Dict() # dict by msg_id of engine_uuids where jobs ran (reverse of completed+failed)
175 destinations = Dict() # dict by msg_id of engine_uuids where jobs ran (reverse of completed+failed)
176 clients = Dict() # dict by msg_id for who submitted the task
176 clients = Dict() # dict by msg_id for who submitted the task
177 targets = List() # list of target IDENTs
177 targets = List() # list of target IDENTs
178 loads = List() # list of engine loads
178 loads = List() # list of engine loads
179 # full = Set() # set of IDENTs that have HWM outstanding tasks
179 # full = Set() # set of IDENTs that have HWM outstanding tasks
180 all_completed = Set() # set of all completed tasks
180 all_completed = Set() # set of all completed tasks
181 all_failed = Set() # set of all failed tasks
181 all_failed = Set() # set of all failed tasks
182 all_done = Set() # set of all finished tasks=union(completed,failed)
182 all_done = Set() # set of all finished tasks=union(completed,failed)
183 all_ids = Set() # set of all submitted task IDs
183 all_ids = Set() # set of all submitted task IDs
184 blacklist = Dict() # dict by msg_id of locations where a job has encountered UnmetDependency
184 blacklist = Dict() # dict by msg_id of locations where a job has encountered UnmetDependency
185 auditor = Instance('zmq.eventloop.ioloop.PeriodicCallback')
185 auditor = Instance('zmq.eventloop.ioloop.PeriodicCallback')
186
186
187 ident = CBytes() # ZMQ identity. This should just be self.session.session
187 ident = CBytes() # ZMQ identity. This should just be self.session.session
188 # but ensure Bytes
188 # but ensure Bytes
189 def _ident_default(self):
189 def _ident_default(self):
190 return self.session.bsession
190 return self.session.bsession
191
191
192 def start(self):
192 def start(self):
193 self.engine_stream.on_recv(self.dispatch_result, copy=False)
193 self.engine_stream.on_recv(self.dispatch_result, copy=False)
194 self._notification_handlers = dict(
194 self._notification_handlers = dict(
195 registration_notification = self._register_engine,
195 registration_notification = self._register_engine,
196 unregistration_notification = self._unregister_engine
196 unregistration_notification = self._unregister_engine
197 )
197 )
198 self.notifier_stream.on_recv(self.dispatch_notification)
198 self.notifier_stream.on_recv(self.dispatch_notification)
199 self.auditor = ioloop.PeriodicCallback(self.audit_timeouts, 2e3, self.loop) # 1 Hz
199 self.auditor = ioloop.PeriodicCallback(self.audit_timeouts, 2e3, self.loop) # 1 Hz
200 self.auditor.start()
200 self.auditor.start()
201 self.log.info("Scheduler started [%s]"%self.scheme_name)
201 self.log.info("Scheduler started [%s]"%self.scheme_name)
202
202
203 def resume_receiving(self):
203 def resume_receiving(self):
204 """Resume accepting jobs."""
204 """Resume accepting jobs."""
205 self.client_stream.on_recv(self.dispatch_submission, copy=False)
205 self.client_stream.on_recv(self.dispatch_submission, copy=False)
206
206
207 def stop_receiving(self):
207 def stop_receiving(self):
208 """Stop accepting jobs while there are no engines.
208 """Stop accepting jobs while there are no engines.
209 Leave them in the ZMQ queue."""
209 Leave them in the ZMQ queue."""
210 self.client_stream.on_recv(None)
210 self.client_stream.on_recv(None)
211
211
212 #-----------------------------------------------------------------------
212 #-----------------------------------------------------------------------
213 # [Un]Registration Handling
213 # [Un]Registration Handling
214 #-----------------------------------------------------------------------
214 #-----------------------------------------------------------------------
215
215
216 def dispatch_notification(self, msg):
216 def dispatch_notification(self, msg):
217 """dispatch register/unregister events."""
217 """dispatch register/unregister events."""
218 try:
218 try:
219 idents,msg = self.session.feed_identities(msg)
219 idents,msg = self.session.feed_identities(msg)
220 except ValueError:
220 except ValueError:
221 self.log.warn("task::Invalid Message: %r",msg)
221 self.log.warn("task::Invalid Message: %r",msg)
222 return
222 return
223 try:
223 try:
224 msg = self.session.unserialize(msg)
224 msg = self.session.unserialize(msg)
225 except ValueError:
225 except ValueError:
226 self.log.warn("task::Unauthorized message from: %r"%idents)
226 self.log.warn("task::Unauthorized message from: %r"%idents)
227 return
227 return
228
228
229 msg_type = msg['header']['msg_type']
229 msg_type = msg['header']['msg_type']
230
230
231 handler = self._notification_handlers.get(msg_type, None)
231 handler = self._notification_handlers.get(msg_type, None)
232 if handler is None:
232 if handler is None:
233 self.log.error("Unhandled message type: %r"%msg_type)
233 self.log.error("Unhandled message type: %r"%msg_type)
234 else:
234 else:
235 try:
235 try:
236 handler(asbytes(msg['content']['queue']))
236 handler(asbytes(msg['content']['queue']))
237 except Exception:
237 except Exception:
238 self.log.error("task::Invalid notification msg: %r", msg, exc_info=True)
238 self.log.error("task::Invalid notification msg: %r", msg, exc_info=True)
239
239
240 def _register_engine(self, uid):
240 def _register_engine(self, uid):
241 """New engine with ident `uid` became available."""
241 """New engine with ident `uid` became available."""
242 # head of the line:
242 # head of the line:
243 self.targets.insert(0,uid)
243 self.targets.insert(0,uid)
244 self.loads.insert(0,0)
244 self.loads.insert(0,0)
245
245
246 # initialize sets
246 # initialize sets
247 self.completed[uid] = set()
247 self.completed[uid] = set()
248 self.failed[uid] = set()
248 self.failed[uid] = set()
249 self.pending[uid] = {}
249 self.pending[uid] = {}
250 if len(self.targets) == 1:
250 if len(self.targets) == 1:
251 self.resume_receiving()
251 self.resume_receiving()
252 # rescan the graph:
252 # rescan the graph:
253 self.update_graph(None)
253 self.update_graph(None)
254
254
255 def _unregister_engine(self, uid):
255 def _unregister_engine(self, uid):
256 """Existing engine with ident `uid` became unavailable."""
256 """Existing engine with ident `uid` became unavailable."""
257 if len(self.targets) == 1:
257 if len(self.targets) == 1:
258 # this was our only engine
258 # this was our only engine
259 self.stop_receiving()
259 self.stop_receiving()
260
260
261 # handle any potentially finished tasks:
261 # handle any potentially finished tasks:
262 self.engine_stream.flush()
262 self.engine_stream.flush()
263
263
264 # don't pop destinations, because they might be used later
264 # don't pop destinations, because they might be used later
265 # map(self.destinations.pop, self.completed.pop(uid))
265 # map(self.destinations.pop, self.completed.pop(uid))
266 # map(self.destinations.pop, self.failed.pop(uid))
266 # map(self.destinations.pop, self.failed.pop(uid))
267
267
268 # prevent this engine from receiving work
268 # prevent this engine from receiving work
269 idx = self.targets.index(uid)
269 idx = self.targets.index(uid)
270 self.targets.pop(idx)
270 self.targets.pop(idx)
271 self.loads.pop(idx)
271 self.loads.pop(idx)
272
272
273 # wait 5 seconds before cleaning up pending jobs, since the results might
273 # wait 5 seconds before cleaning up pending jobs, since the results might
274 # still be incoming
274 # still be incoming
275 if self.pending[uid]:
275 if self.pending[uid]:
276 dc = ioloop.DelayedCallback(lambda : self.handle_stranded_tasks(uid), 5000, self.loop)
276 dc = ioloop.DelayedCallback(lambda : self.handle_stranded_tasks(uid), 5000, self.loop)
277 dc.start()
277 dc.start()
278 else:
278 else:
279 self.completed.pop(uid)
279 self.completed.pop(uid)
280 self.failed.pop(uid)
280 self.failed.pop(uid)
281
281
282
282
283 def handle_stranded_tasks(self, engine):
283 def handle_stranded_tasks(self, engine):
284 """Deal with jobs resident in an engine that died."""
284 """Deal with jobs resident in an engine that died."""
285 lost = self.pending[engine]
285 lost = self.pending[engine]
286 for msg_id in lost.keys():
286 for msg_id in lost.keys():
287 if msg_id not in self.pending[engine]:
287 if msg_id not in self.pending[engine]:
288 # prevent double-handling of messages
288 # prevent double-handling of messages
289 continue
289 continue
290
290
291 raw_msg = lost[msg_id][0]
291 raw_msg = lost[msg_id][0]
292 idents,msg = self.session.feed_identities(raw_msg, copy=False)
292 idents,msg = self.session.feed_identities(raw_msg, copy=False)
293 parent = self.session.unpack(msg[1].bytes)
293 parent = self.session.unpack(msg[1].bytes)
294 idents = [engine, idents[0]]
294 idents = [engine, idents[0]]
295
295
296 # build fake error reply
296 # build fake error reply
297 try:
297 try:
298 raise error.EngineError("Engine %r died while running task %r"%(engine, msg_id))
298 raise error.EngineError("Engine %r died while running task %r"%(engine, msg_id))
299 except:
299 except:
300 content = error.wrap_exception()
300 content = error.wrap_exception()
301 msg = self.session.msg('apply_reply', content, parent=parent, subheader={'status':'error'})
301 msg = self.session.msg('apply_reply', content, parent=parent, subheader={'status':'error'})
302 raw_reply = map(zmq.Message, self.session.serialize(msg, ident=idents))
302 raw_reply = map(zmq.Message, self.session.serialize(msg, ident=idents))
303 # and dispatch it
303 # and dispatch it
304 self.dispatch_result(raw_reply)
304 self.dispatch_result(raw_reply)
305
305
306 # finally scrub completed/failed lists
306 # finally scrub completed/failed lists
307 self.completed.pop(engine)
307 self.completed.pop(engine)
308 self.failed.pop(engine)
308 self.failed.pop(engine)
309
309
310
310
311 #-----------------------------------------------------------------------
311 #-----------------------------------------------------------------------
312 # Job Submission
312 # Job Submission
313 #-----------------------------------------------------------------------
313 #-----------------------------------------------------------------------
314 def dispatch_submission(self, raw_msg):
314 def dispatch_submission(self, raw_msg):
315 """Dispatch job submission to appropriate handlers."""
315 """Dispatch job submission to appropriate handlers."""
316 # ensure targets up to date:
316 # ensure targets up to date:
317 self.notifier_stream.flush()
317 self.notifier_stream.flush()
318 try:
318 try:
319 idents, msg = self.session.feed_identities(raw_msg, copy=False)
319 idents, msg = self.session.feed_identities(raw_msg, copy=False)
320 msg = self.session.unserialize(msg, content=False, copy=False)
320 msg = self.session.unserialize(msg, content=False, copy=False)
321 except Exception:
321 except Exception:
322 self.log.error("task::Invaid task msg: %r"%raw_msg, exc_info=True)
322 self.log.error("task::Invaid task msg: %r"%raw_msg, exc_info=True)
323 return
323 return
324
324
325
325
326 # send to monitor
326 # send to monitor
327 self.mon_stream.send_multipart([b'intask']+raw_msg, copy=False)
327 self.mon_stream.send_multipart([b'intask']+raw_msg, copy=False)
328
328
329 header = msg['header']
329 header = msg['header']
330 msg_id = header['msg_id']
330 msg_id = header['msg_id']
331 self.all_ids.add(msg_id)
331 self.all_ids.add(msg_id)
332
332
333 # get targets as a set of bytes objects
333 # get targets as a set of bytes objects
334 # from a list of unicode objects
334 # from a list of unicode objects
335 targets = header.get('targets', [])
335 targets = header.get('targets', [])
336 targets = map(asbytes, targets)
336 targets = map(asbytes, targets)
337 targets = set(targets)
337 targets = set(targets)
338
338
339 retries = header.get('retries', 0)
339 retries = header.get('retries', 0)
340 self.retries[msg_id] = retries
340 self.retries[msg_id] = retries
341
341
342 # time dependencies
342 # time dependencies
343 after = header.get('after', None)
343 after = header.get('after', None)
344 if after:
344 if after:
345 after = Dependency(after)
345 after = Dependency(after)
346 if after.all:
346 if after.all:
347 if after.success:
347 if after.success:
348 after = Dependency(after.difference(self.all_completed),
348 after = Dependency(after.difference(self.all_completed),
349 success=after.success,
349 success=after.success,
350 failure=after.failure,
350 failure=after.failure,
351 all=after.all,
351 all=after.all,
352 )
352 )
353 if after.failure:
353 if after.failure:
354 after = Dependency(after.difference(self.all_failed),
354 after = Dependency(after.difference(self.all_failed),
355 success=after.success,
355 success=after.success,
356 failure=after.failure,
356 failure=after.failure,
357 all=after.all,
357 all=after.all,
358 )
358 )
359 if after.check(self.all_completed, self.all_failed):
359 if after.check(self.all_completed, self.all_failed):
360 # recast as empty set, if `after` already met,
360 # recast as empty set, if `after` already met,
361 # to prevent unnecessary set comparisons
361 # to prevent unnecessary set comparisons
362 after = MET
362 after = MET
363 else:
363 else:
364 after = MET
364 after = MET
365
365
366 # location dependencies
366 # location dependencies
367 follow = Dependency(header.get('follow', []))
367 follow = Dependency(header.get('follow', []))
368
368
369 # turn timeouts into datetime objects:
369 # turn timeouts into datetime objects:
370 timeout = header.get('timeout', None)
370 timeout = header.get('timeout', None)
371 if timeout:
371 if timeout:
372 # cast to float, because jsonlib returns floats as decimal.Decimal,
372 # cast to float, because jsonlib returns floats as decimal.Decimal,
373 # which timedelta does not accept
373 # which timedelta does not accept
374 timeout = datetime.now() + timedelta(0,float(timeout),0)
374 timeout = datetime.now() + timedelta(0,float(timeout),0)
375
375
376 args = [raw_msg, targets, after, follow, timeout]
376 args = [raw_msg, targets, after, follow, timeout]
377
377
378 # validate and reduce dependencies:
378 # validate and reduce dependencies:
379 for dep in after,follow:
379 for dep in after,follow:
380 if not dep: # empty dependency
380 if not dep: # empty dependency
381 continue
381 continue
382 # check valid:
382 # check valid:
383 if msg_id in dep or dep.difference(self.all_ids):
383 if msg_id in dep or dep.difference(self.all_ids):
384 self.depending[msg_id] = args
384 self.depending[msg_id] = args
385 return self.fail_unreachable(msg_id, error.InvalidDependency)
385 return self.fail_unreachable(msg_id, error.InvalidDependency)
386 # check if unreachable:
386 # check if unreachable:
387 if dep.unreachable(self.all_completed, self.all_failed):
387 if dep.unreachable(self.all_completed, self.all_failed):
388 self.depending[msg_id] = args
388 self.depending[msg_id] = args
389 return self.fail_unreachable(msg_id)
389 return self.fail_unreachable(msg_id)
390
390
391 if after.check(self.all_completed, self.all_failed):
391 if after.check(self.all_completed, self.all_failed):
392 # time deps already met, try to run
392 # time deps already met, try to run
393 if not self.maybe_run(msg_id, *args):
393 if not self.maybe_run(msg_id, *args):
394 # can't run yet
394 # can't run yet
395 if msg_id not in self.all_failed:
395 if msg_id not in self.all_failed:
396 # could have failed as unreachable
396 # could have failed as unreachable
397 self.save_unmet(msg_id, *args)
397 self.save_unmet(msg_id, *args)
398 else:
398 else:
399 self.save_unmet(msg_id, *args)
399 self.save_unmet(msg_id, *args)
400
400
401 def audit_timeouts(self):
401 def audit_timeouts(self):
402 """Audit all waiting tasks for expired timeouts."""
402 """Audit all waiting tasks for expired timeouts."""
403 now = datetime.now()
403 now = datetime.now()
404 for msg_id in self.depending.keys():
404 for msg_id in self.depending.keys():
405 # must recheck, in case one failure cascaded to another:
405 # must recheck, in case one failure cascaded to another:
406 if msg_id in self.depending:
406 if msg_id in self.depending:
407 raw,after,targets,follow,timeout = self.depending[msg_id]
407 raw,after,targets,follow,timeout = self.depending[msg_id]
408 if timeout and timeout < now:
408 if timeout and timeout < now:
409 self.fail_unreachable(msg_id, error.TaskTimeout)
409 self.fail_unreachable(msg_id, error.TaskTimeout)
410
410
411 def fail_unreachable(self, msg_id, why=error.ImpossibleDependency):
411 def fail_unreachable(self, msg_id, why=error.ImpossibleDependency):
412 """a task has become unreachable, send a reply with an ImpossibleDependency
412 """a task has become unreachable, send a reply with an ImpossibleDependency
413 error."""
413 error."""
414 if msg_id not in self.depending:
414 if msg_id not in self.depending:
415 self.log.error("msg %r already failed!", msg_id)
415 self.log.error("msg %r already failed!", msg_id)
416 return
416 return
417 raw_msg,targets,after,follow,timeout = self.depending.pop(msg_id)
417 raw_msg,targets,after,follow,timeout = self.depending.pop(msg_id)
418 for mid in follow.union(after):
418 for mid in follow.union(after):
419 if mid in self.graph:
419 if mid in self.graph:
420 self.graph[mid].remove(msg_id)
420 self.graph[mid].remove(msg_id)
421
421
422 # FIXME: unpacking a message I've already unpacked, but didn't save:
422 # FIXME: unpacking a message I've already unpacked, but didn't save:
423 idents,msg = self.session.feed_identities(raw_msg, copy=False)
423 idents,msg = self.session.feed_identities(raw_msg, copy=False)
424 header = self.session.unpack(msg[1].bytes)
424 header = self.session.unpack(msg[1].bytes)
425
425
426 try:
426 try:
427 raise why()
427 raise why()
428 except:
428 except:
429 content = error.wrap_exception()
429 content = error.wrap_exception()
430
430
431 self.all_done.add(msg_id)
431 self.all_done.add(msg_id)
432 self.all_failed.add(msg_id)
432 self.all_failed.add(msg_id)
433
433
434 msg = self.session.send(self.client_stream, 'apply_reply', content,
434 msg = self.session.send(self.client_stream, 'apply_reply', content,
435 parent=header, ident=idents)
435 parent=header, ident=idents)
436 self.session.send(self.mon_stream, msg, ident=[b'outtask']+idents)
436 self.session.send(self.mon_stream, msg, ident=[b'outtask']+idents)
437
437
438 self.update_graph(msg_id, success=False)
438 self.update_graph(msg_id, success=False)
439
439
440 def maybe_run(self, msg_id, raw_msg, targets, after, follow, timeout):
440 def maybe_run(self, msg_id, raw_msg, targets, after, follow, timeout):
441 """check location dependencies, and run if they are met."""
441 """check location dependencies, and run if they are met."""
442 blacklist = self.blacklist.setdefault(msg_id, set())
442 blacklist = self.blacklist.setdefault(msg_id, set())
443 if follow or targets or blacklist or self.hwm:
443 if follow or targets or blacklist or self.hwm:
444 # we need a can_run filter
444 # we need a can_run filter
445 def can_run(idx):
445 def can_run(idx):
446 # check hwm
446 # check hwm
447 if self.hwm and self.loads[idx] == self.hwm:
447 if self.hwm and self.loads[idx] == self.hwm:
448 return False
448 return False
449 target = self.targets[idx]
449 target = self.targets[idx]
450 # check blacklist
450 # check blacklist
451 if target in blacklist:
451 if target in blacklist:
452 return False
452 return False
453 # check targets
453 # check targets
454 if targets and target not in targets:
454 if targets and target not in targets:
455 return False
455 return False
456 # check follow
456 # check follow
457 return follow.check(self.completed[target], self.failed[target])
457 return follow.check(self.completed[target], self.failed[target])
458
458
459 indices = filter(can_run, range(len(self.targets)))
459 indices = filter(can_run, range(len(self.targets)))
460
460
461 if not indices:
461 if not indices:
462 # couldn't run
462 # couldn't run
463 if follow.all:
463 if follow.all:
464 # check follow for impossibility
464 # check follow for impossibility
465 dests = set()
465 dests = set()
466 relevant = set()
466 relevant = set()
467 if follow.success:
467 if follow.success:
468 relevant = self.all_completed
468 relevant = self.all_completed
469 if follow.failure:
469 if follow.failure:
470 relevant = relevant.union(self.all_failed)
470 relevant = relevant.union(self.all_failed)
471 for m in follow.intersection(relevant):
471 for m in follow.intersection(relevant):
472 dests.add(self.destinations[m])
472 dests.add(self.destinations[m])
473 if len(dests) > 1:
473 if len(dests) > 1:
474 self.depending[msg_id] = (raw_msg, targets, after, follow, timeout)
474 self.depending[msg_id] = (raw_msg, targets, after, follow, timeout)
475 self.fail_unreachable(msg_id)
475 self.fail_unreachable(msg_id)
476 return False
476 return False
477 if targets:
477 if targets:
478 # check blacklist+targets for impossibility
478 # check blacklist+targets for impossibility
479 targets.difference_update(blacklist)
479 targets.difference_update(blacklist)
480 if not targets or not targets.intersection(self.targets):
480 if not targets or not targets.intersection(self.targets):
481 self.depending[msg_id] = (raw_msg, targets, after, follow, timeout)
481 self.depending[msg_id] = (raw_msg, targets, after, follow, timeout)
482 self.fail_unreachable(msg_id)
482 self.fail_unreachable(msg_id)
483 return False
483 return False
484 return False
484 return False
485 else:
485 else:
486 indices = None
486 indices = None
487
487
488 self.submit_task(msg_id, raw_msg, targets, follow, timeout, indices)
488 self.submit_task(msg_id, raw_msg, targets, follow, timeout, indices)
489 return True
489 return True
490
490
491 def save_unmet(self, msg_id, raw_msg, targets, after, follow, timeout):
491 def save_unmet(self, msg_id, raw_msg, targets, after, follow, timeout):
492 """Save a message for later submission when its dependencies are met."""
492 """Save a message for later submission when its dependencies are met."""
493 self.depending[msg_id] = [raw_msg,targets,after,follow,timeout]
493 self.depending[msg_id] = [raw_msg,targets,after,follow,timeout]
494 # track the ids in follow or after, but not those already finished
494 # track the ids in follow or after, but not those already finished
495 for dep_id in after.union(follow).difference(self.all_done):
495 for dep_id in after.union(follow).difference(self.all_done):
496 if dep_id not in self.graph:
496 if dep_id not in self.graph:
497 self.graph[dep_id] = set()
497 self.graph[dep_id] = set()
498 self.graph[dep_id].add(msg_id)
498 self.graph[dep_id].add(msg_id)
499
499
500 def submit_task(self, msg_id, raw_msg, targets, follow, timeout, indices=None):
500 def submit_task(self, msg_id, raw_msg, targets, follow, timeout, indices=None):
501 """Submit a task to any of a subset of our targets."""
501 """Submit a task to any of a subset of our targets."""
502 if indices:
502 if indices:
503 loads = [self.loads[i] for i in indices]
503 loads = [self.loads[i] for i in indices]
504 else:
504 else:
505 loads = self.loads
505 loads = self.loads
506 idx = self.scheme(loads)
506 idx = self.scheme(loads)
507 if indices:
507 if indices:
508 idx = indices[idx]
508 idx = indices[idx]
509 target = self.targets[idx]
509 target = self.targets[idx]
510 # print (target, map(str, msg[:3]))
510 # print (target, map(str, msg[:3]))
511 # send job to the engine
511 # send job to the engine
512 self.engine_stream.send(target, flags=zmq.SNDMORE, copy=False)
512 self.engine_stream.send(target, flags=zmq.SNDMORE, copy=False)
513 self.engine_stream.send_multipart(raw_msg, copy=False)
513 self.engine_stream.send_multipart(raw_msg, copy=False)
514 # update load
514 # update load
515 self.add_job(idx)
515 self.add_job(idx)
516 self.pending[target][msg_id] = (raw_msg, targets, MET, follow, timeout)
516 self.pending[target][msg_id] = (raw_msg, targets, MET, follow, timeout)
517 # notify Hub
517 # notify Hub
518 content = dict(msg_id=msg_id, engine_id=target.decode('ascii'))
518 content = dict(msg_id=msg_id, engine_id=target.decode('ascii'))
519 self.session.send(self.mon_stream, 'task_destination', content=content,
519 self.session.send(self.mon_stream, 'task_destination', content=content,
520 ident=[b'tracktask',self.ident])
520 ident=[b'tracktask',self.ident])
521
521
522
522
523 #-----------------------------------------------------------------------
523 #-----------------------------------------------------------------------
524 # Result Handling
524 # Result Handling
525 #-----------------------------------------------------------------------
525 #-----------------------------------------------------------------------
526 def dispatch_result(self, raw_msg):
526 def dispatch_result(self, raw_msg):
527 """dispatch method for result replies"""
527 """dispatch method for result replies"""
528 try:
528 try:
529 idents,msg = self.session.feed_identities(raw_msg, copy=False)
529 idents,msg = self.session.feed_identities(raw_msg, copy=False)
530 msg = self.session.unserialize(msg, content=False, copy=False)
530 msg = self.session.unserialize(msg, content=False, copy=False)
531 engine = idents[0]
531 engine = idents[0]
532 try:
532 try:
533 idx = self.targets.index(engine)
533 idx = self.targets.index(engine)
534 except ValueError:
534 except ValueError:
535 pass # skip load-update for dead engines
535 pass # skip load-update for dead engines
536 else:
536 else:
537 self.finish_job(idx)
537 self.finish_job(idx)
538 except Exception:
538 except Exception:
539 self.log.error("task::Invaid result: %r", raw_msg, exc_info=True)
539 self.log.error("task::Invaid result: %r", raw_msg, exc_info=True)
540 return
540 return
541
541
542 header = msg['header']
542 header = msg['header']
543 parent = msg['parent_header']
543 parent = msg['parent_header']
544 if header.get('dependencies_met', True):
544 if header.get('dependencies_met', True):
545 success = (header['status'] == 'ok')
545 success = (header['status'] == 'ok')
546 msg_id = parent['msg_id']
546 msg_id = parent['msg_id']
547 retries = self.retries[msg_id]
547 retries = self.retries[msg_id]
548 if not success and retries > 0:
548 if not success and retries > 0:
549 # failed
549 # failed
550 self.retries[msg_id] = retries - 1
550 self.retries[msg_id] = retries - 1
551 self.handle_unmet_dependency(idents, parent)
551 self.handle_unmet_dependency(idents, parent)
552 else:
552 else:
553 del self.retries[msg_id]
553 del self.retries[msg_id]
554 # relay to client and update graph
554 # relay to client and update graph
555 self.handle_result(idents, parent, raw_msg, success)
555 self.handle_result(idents, parent, raw_msg, success)
556 # send to Hub monitor
556 # send to Hub monitor
557 self.mon_stream.send_multipart([b'outtask']+raw_msg, copy=False)
557 self.mon_stream.send_multipart([b'outtask']+raw_msg, copy=False)
558 else:
558 else:
559 self.handle_unmet_dependency(idents, parent)
559 self.handle_unmet_dependency(idents, parent)
560
560
561 def handle_result(self, idents, parent, raw_msg, success=True):
561 def handle_result(self, idents, parent, raw_msg, success=True):
562 """handle a real task result, either success or failure"""
562 """handle a real task result, either success or failure"""
563 # first, relay result to client
563 # first, relay result to client
564 engine = idents[0]
564 engine = idents[0]
565 client = idents[1]
565 client = idents[1]
566 # swap_ids for XREP-XREP mirror
566 # swap_ids for XREP-XREP mirror
567 raw_msg[:2] = [client,engine]
567 raw_msg[:2] = [client,engine]
568 # print (map(str, raw_msg[:4]))
568 # print (map(str, raw_msg[:4]))
569 self.client_stream.send_multipart(raw_msg, copy=False)
569 self.client_stream.send_multipart(raw_msg, copy=False)
570 # now, update our data structures
570 # now, update our data structures
571 msg_id = parent['msg_id']
571 msg_id = parent['msg_id']
572 self.blacklist.pop(msg_id, None)
572 self.blacklist.pop(msg_id, None)
573 self.pending[engine].pop(msg_id)
573 self.pending[engine].pop(msg_id)
574 if success:
574 if success:
575 self.completed[engine].add(msg_id)
575 self.completed[engine].add(msg_id)
576 self.all_completed.add(msg_id)
576 self.all_completed.add(msg_id)
577 else:
577 else:
578 self.failed[engine].add(msg_id)
578 self.failed[engine].add(msg_id)
579 self.all_failed.add(msg_id)
579 self.all_failed.add(msg_id)
580 self.all_done.add(msg_id)
580 self.all_done.add(msg_id)
581 self.destinations[msg_id] = engine
581 self.destinations[msg_id] = engine
582
582
583 self.update_graph(msg_id, success)
583 self.update_graph(msg_id, success)
584
584
585 def handle_unmet_dependency(self, idents, parent):
585 def handle_unmet_dependency(self, idents, parent):
586 """handle an unmet dependency"""
586 """handle an unmet dependency"""
587 engine = idents[0]
587 engine = idents[0]
588 msg_id = parent['msg_id']
588 msg_id = parent['msg_id']
589
589
590 if msg_id not in self.blacklist:
590 if msg_id not in self.blacklist:
591 self.blacklist[msg_id] = set()
591 self.blacklist[msg_id] = set()
592 self.blacklist[msg_id].add(engine)
592 self.blacklist[msg_id].add(engine)
593
593
594 args = self.pending[engine].pop(msg_id)
594 args = self.pending[engine].pop(msg_id)
595 raw,targets,after,follow,timeout = args
595 raw,targets,after,follow,timeout = args
596
596
597 if self.blacklist[msg_id] == targets:
597 if self.blacklist[msg_id] == targets:
598 self.depending[msg_id] = args
598 self.depending[msg_id] = args
599 self.fail_unreachable(msg_id)
599 self.fail_unreachable(msg_id)
600 elif not self.maybe_run(msg_id, *args):
600 elif not self.maybe_run(msg_id, *args):
601 # resubmit failed
601 # resubmit failed
602 if msg_id not in self.all_failed:
602 if msg_id not in self.all_failed:
603 # put it back in our dependency tree
603 # put it back in our dependency tree
604 self.save_unmet(msg_id, *args)
604 self.save_unmet(msg_id, *args)
605
605
606 if self.hwm:
606 if self.hwm:
607 try:
607 try:
608 idx = self.targets.index(engine)
608 idx = self.targets.index(engine)
609 except ValueError:
609 except ValueError:
610 pass # skip load-update for dead engines
610 pass # skip load-update for dead engines
611 else:
611 else:
612 if self.loads[idx] == self.hwm-1:
612 if self.loads[idx] == self.hwm-1:
613 self.update_graph(None)
613 self.update_graph(None)
614
614
615
615
616
616
617 def update_graph(self, dep_id=None, success=True):
617 def update_graph(self, dep_id=None, success=True):
618 """dep_id just finished. Update our dependency
618 """dep_id just finished. Update our dependency
619 graph and submit any jobs that just became runable.
619 graph and submit any jobs that just became runable.
620
620
621 Called with dep_id=None to update entire graph for hwm, but without finishing
621 Called with dep_id=None to update entire graph for hwm, but without finishing
622 a task.
622 a task.
623 """
623 """
624 # print ("\n\n***********")
624 # print ("\n\n***********")
625 # pprint (dep_id)
625 # pprint (dep_id)
626 # pprint (self.graph)
626 # pprint (self.graph)
627 # pprint (self.depending)
627 # pprint (self.depending)
628 # pprint (self.all_completed)
628 # pprint (self.all_completed)
629 # pprint (self.all_failed)
629 # pprint (self.all_failed)
630 # print ("\n\n***********\n\n")
630 # print ("\n\n***********\n\n")
631 # update any jobs that depended on the dependency
631 # update any jobs that depended on the dependency
632 jobs = self.graph.pop(dep_id, [])
632 jobs = self.graph.pop(dep_id, [])
633
633
634 # recheck *all* jobs if
634 # recheck *all* jobs if
635 # a) we have HWM and an engine just become no longer full
635 # a) we have HWM and an engine just become no longer full
636 # or b) dep_id was given as None
636 # or b) dep_id was given as None
637 if dep_id is None or self.hwm and any( [ load==self.hwm-1 for load in self.loads ]):
637 if dep_id is None or self.hwm and any( [ load==self.hwm-1 for load in self.loads ]):
638 jobs = self.depending.keys()
638 jobs = self.depending.keys()
639
639
640 for msg_id in jobs:
640 for msg_id in jobs:
641 raw_msg, targets, after, follow, timeout = self.depending[msg_id]
641 raw_msg, targets, after, follow, timeout = self.depending[msg_id]
642
642
643 if after.unreachable(self.all_completed, self.all_failed)\
643 if after.unreachable(self.all_completed, self.all_failed)\
644 or follow.unreachable(self.all_completed, self.all_failed):
644 or follow.unreachable(self.all_completed, self.all_failed):
645 self.fail_unreachable(msg_id)
645 self.fail_unreachable(msg_id)
646
646
647 elif after.check(self.all_completed, self.all_failed): # time deps met, maybe run
647 elif after.check(self.all_completed, self.all_failed): # time deps met, maybe run
648 if self.maybe_run(msg_id, raw_msg, targets, MET, follow, timeout):
648 if self.maybe_run(msg_id, raw_msg, targets, MET, follow, timeout):
649
649
650 self.depending.pop(msg_id)
650 self.depending.pop(msg_id)
651 for mid in follow.union(after):
651 for mid in follow.union(after):
652 if mid in self.graph:
652 if mid in self.graph:
653 self.graph[mid].remove(msg_id)
653 self.graph[mid].remove(msg_id)
654
654
655 #----------------------------------------------------------------------
655 #----------------------------------------------------------------------
656 # methods to be overridden by subclasses
656 # methods to be overridden by subclasses
657 #----------------------------------------------------------------------
657 #----------------------------------------------------------------------
658
658
659 def add_job(self, idx):
659 def add_job(self, idx):
660 """Called after self.targets[idx] just got the job with header.
660 """Called after self.targets[idx] just got the job with header.
661 Override with subclasses. The default ordering is simple LRU.
661 Override with subclasses. The default ordering is simple LRU.
662 The default loads are the number of outstanding jobs."""
662 The default loads are the number of outstanding jobs."""
663 self.loads[idx] += 1
663 self.loads[idx] += 1
664 for lis in (self.targets, self.loads):
664 for lis in (self.targets, self.loads):
665 lis.append(lis.pop(idx))
665 lis.append(lis.pop(idx))
666
666
667
667
668 def finish_job(self, idx):
668 def finish_job(self, idx):
669 """Called after self.targets[idx] just finished a job.
669 """Called after self.targets[idx] just finished a job.
670 Override with subclasses."""
670 Override with subclasses."""
671 self.loads[idx] -= 1
671 self.loads[idx] -= 1
672
672
673
673
674
674
675 def launch_scheduler(in_addr, out_addr, mon_addr, not_addr, config=None,
675 def launch_scheduler(in_addr, out_addr, mon_addr, not_addr, config=None,
676 logname='root', log_url=None, loglevel=logging.DEBUG,
676 logname='root', log_url=None, loglevel=logging.DEBUG,
677 identity=b'task', in_thread=False):
677 identity=b'task', in_thread=False):
678
678
679 ZMQStream = zmqstream.ZMQStream
679 ZMQStream = zmqstream.ZMQStream
680
680
681 if config:
681 if config:
682 # unwrap dict back into Config
682 # unwrap dict back into Config
683 config = Config(config)
683 config = Config(config)
684
684
685 if in_thread:
685 if in_thread:
686 # use instance() to get the same Context/Loop as our parent
686 # use instance() to get the same Context/Loop as our parent
687 ctx = zmq.Context.instance()
687 ctx = zmq.Context.instance()
688 loop = ioloop.IOLoop.instance()
688 loop = ioloop.IOLoop.instance()
689 else:
689 else:
690 # in a process, don't use instance()
690 # in a process, don't use instance()
691 # for safety with multiprocessing
691 # for safety with multiprocessing
692 ctx = zmq.Context()
692 ctx = zmq.Context()
693 loop = ioloop.IOLoop()
693 loop = ioloop.IOLoop()
694 ins = ZMQStream(ctx.socket(zmq.ROUTER),loop)
694 ins = ZMQStream(ctx.socket(zmq.ROUTER),loop)
695 ins.setsockopt(zmq.IDENTITY, identity)
695 ins.setsockopt(zmq.IDENTITY, identity)
696 ins.bind(in_addr)
696 ins.bind(in_addr)
697
697
698 outs = ZMQStream(ctx.socket(zmq.ROUTER),loop)
698 outs = ZMQStream(ctx.socket(zmq.ROUTER),loop)
699 outs.setsockopt(zmq.IDENTITY, identity)
699 outs.setsockopt(zmq.IDENTITY, identity)
700 outs.bind(out_addr)
700 outs.bind(out_addr)
701 mons = zmqstream.ZMQStream(ctx.socket(zmq.PUB),loop)
701 mons = zmqstream.ZMQStream(ctx.socket(zmq.PUB),loop)
702 mons.connect(mon_addr)
702 mons.connect(mon_addr)
703 nots = zmqstream.ZMQStream(ctx.socket(zmq.SUB),loop)
703 nots = zmqstream.ZMQStream(ctx.socket(zmq.SUB),loop)
704 nots.setsockopt(zmq.SUBSCRIBE, b'')
704 nots.setsockopt(zmq.SUBSCRIBE, b'')
705 nots.connect(not_addr)
705 nots.connect(not_addr)
706
706
707 # setup logging.
707 # setup logging.
708 if in_thread:
708 if in_thread:
709 log = Application.instance().log
709 log = Application.instance().log
710 else:
710 else:
711 if log_url:
711 if log_url:
712 log = connect_logger(logname, ctx, log_url, root="scheduler", loglevel=loglevel)
712 log = connect_logger(logname, ctx, log_url, root="scheduler", loglevel=loglevel)
713 else:
713 else:
714 log = local_logger(logname, loglevel)
714 log = local_logger(logname, loglevel)
715
715
716 scheduler = TaskScheduler(client_stream=ins, engine_stream=outs,
716 scheduler = TaskScheduler(client_stream=ins, engine_stream=outs,
717 mon_stream=mons, notifier_stream=nots,
717 mon_stream=mons, notifier_stream=nots,
718 loop=loop, log=log,
718 loop=loop, log=log,
719 config=config)
719 config=config)
720 scheduler.start()
720 scheduler.start()
721 if not in_thread:
721 if not in_thread:
722 try:
722 try:
723 loop.start()
723 loop.start()
724 except KeyboardInterrupt:
724 except KeyboardInterrupt:
725 print ("interrupted, exiting...", file=sys.__stderr__)
725 scheduler.log.critical("Interrupted, exiting...")
726
726
General Comments 0
You need to be logged in to leave comments. Login now