##// END OF EJS Templates
add HWM to TaskScheduler...
MinRK -
Show More
@@ -1,180 +1,180 b''
1 1 from IPython.config.loader import Config
2 2
3 3 c = get_config()
4 4
5 5 #-----------------------------------------------------------------------------
6 6 # Global configuration
7 7 #-----------------------------------------------------------------------------
8 8
9 9 # Basic Global config attributes
10 10
11 11 # Start up messages are logged to stdout using the logging module.
12 12 # These all happen before the twisted reactor is started and are
13 13 # useful for debugging purposes. Can be (10=DEBUG,20=INFO,30=WARN,40=CRITICAL)
14 14 # and smaller is more verbose.
15 15 # c.Global.log_level = 20
16 16
17 17 # Log to a file in cluster_dir/log, otherwise just log to sys.stdout.
18 18 # c.Global.log_to_file = False
19 19
20 20 # Remove old logs from cluster_dir/log before starting.
21 21 # c.Global.clean_logs = True
22 22
23 23 # A list of Python statements that will be run before starting the
24 24 # controller. This is provided because occasionally certain things need to
25 25 # be imported in the controller for pickling to work.
26 26 # c.Global.import_statements = ['import math']
27 27
28 28 # Reuse the controller's JSON files. If False, JSON files are regenerated
29 29 # each time the controller is run. If True, they will be reused, *but*, you
30 30 # also must set the network ports by hand. If set, this will override the
31 31 # values set for the client and engine connections below.
32 32 # c.Global.reuse_files = True
33 33
34 34 # Enable exec_key authentication on all messages. Default is True
35 35 # c.Global.secure = True
36 36
37 37 # The working directory for the process. The application will use os.chdir
38 38 # to change to this directory before starting.
39 39 # c.Global.work_dir = os.getcwd()
40 40
41 41 # The log url for logging to an `iploggerz` application. This will override
42 42 # log-to-file.
43 43 # c.Global.log_url = 'tcp://127.0.0.1:20202'
44 44
45 45 # The specific external IP that is used to disambiguate multi-interface URLs.
46 46 # The default behavior is to guess from external IPs gleaned from `socket`.
47 47 # c.Global.location = '192.168.1.123'
48 48
49 49 # The ssh server remote clients should use to connect to this controller.
50 50 # It must be a machine that can see the interface specified in client_ip.
51 51 # The default for client_ip is localhost, in which case the sshserver must
52 52 # be an external IP of the controller machine.
53 53 # c.Global.sshserver = 'controller.example.com'
54 54
55 55 # the url to use for registration. If set, this overrides engine-ip,
56 56 # engine-transport client-ip,client-transport, and regport.
57 57 # c.RegistrationFactory.url = 'tcp://*:12345'
58 58
59 59 # the port to use for registration. Clients and Engines both use this
60 60 # port for registration.
61 61 # c.RegistrationFactory.regport = 10101
62 62
63 63 #-----------------------------------------------------------------------------
64 64 # Configure the Task Scheduler
65 65 #-----------------------------------------------------------------------------
66 66
67 67 # The routing scheme. 'pure' will use the pure-ZMQ scheduler. Any other
68 68 # value will use a Python scheduler with various routing schemes.
69 69 # python schemes are: lru, weighted, random, twobin. Default is 'weighted'.
70 70 # Note that the pure ZMQ scheduler does not support many features, such as
71 71 # dying engines, dependencies, or engine-subset load-balancing.
72 72 # c.ControllerFactory.scheme = 'pure'
73 73
74 # The pure ZMQ scheduler can limit the number of outstanding tasks per engine
75 # by using the ZMQ HWM option. This allows engines with long-running tasks
74 # The Python scheduler can limit the number of outstanding tasks per engine
75 # by using an HWM option. This allows engines with long-running tasks
76 76 # to not steal too many tasks from other engines. The default is 0, which
77 77 # means agressively distribute messages, never waiting for them to finish.
78 # c.ControllerFactory.hwm = 1
78 # c.TaskScheduler.hwm = 0
79 79
80 80 # Whether to use Threads or Processes to start the Schedulers. Threads will
81 81 # use less resources, but potentially reduce throughput. Default is to
82 82 # use processes. Note that the a Python scheduler will always be in a Process.
83 83 # c.ControllerFactory.usethreads
84 84
85 85 #-----------------------------------------------------------------------------
86 86 # Configure the Hub
87 87 #-----------------------------------------------------------------------------
88 88
89 89 # Which class to use for the db backend. Currently supported are DictDB (the
90 90 # default), and MongoDB. Uncomment this line to enable MongoDB, which will
91 91 # slow-down the Hub's responsiveness, but also reduce its memory footprint.
92 92 # c.HubFactory.db_class = 'IPython.parallel.controller.mongodb.MongoDB'
93 93
94 94 # The heartbeat ping frequency. This is the frequency (in ms) at which the
95 95 # Hub pings engines for heartbeats. This determines how quickly the Hub
96 96 # will react to engines coming and going. A lower number means faster response
97 97 # time, but more network activity. The default is 100ms
98 98 # c.HubFactory.ping = 100
99 99
100 100 # HubFactory queue port pairs, to set by name: mux, iopub, control, task. Set
101 101 # each as a tuple of length 2 of ints. The default is to find random
102 102 # available ports
103 103 # c.HubFactory.mux = (10102,10112)
104 104
105 105 #-----------------------------------------------------------------------------
106 106 # Configure the client connections
107 107 #-----------------------------------------------------------------------------
108 108
109 109 # Basic client connection config attributes
110 110
111 111 # The network interface the controller will listen on for client connections.
112 112 # This should be an IP address or interface on the controller. An asterisk
113 113 # means listen on all interfaces. The transport can be any transport
114 114 # supported by zeromq (tcp,epgm,pgm,ib,ipc):
115 115 # c.HubFactory.client_ip = '*'
116 116 # c.HubFactory.client_transport = 'tcp'
117 117
118 118 # individual client ports to configure by name: query_port, notifier_port
119 119 # c.HubFactory.query_port = 12345
120 120
121 121 #-----------------------------------------------------------------------------
122 122 # Configure the engine connections
123 123 #-----------------------------------------------------------------------------
124 124
125 125 # Basic config attributes for the engine connections.
126 126
127 127 # The network interface the controller will listen on for engine connections.
128 128 # This should be an IP address or interface on the controller. An asterisk
129 129 # means listen on all interfaces. The transport can be any transport
130 130 # supported by zeromq (tcp,epgm,pgm,ib,ipc):
131 131 # c.HubFactory.engine_ip = '*'
132 132 # c.HubFactory.engine_transport = 'tcp'
133 133
134 134 # set the engine heartbeat ports to use:
135 135 # c.HubFactory.hb = (10303,10313)
136 136
137 137 #-----------------------------------------------------------------------------
138 138 # Configure the TaskRecord database backend
139 139 #-----------------------------------------------------------------------------
140 140
141 141 # For memory/persistance reasons, tasks can be stored out-of-memory in a database.
142 142 # Currently, only sqlite and mongodb are supported as backends, but the interface
143 143 # is fairly simple, so advanced developers could write their own backend.
144 144
145 145 # ----- in-memory configuration --------
146 146 # this line restores the default behavior: in-memory storage of all results.
147 147 # c.HubFactory.db_class = 'IPython.parallel.controller.dictdb.DictDB'
148 148
149 149 # ----- sqlite configuration --------
150 150 # use this line to activate sqlite:
151 151 # c.HubFactory.db_class = 'IPython.parallel.controller.sqlitedb.SQLiteDB'
152 152
153 153 # You can specify the name of the db-file. By default, this will be located
154 154 # in the active cluster_dir, e.g. ~/.ipython/clusterz_default/tasks.db
155 155 # c.SQLiteDB.filename = 'tasks.db'
156 156
157 157 # You can also specify the location of the db-file, if you want it to be somewhere
158 158 # other than the cluster_dir.
159 159 # c.SQLiteDB.location = '/scratch/'
160 160
161 161 # This will specify the name of the table for the controller to use. The default
162 162 # behavior is to use the session ID of the SessionFactory object (a uuid). Overriding
163 163 # this will result in results persisting for multiple sessions.
164 164 # c.SQLiteDB.table = 'results'
165 165
166 166 # ----- mongodb configuration --------
167 167 # use this line to activate mongodb:
168 168 # c.HubFactory.db_class = 'IPython.parallel.controller.mongodb.MongoDB'
169 169
170 170 # You can specify the args and kwargs pymongo will use when creating the Connection.
171 171 # For more information on what these options might be, see pymongo documentation.
172 172 # c.MongoDB.connection_kwargs = {}
173 173 # c.MongoDB.connection_args = []
174 174
175 175 # This will specify the name of the mongo database for the controller to use. The default
176 176 # behavior is to use the session ID of the SessionFactory object (a uuid). Overriding
177 177 # this will result in task results persisting through multiple sessions.
178 178 # c.MongoDB.database = 'ipythondb'
179 179
180 180
@@ -1,596 +1,621 b''
1 1 """The Python scheduler for rich scheduling.
2 2
3 3 The Pure ZMQ scheduler does not allow routing schemes other than LRU,
4 4 nor does it check msg_id DAG dependencies. For those, a slightly slower
5 5 Python Scheduler exists.
6 6 """
7 7 #-----------------------------------------------------------------------------
8 8 # Copyright (C) 2010-2011 The IPython Development Team
9 9 #
10 10 # Distributed under the terms of the BSD License. The full license is in
11 11 # the file COPYING, distributed as part of this software.
12 12 #-----------------------------------------------------------------------------
13 13
14 14 #----------------------------------------------------------------------
15 15 # Imports
16 16 #----------------------------------------------------------------------
17 17
18 18 from __future__ import print_function
19 19
20 20 import logging
21 21 import sys
22 22
23 23 from datetime import datetime, timedelta
24 24 from random import randint, random
25 25 from types import FunctionType
26 26
27 27 try:
28 28 import numpy
29 29 except ImportError:
30 30 numpy = None
31 31
32 32 import zmq
33 33 from zmq.eventloop import ioloop, zmqstream
34 34
35 35 # local imports
36 36 from IPython.external.decorator import decorator
37 37 from IPython.config.loader import Config
38 from IPython.utils.traitlets import Instance, Dict, List, Set
38 from IPython.utils.traitlets import Instance, Dict, List, Set, Int
39 39
40 40 from IPython.parallel import error
41 41 from IPython.parallel.factory import SessionFactory
42 42 from IPython.parallel.util import connect_logger, local_logger
43 43
44 44 from .dependency import Dependency
45 45
46 46 @decorator
47 47 def logged(f,self,*args,**kwargs):
48 48 # print ("#--------------------")
49 49 self.log.debug("scheduler::%s(*%s,**%s)"%(f.func_name, args, kwargs))
50 50 # print ("#--")
51 51 return f(self,*args, **kwargs)
52 52
53 53 #----------------------------------------------------------------------
54 54 # Chooser functions
55 55 #----------------------------------------------------------------------
56 56
57 57 def plainrandom(loads):
58 58 """Plain random pick."""
59 59 n = len(loads)
60 60 return randint(0,n-1)
61 61
62 62 def lru(loads):
63 63 """Always pick the front of the line.
64 64
65 65 The content of `loads` is ignored.
66 66
67 67 Assumes LRU ordering of loads, with oldest first.
68 68 """
69 69 return 0
70 70
71 71 def twobin(loads):
72 72 """Pick two at random, use the LRU of the two.
73 73
74 74 The content of loads is ignored.
75 75
76 76 Assumes LRU ordering of loads, with oldest first.
77 77 """
78 78 n = len(loads)
79 79 a = randint(0,n-1)
80 80 b = randint(0,n-1)
81 81 return min(a,b)
82 82
83 83 def weighted(loads):
84 84 """Pick two at random using inverse load as weight.
85 85
86 86 Return the less loaded of the two.
87 87 """
88 88 # weight 0 a million times more than 1:
89 89 weights = 1./(1e-6+numpy.array(loads))
90 90 sums = weights.cumsum()
91 91 t = sums[-1]
92 92 x = random()*t
93 93 y = random()*t
94 94 idx = 0
95 95 idy = 0
96 96 while sums[idx] < x:
97 97 idx += 1
98 98 while sums[idy] < y:
99 99 idy += 1
100 100 if weights[idy] > weights[idx]:
101 101 return idy
102 102 else:
103 103 return idx
104 104
105 105 def leastload(loads):
106 106 """Always choose the lowest load.
107 107
108 108 If the lowest load occurs more than once, the first
109 109 occurance will be used. If loads has LRU ordering, this means
110 110 the LRU of those with the lowest load is chosen.
111 111 """
112 112 return loads.index(min(loads))
113 113
114 114 #---------------------------------------------------------------------
115 115 # Classes
116 116 #---------------------------------------------------------------------
117 117 # store empty default dependency:
118 118 MET = Dependency([])
119 119
120 120 class TaskScheduler(SessionFactory):
121 121 """Python TaskScheduler object.
122 122
123 123 This is the simplest object that supports msg_id based
124 124 DAG dependencies. *Only* task msg_ids are checked, not
125 125 msg_ids of jobs submitted via the MUX queue.
126 126
127 127 """
128 128
129 hwm = Int(0, config=True) # limit number of outstanding tasks
130
129 131 # input arguments:
130 132 scheme = Instance(FunctionType, default=leastload) # function for determining the destination
131 133 client_stream = Instance(zmqstream.ZMQStream) # client-facing stream
132 134 engine_stream = Instance(zmqstream.ZMQStream) # engine-facing stream
133 135 notifier_stream = Instance(zmqstream.ZMQStream) # hub-facing sub stream
134 136 mon_stream = Instance(zmqstream.ZMQStream) # hub-facing pub stream
135 137
136 138 # internals:
137 139 graph = Dict() # dict by msg_id of [ msg_ids that depend on key ]
140 # waiting = List() # list of msg_ids ready to run, but haven't due to HWM
138 141 depending = Dict() # dict by msg_id of (msg_id, raw_msg, after, follow)
139 142 pending = Dict() # dict by engine_uuid of submitted tasks
140 143 completed = Dict() # dict by engine_uuid of completed tasks
141 144 failed = Dict() # dict by engine_uuid of failed tasks
142 145 destinations = Dict() # dict by msg_id of engine_uuids where jobs ran (reverse of completed+failed)
143 146 clients = Dict() # dict by msg_id for who submitted the task
144 147 targets = List() # list of target IDENTs
145 148 loads = List() # list of engine loads
149 # full = Set() # set of IDENTs that have HWM outstanding tasks
146 150 all_completed = Set() # set of all completed tasks
147 151 all_failed = Set() # set of all failed tasks
148 152 all_done = Set() # set of all finished tasks=union(completed,failed)
149 153 all_ids = Set() # set of all submitted task IDs
150 154 blacklist = Dict() # dict by msg_id of locations where a job has encountered UnmetDependency
151 155 auditor = Instance('zmq.eventloop.ioloop.PeriodicCallback')
152 156
153 157
154 158 def start(self):
155 159 self.engine_stream.on_recv(self.dispatch_result, copy=False)
156 160 self._notification_handlers = dict(
157 161 registration_notification = self._register_engine,
158 162 unregistration_notification = self._unregister_engine
159 163 )
160 164 self.notifier_stream.on_recv(self.dispatch_notification)
161 165 self.auditor = ioloop.PeriodicCallback(self.audit_timeouts, 2e3, self.loop) # 1 Hz
162 166 self.auditor.start()
163 167 self.log.info("Scheduler started...%r"%self)
164 168
165 169 def resume_receiving(self):
166 170 """Resume accepting jobs."""
167 171 self.client_stream.on_recv(self.dispatch_submission, copy=False)
168 172
169 173 def stop_receiving(self):
170 174 """Stop accepting jobs while there are no engines.
171 175 Leave them in the ZMQ queue."""
172 176 self.client_stream.on_recv(None)
173 177
174 178 #-----------------------------------------------------------------------
175 179 # [Un]Registration Handling
176 180 #-----------------------------------------------------------------------
177 181
178 182 def dispatch_notification(self, msg):
179 183 """dispatch register/unregister events."""
180 184 idents,msg = self.session.feed_identities(msg)
181 185 msg = self.session.unpack_message(msg)
182 186 msg_type = msg['msg_type']
183 187 handler = self._notification_handlers.get(msg_type, None)
184 188 if handler is None:
185 189 raise Exception("Unhandled message type: %s"%msg_type)
186 190 else:
187 191 try:
188 192 handler(str(msg['content']['queue']))
189 193 except KeyError:
190 194 self.log.error("task::Invalid notification msg: %s"%msg)
191 195
192 196 @logged
193 197 def _register_engine(self, uid):
194 198 """New engine with ident `uid` became available."""
195 199 # head of the line:
196 200 self.targets.insert(0,uid)
197 201 self.loads.insert(0,0)
198 202 # initialize sets
199 203 self.completed[uid] = set()
200 204 self.failed[uid] = set()
201 205 self.pending[uid] = {}
202 206 if len(self.targets) == 1:
203 207 self.resume_receiving()
204 208
205 209 def _unregister_engine(self, uid):
206 210 """Existing engine with ident `uid` became unavailable."""
207 211 if len(self.targets) == 1:
208 212 # this was our only engine
209 213 self.stop_receiving()
210 214
211 215 # handle any potentially finished tasks:
212 216 self.engine_stream.flush()
213 217
214 218 self.completed.pop(uid)
215 219 self.failed.pop(uid)
216 220 # don't pop destinations, because it might be used later
217 221 # map(self.destinations.pop, self.completed.pop(uid))
218 222 # map(self.destinations.pop, self.failed.pop(uid))
219
220 223 idx = self.targets.index(uid)
221 224 self.targets.pop(idx)
222 225 self.loads.pop(idx)
223 226
224 227 # wait 5 seconds before cleaning up pending jobs, since the results might
225 228 # still be incoming
226 229 if self.pending[uid]:
227 230 dc = ioloop.DelayedCallback(lambda : self.handle_stranded_tasks(uid), 5000, self.loop)
228 231 dc.start()
229 232
230 233 @logged
231 234 def handle_stranded_tasks(self, engine):
232 235 """Deal with jobs resident in an engine that died."""
233 236 lost = self.pending.pop(engine)
234 237
235 238 for msg_id, (raw_msg, targets, MET, follow, timeout) in lost.iteritems():
236 239 self.all_failed.add(msg_id)
237 240 self.all_done.add(msg_id)
238 241 idents,msg = self.session.feed_identities(raw_msg, copy=False)
239 242 msg = self.session.unpack_message(msg, copy=False, content=False)
240 243 parent = msg['header']
241 244 idents = [idents[0],engine]+idents[1:]
242 245 # print (idents)
243 246 try:
244 247 raise error.EngineError("Engine %r died while running task %r"%(engine, msg_id))
245 248 except:
246 249 content = error.wrap_exception()
247 250 msg = self.session.send(self.client_stream, 'apply_reply', content,
248 251 parent=parent, ident=idents)
249 252 self.session.send(self.mon_stream, msg, ident=['outtask']+idents)
250 253 self.update_graph(msg_id)
251 254
252 255
253 256 #-----------------------------------------------------------------------
254 257 # Job Submission
255 258 #-----------------------------------------------------------------------
256 259 @logged
257 260 def dispatch_submission(self, raw_msg):
258 261 """Dispatch job submission to appropriate handlers."""
259 262 # ensure targets up to date:
260 263 self.notifier_stream.flush()
261 264 try:
262 265 idents, msg = self.session.feed_identities(raw_msg, copy=False)
263 266 msg = self.session.unpack_message(msg, content=False, copy=False)
264 except:
267 except Exception:
265 268 self.log.error("task::Invaid task: %s"%raw_msg, exc_info=True)
266 269 return
267 270
268 271 # send to monitor
269 272 self.mon_stream.send_multipart(['intask']+raw_msg, copy=False)
270 273
271 274 header = msg['header']
272 275 msg_id = header['msg_id']
273 276 self.all_ids.add(msg_id)
274 277
275 278 # targets
276 279 targets = set(header.get('targets', []))
277 280
278 281 # time dependencies
279 282 after = Dependency(header.get('after', []))
280 283 if after.all:
281 284 if after.success:
282 285 after.difference_update(self.all_completed)
283 286 if after.failure:
284 287 after.difference_update(self.all_failed)
285 288 if after.check(self.all_completed, self.all_failed):
286 289 # recast as empty set, if `after` already met,
287 290 # to prevent unnecessary set comparisons
288 291 after = MET
289 292
290 293 # location dependencies
291 294 follow = Dependency(header.get('follow', []))
292 295
293 296 # turn timeouts into datetime objects:
294 297 timeout = header.get('timeout', None)
295 298 if timeout:
296 299 timeout = datetime.now() + timedelta(0,timeout,0)
297 300
298 301 args = [raw_msg, targets, after, follow, timeout]
299 302
300 303 # validate and reduce dependencies:
301 304 for dep in after,follow:
302 305 # check valid:
303 306 if msg_id in dep or dep.difference(self.all_ids):
304 307 self.depending[msg_id] = args
305 308 return self.fail_unreachable(msg_id, error.InvalidDependency)
306 309 # check if unreachable:
307 310 if dep.unreachable(self.all_completed, self.all_failed):
308 311 self.depending[msg_id] = args
309 312 return self.fail_unreachable(msg_id)
310 313
311 314 if after.check(self.all_completed, self.all_failed):
312 315 # time deps already met, try to run
313 316 if not self.maybe_run(msg_id, *args):
314 317 # can't run yet
315 318 self.save_unmet(msg_id, *args)
316 319 else:
317 320 self.save_unmet(msg_id, *args)
318 321
319 322 # @logged
320 323 def audit_timeouts(self):
321 324 """Audit all waiting tasks for expired timeouts."""
322 325 now = datetime.now()
323 326 for msg_id in self.depending.keys():
324 327 # must recheck, in case one failure cascaded to another:
325 328 if msg_id in self.depending:
326 329 raw,after,targets,follow,timeout = self.depending[msg_id]
327 330 if timeout and timeout < now:
328 331 self.fail_unreachable(msg_id, timeout=True)
329 332
330 333 @logged
331 334 def fail_unreachable(self, msg_id, why=error.ImpossibleDependency):
332 335 """a task has become unreachable, send a reply with an ImpossibleDependency
333 336 error."""
334 337 if msg_id not in self.depending:
335 338 self.log.error("msg %r already failed!"%msg_id)
336 339 return
337 340 raw_msg,targets,after,follow,timeout = self.depending.pop(msg_id)
338 341 for mid in follow.union(after):
339 342 if mid in self.graph:
340 343 self.graph[mid].remove(msg_id)
341 344
342 345 # FIXME: unpacking a message I've already unpacked, but didn't save:
343 346 idents,msg = self.session.feed_identities(raw_msg, copy=False)
344 347 msg = self.session.unpack_message(msg, copy=False, content=False)
345 348 header = msg['header']
346 349
347 350 try:
348 351 raise why()
349 352 except:
350 353 content = error.wrap_exception()
351 354
352 355 self.all_done.add(msg_id)
353 356 self.all_failed.add(msg_id)
354 357
355 358 msg = self.session.send(self.client_stream, 'apply_reply', content,
356 359 parent=header, ident=idents)
357 360 self.session.send(self.mon_stream, msg, ident=['outtask']+idents)
358 361
359 362 self.update_graph(msg_id, success=False)
360 363
361 364 @logged
362 365 def maybe_run(self, msg_id, raw_msg, targets, after, follow, timeout):
363 366 """check location dependencies, and run if they are met."""
364 367 blacklist = self.blacklist.setdefault(msg_id, set())
365 if follow or targets or blacklist:
368 if follow or targets or blacklist or self.hwm:
366 369 # we need a can_run filter
367 370 def can_run(idx):
368 target = self.targets[idx]
369 # check targets
370 if targets and target not in targets:
371 # check hwm
372 if self.loads[idx] == self.hwm:
371 373 return False
374 target = self.targets[idx]
372 375 # check blacklist
373 376 if target in blacklist:
374 377 return False
378 # check targets
379 if targets and target not in targets:
380 return False
375 381 # check follow
376 382 return follow.check(self.completed[target], self.failed[target])
377 383
378 384 indices = filter(can_run, range(len(self.targets)))
379 385 if not indices:
380 386 # couldn't run
381 387 if follow.all:
382 388 # check follow for impossibility
383 389 dests = set()
384 390 relevant = set()
385 391 if follow.success:
386 392 relevant = self.all_completed
387 393 if follow.failure:
388 394 relevant = relevant.union(self.all_failed)
389 395 for m in follow.intersection(relevant):
390 396 dests.add(self.destinations[m])
391 397 if len(dests) > 1:
392 398 self.fail_unreachable(msg_id)
393 399 return False
394 400 if targets:
395 401 # check blacklist+targets for impossibility
396 402 targets.difference_update(blacklist)
397 403 if not targets or not targets.intersection(self.targets):
398 404 self.fail_unreachable(msg_id)
399 405 return False
400 406 return False
401 407 else:
402 408 indices = None
403 409
404 410 self.submit_task(msg_id, raw_msg, targets, follow, timeout, indices)
405 411 return True
406 412
407 413 @logged
408 414 def save_unmet(self, msg_id, raw_msg, targets, after, follow, timeout):
409 415 """Save a message for later submission when its dependencies are met."""
410 416 self.depending[msg_id] = [raw_msg,targets,after,follow,timeout]
411 417 # track the ids in follow or after, but not those already finished
412 418 for dep_id in after.union(follow).difference(self.all_done):
413 419 if dep_id not in self.graph:
414 420 self.graph[dep_id] = set()
415 421 self.graph[dep_id].add(msg_id)
416 422
417 423 @logged
418 424 def submit_task(self, msg_id, raw_msg, targets, follow, timeout, indices=None):
419 425 """Submit a task to any of a subset of our targets."""
420 426 if indices:
421 427 loads = [self.loads[i] for i in indices]
422 428 else:
423 429 loads = self.loads
424 430 idx = self.scheme(loads)
425 431 if indices:
426 432 idx = indices[idx]
427 433 target = self.targets[idx]
428 434 # print (target, map(str, msg[:3]))
435 # send job to the engine
429 436 self.engine_stream.send(target, flags=zmq.SNDMORE, copy=False)
430 437 self.engine_stream.send_multipart(raw_msg, copy=False)
438 # update load
431 439 self.add_job(idx)
432 440 self.pending[target][msg_id] = (raw_msg, targets, MET, follow, timeout)
441 # notify Hub
433 442 content = dict(msg_id=msg_id, engine_id=target)
434 443 self.session.send(self.mon_stream, 'task_destination', content=content,
435 444 ident=['tracktask',self.session.session])
445
436 446
437 447 #-----------------------------------------------------------------------
438 448 # Result Handling
439 449 #-----------------------------------------------------------------------
440 450 @logged
441 451 def dispatch_result(self, raw_msg):
442 452 """dispatch method for result replies"""
443 453 try:
444 454 idents,msg = self.session.feed_identities(raw_msg, copy=False)
445 455 msg = self.session.unpack_message(msg, content=False, copy=False)
446 except:
456 engine = idents[0]
457 idx = self.targets.index(engine)
458 self.finish_job(idx)
459 except Exception:
447 460 self.log.error("task::Invaid result: %s"%raw_msg, exc_info=True)
448 461 return
449
462
450 463 header = msg['header']
451 464 if header.get('dependencies_met', True):
452 465 success = (header['status'] == 'ok')
453 466 self.handle_result(idents, msg['parent_header'], raw_msg, success)
454 467 # send to Hub monitor
455 468 self.mon_stream.send_multipart(['outtask']+raw_msg, copy=False)
456 469 else:
457 470 self.handle_unmet_dependency(idents, msg['parent_header'])
458 471
459 472 @logged
460 473 def handle_result(self, idents, parent, raw_msg, success=True):
461 474 """handle a real task result, either success or failure"""
462 475 # first, relay result to client
463 476 engine = idents[0]
464 477 client = idents[1]
465 478 # swap_ids for XREP-XREP mirror
466 479 raw_msg[:2] = [client,engine]
467 480 # print (map(str, raw_msg[:4]))
468 481 self.client_stream.send_multipart(raw_msg, copy=False)
469 482 # now, update our data structures
470 483 msg_id = parent['msg_id']
471 484 self.blacklist.pop(msg_id, None)
472 485 self.pending[engine].pop(msg_id)
473 486 if success:
474 487 self.completed[engine].add(msg_id)
475 488 self.all_completed.add(msg_id)
476 489 else:
477 490 self.failed[engine].add(msg_id)
478 491 self.all_failed.add(msg_id)
479 492 self.all_done.add(msg_id)
480 493 self.destinations[msg_id] = engine
481 494
482 495 self.update_graph(msg_id, success)
483 496
484 497 @logged
485 498 def handle_unmet_dependency(self, idents, parent):
486 499 """handle an unmet dependency"""
487 500 engine = idents[0]
488 501 msg_id = parent['msg_id']
489 502
490 503 if msg_id not in self.blacklist:
491 504 self.blacklist[msg_id] = set()
492 505 self.blacklist[msg_id].add(engine)
493 506
494 507 args = self.pending[engine].pop(msg_id)
495 508 raw,targets,after,follow,timeout = args
496 509
497 510 if self.blacklist[msg_id] == targets:
498 511 self.depending[msg_id] = args
499 return self.fail_unreachable(msg_id)
500
512 self.fail_unreachable(msg_id)
501 513 elif not self.maybe_run(msg_id, *args):
502 514 # resubmit failed, put it back in our dependency tree
503 515 self.save_unmet(msg_id, *args)
504 516
517 if self.hwm:
518 idx = self.targets.index(engine)
519 if self.loads[idx] == self.hwm-1:
520 self.update_graph(None)
521
522
505 523
506 524 @logged
507 def update_graph(self, dep_id, success=True):
525 def update_graph(self, dep_id=None, success=True):
508 526 """dep_id just finished. Update our dependency
509 graph and submit any jobs that just became runable."""
527 graph and submit any jobs that just became runable.
528
529 Called with dep_id=None to update graph for hwm, but without finishing
530 a task.
531 """
510 532 # print ("\n\n***********")
511 533 # pprint (dep_id)
512 534 # pprint (self.graph)
513 535 # pprint (self.depending)
514 536 # pprint (self.all_completed)
515 537 # pprint (self.all_failed)
516 538 # print ("\n\n***********\n\n")
517 if dep_id not in self.graph:
518 return
519 jobs = self.graph.pop(dep_id)
539 # update any jobs that depended on the dependency
540 jobs = self.graph.pop(dep_id, [])
541 # if we have HWM and an engine just become no longer full
542 # recheck *all* jobs:
543 if self.hwm and any( [ load==self.hwm-1 for load in self.loads]):
544 jobs = self.depending.keys()
520 545
521 546 for msg_id in jobs:
522 547 raw_msg, targets, after, follow, timeout = self.depending[msg_id]
523 548
524 549 if after.unreachable(self.all_completed, self.all_failed) or follow.unreachable(self.all_completed, self.all_failed):
525 550 self.fail_unreachable(msg_id)
526 551
527 552 elif after.check(self.all_completed, self.all_failed): # time deps met, maybe run
528 553 if self.maybe_run(msg_id, raw_msg, targets, MET, follow, timeout):
529 554
530 555 self.depending.pop(msg_id)
531 556 for mid in follow.union(after):
532 557 if mid in self.graph:
533 558 self.graph[mid].remove(msg_id)
534 559
535 560 #----------------------------------------------------------------------
536 561 # methods to be overridden by subclasses
537 562 #----------------------------------------------------------------------
538 563
539 564 def add_job(self, idx):
540 565 """Called after self.targets[idx] just got the job with header.
541 566 Override with subclasses. The default ordering is simple LRU.
542 567 The default loads are the number of outstanding jobs."""
543 568 self.loads[idx] += 1
544 569 for lis in (self.targets, self.loads):
545 570 lis.append(lis.pop(idx))
546 571
547 572
548 573 def finish_job(self, idx):
549 574 """Called after self.targets[idx] just finished a job.
550 575 Override with subclasses."""
551 576 self.loads[idx] -= 1
552 577
553 578
554 579
555 580 def launch_scheduler(in_addr, out_addr, mon_addr, not_addr, config=None,logname='ZMQ',
556 581 log_addr=None, loglevel=logging.DEBUG, scheme='lru',
557 582 identity=b'task'):
558 583 from zmq.eventloop import ioloop
559 584 from zmq.eventloop.zmqstream import ZMQStream
560 585
561 586 if config:
562 587 # unwrap dict back into Config
563 588 config = Config(config)
564 589
565 590 ctx = zmq.Context()
566 591 loop = ioloop.IOLoop()
567 592 ins = ZMQStream(ctx.socket(zmq.XREP),loop)
568 593 ins.setsockopt(zmq.IDENTITY, identity)
569 594 ins.bind(in_addr)
570 595
571 596 outs = ZMQStream(ctx.socket(zmq.XREP),loop)
572 597 outs.setsockopt(zmq.IDENTITY, identity)
573 598 outs.bind(out_addr)
574 599 mons = ZMQStream(ctx.socket(zmq.PUB),loop)
575 600 mons.connect(mon_addr)
576 601 nots = ZMQStream(ctx.socket(zmq.SUB),loop)
577 602 nots.setsockopt(zmq.SUBSCRIBE, '')
578 603 nots.connect(not_addr)
579 604
580 605 scheme = globals().get(scheme, None)
581 606 # setup logging
582 607 if log_addr:
583 608 connect_logger(logname, ctx, log_addr, root="scheduler", loglevel=loglevel)
584 609 else:
585 610 local_logger(logname, loglevel)
586 611
587 612 scheduler = TaskScheduler(client_stream=ins, engine_stream=outs,
588 613 mon_stream=mons, notifier_stream=nots,
589 614 scheme=scheme, loop=loop, logname=logname,
590 615 config=config)
591 616 scheduler.start()
592 617 try:
593 618 loop.start()
594 619 except KeyboardInterrupt:
595 620 print ("interrupted, exiting...", file=sys.__stderr__)
596 621
General Comments 0
You need to be logged in to leave comments. Login now