##// END OF EJS Templates
allow load-balancing across subsets of engines
MinRK -
Show More
@@ -326,7 +326,7 b' class Client(object):'
326 326 else:
327 327 self._registration_socket.connect(url)
328 328 self._engines = ReverseDict()
329 self._ids = set()
329 self._ids = []
330 330 self.outstanding=set()
331 331 self.results = {}
332 332 self.metadata = {}
@@ -370,7 +370,8 b' class Client(object):'
370 370 for k,v in engines.iteritems():
371 371 eid = int(k)
372 372 self._engines[eid] = bytes(v) # force not unicode
373 self._ids.add(eid)
373 self._ids.append(eid)
374 self._ids = sorted(self._ids)
374 375 if sorted(self._engines.keys()) != range(len(self._engines)) and \
375 376 self._task_scheme == 'pure' and self._task_socket:
376 377 self._stop_scheduling_tasks()
@@ -470,7 +471,6 b' class Client(object):'
470 471 eid = content['id']
471 472 d = {eid : content['queue']}
472 473 self._update_engines(d)
473 self._ids.add(int(eid))
474 474
475 475 def _unregister_engine(self, msg):
476 476 """Unregister an engine that has died."""
@@ -664,9 +664,9 b' class Client(object):'
664 664 """property for convenient RemoteFunction generation.
665 665
666 666 >>> @client.remote
667 ... def f():
667 ... def getpid():
668 668 import os
669 print (os.getpid())
669 return os.getpid()
670 670 """
671 671 return remote(self, block=self.block)
672 672
@@ -867,6 +867,7 b' class Client(object):'
867 867 # pass to Dependency constructor
868 868 return list(Dependency(dep))
869 869
870 @defaultblock
870 871 def apply(self, f, args=None, kwargs=None, bound=True, block=None, targets=None,
871 872 after=None, follow=None, timeout=None):
872 873 """Call `f(*args, **kwargs)` on a remote engine(s), returning the result.
@@ -903,24 +904,9 b' class Client(object):'
903 904 Run on each specified engine
904 905 if int:
905 906 Run on single engine
906
907 after : Dependency or collection of msg_ids
908 Only for load-balanced execution (targets=None)
909 Specify a list of msg_ids as a time-based dependency.
910 This job will only be run *after* the dependencies
911 have been met.
912
913 follow : Dependency or collection of msg_ids
914 Only for load-balanced execution (targets=None)
915 Specify a list of msg_ids as a location-based dependency.
916 This job will only be run on an engine where this dependency
917 is met.
918 907
919 timeout : float/int or None
920 Only for load-balanced execution (targets=None)
921 Specify an amount of time (in seconds) for the scheduler to
922 wait for dependencies to be met before failing with a
923 DependencyTimeout.
908 after,follow,timeout only used in `apply_balanced`. See that docstring
909 for details.
924 910
925 911 Returns
926 912 -------
@@ -947,25 +933,88 b' class Client(object):'
947 933 if not isinstance(kwargs, dict):
948 934 raise TypeError("kwargs must be dict, not %s"%type(kwargs))
949 935
950 options = dict(bound=bound, block=block)
936 options = dict(bound=bound, block=block, targets=targets)
951 937
952 938 if targets is None:
953 if self._task_socket:
954 return self._apply_balanced(f, args, kwargs, timeout=timeout,
955 after=after, follow=follow, **options)
956 else:
957 msg = "Task farming is disabled"
958 if self._task_scheme == 'pure':
959 msg += " because the pure ZMQ scheduler cannot handle"
960 msg += " disappearing engines."
961 raise RuntimeError(msg)
939 return self.apply_balanced(f, args, kwargs, timeout=timeout,
940 after=after, follow=follow, **options)
962 941 else:
963 return self._apply_direct(f, args, kwargs, targets=targets, **options)
942 if follow or after or timeout:
943 msg = "follow, after, and timeout args are only used for load-balanced"
944 msg += "execution."
945 raise ValueError(msg)
946 return self._apply_direct(f, args, kwargs, **options)
964 947
965 def _apply_balanced(self, f, args, kwargs, bound=True, block=None,
948 @defaultblock
949 def apply_balanced(self, f, args, kwargs, bound=True, block=None, targets=None,
966 950 after=None, follow=None, timeout=None):
967 """The underlying method for applying functions in a load balanced
968 manner, via the task queue."""
951 """call f(*args, **kwargs) remotely in a load-balanced manner.
952
953 Parameters
954 ----------
955
956 f : function
957 The fuction to be called remotely
958 args : tuple/list
959 The positional arguments passed to `f`
960 kwargs : dict
961 The keyword arguments passed to `f`
962 bound : bool (default: True)
963 Whether to execute in the Engine(s) namespace, or in a clean
964 namespace not affecting the engine.
965 block : bool (default: self.block)
966 Whether to wait for the result, or return immediately.
967 False:
968 returns AsyncResult
969 True:
970 returns actual result(s) of f(*args, **kwargs)
971 if multiple targets:
972 list of results, matching `targets`
973 targets : int,list of ints, 'all', None
974 Specify the destination of the job.
975 if None:
976 Submit via Task queue for load-balancing.
977 if 'all':
978 Run on all active engines
979 if list:
980 Run on each specified engine
981 if int:
982 Run on single engine
983
984 after : Dependency or collection of msg_ids
985 Only for load-balanced execution (targets=None)
986 Specify a list of msg_ids as a time-based dependency.
987 This job will only be run *after* the dependencies
988 have been met.
989
990 follow : Dependency or collection of msg_ids
991 Only for load-balanced execution (targets=None)
992 Specify a list of msg_ids as a location-based dependency.
993 This job will only be run on an engine where this dependency
994 is met.
995
996 timeout : float/int or None
997 Only for load-balanced execution (targets=None)
998 Specify an amount of time (in seconds) for the scheduler to
999 wait for dependencies to be met before failing with a
1000 DependencyTimeout.
1001
1002 Returns
1003 -------
1004 if block is False:
1005 return AsyncResult wrapping msg_id
1006 output of AsyncResult.get() is identical to that of `apply(...block=True)`
1007 else:
1008 wait for, and return actual result of `f(*args, **kwargs)`
1009
1010 """
1011
1012 if self._task_socket is None:
1013 msg = "Task farming is disabled"
1014 if self._task_scheme == 'pure':
1015 msg += " because the pure ZMQ scheduler cannot handle"
1016 msg += " disappearing engines."
1017 raise RuntimeError(msg)
969 1018
970 1019 if self._task_scheme == 'pure':
971 1020 # pure zmq scheme doesn't support dependencies
@@ -978,9 +1027,26 b' class Client(object):'
978 1027 warnings.warn(msg, RuntimeWarning)
979 1028
980 1029
1030 # defaults:
1031 args = args if args is not None else []
1032 kwargs = kwargs if kwargs is not None else {}
1033
1034 if targets:
1035 idents,_ = self._build_targets(targets)
1036 else:
1037 idents = []
1038
1039 # enforce types of f,args,kwrags
1040 if not callable(f):
1041 raise TypeError("f must be callable, not %s"%type(f))
1042 if not isinstance(args, (tuple, list)):
1043 raise TypeError("args must be tuple or list, not %s"%type(args))
1044 if not isinstance(kwargs, dict):
1045 raise TypeError("kwargs must be dict, not %s"%type(kwargs))
1046
981 1047 after = self._build_dependency(after)
982 1048 follow = self._build_dependency(follow)
983 subheader = dict(after=after, follow=follow, timeout=timeout)
1049 subheader = dict(after=after, follow=follow, timeout=timeout, targets=idents)
984 1050 bufs = ss.pack_apply_message(f,args,kwargs)
985 1051 content = dict(bound=bound)
986 1052
@@ -991,31 +1057,40 b' class Client(object):'
991 1057 self.history.append(msg_id)
992 1058 ar = AsyncResult(self, [msg_id], fname=f.__name__)
993 1059 if block:
994 return ar.get()
1060 try:
1061 return ar.get()
1062 except KeyboardInterrupt:
1063 return ar
995 1064 else:
996 1065 return ar
997 1066
998 1067 def _apply_direct(self, f, args, kwargs, bound=True, block=None, targets=None):
999 1068 """Then underlying method for applying functions to specific engines
1000 via the MUX queue."""
1069 via the MUX queue.
1070
1071 Not to be called directly!
1072 """
1001 1073
1002 queues,targets = self._build_targets(targets)
1074 idents,targets = self._build_targets(targets)
1003 1075
1004 1076 subheader = {}
1005 1077 content = dict(bound=bound)
1006 1078 bufs = ss.pack_apply_message(f,args,kwargs)
1007 1079
1008 1080 msg_ids = []
1009 for queue in queues:
1081 for ident in idents:
1010 1082 msg = self.session.send(self._mux_socket, "apply_request",
1011 content=content, buffers=bufs,ident=queue, subheader=subheader)
1083 content=content, buffers=bufs, ident=ident, subheader=subheader)
1012 1084 msg_id = msg['msg_id']
1013 1085 self.outstanding.add(msg_id)
1014 1086 self.history.append(msg_id)
1015 1087 msg_ids.append(msg_id)
1016 1088 ar = AsyncResult(self, msg_ids, fname=f.__name__)
1017 1089 if block:
1018 return ar.get()
1090 try:
1091 return ar.get()
1092 except KeyboardInterrupt:
1093 return ar
1019 1094 else:
1020 1095 return ar
1021 1096
@@ -1037,6 +1112,15 b' class Client(object):'
1037 1112 """Decorator for making a RemoteFunction."""
1038 1113 return remote(self, bound=bound, targets=targets, block=block)
1039 1114
1115 def view(self, targets=None, balanced=False):
1116 """Method for constructing View objects"""
1117 if not balanced:
1118 if not targets:
1119 targets = slice(None)
1120 return self[targets]
1121 else:
1122 return LoadBalancedView(self, targets)
1123
1040 1124 #--------------------------------------------------------------------------
1041 1125 # Data movement
1042 1126 #--------------------------------------------------------------------------
@@ -265,6 +265,9 b' class TaskScheduler(SessionFactory):'
265 265 msg_id = header['msg_id']
266 266 self.all_ids.add(msg_id)
267 267
268 # targets
269 targets = set(header.get('targets', []))
270
268 271 # time dependencies
269 272 after = Dependency(header.get('after', []))
270 273 if after.all:
@@ -279,28 +282,31 b' class TaskScheduler(SessionFactory):'
279 282 # location dependencies
280 283 follow = Dependency(header.get('follow', []))
281 284
285 # turn timeouts into datetime objects:
286 timeout = header.get('timeout', None)
287 if timeout:
288 timeout = datetime.now() + timedelta(0,timeout,0)
289
290 args = [raw_msg, targets, after, follow, timeout]
291
292 # validate and reduce dependencies:
282 293 for dep in after,follow:
283 294 # check valid:
284 295 if msg_id in dep or dep.difference(self.all_ids):
285 self.depending[msg_id] = [raw_msg,MET,MET,None]
296 self.depending[msg_id] = args
286 297 return self.fail_unreachable(msg_id, error.InvalidDependency)
287 298 # check if unreachable:
288 299 if dep.unreachable(self.all_failed):
289 self.depending[msg_id] = [raw_msg,MET,MET,None]
300 self.depending[msg_id] = args
290 301 return self.fail_unreachable(msg_id)
291 302
292 # turn timeouts into datetime objects:
293 timeout = header.get('timeout', None)
294 if timeout:
295 timeout = datetime.now() + timedelta(0,timeout,0)
296
297 303 if after.check(self.all_completed, self.all_failed):
298 304 # time deps already met, try to run
299 if not self.maybe_run(msg_id, raw_msg, follow, timeout):
305 if not self.maybe_run(msg_id, *args):
300 306 # can't run yet
301 self.save_unmet(msg_id, raw_msg, after, follow, timeout)
307 self.save_unmet(msg_id, *args)
302 308 else:
303 self.save_unmet(msg_id, raw_msg, after, follow, timeout)
309 self.save_unmet(msg_id, *args)
304 310
305 311 # @logged
306 312 def audit_timeouts(self):
@@ -309,17 +315,18 b' class TaskScheduler(SessionFactory):'
309 315 for msg_id in self.depending.keys():
310 316 # must recheck, in case one failure cascaded to another:
311 317 if msg_id in self.depending:
312 raw,after,follow,timeout = self.depending[msg_id]
318 raw,after,targets,follow,timeout = self.depending[msg_id]
313 319 if timeout and timeout < now:
314 320 self.fail_unreachable(msg_id, timeout=True)
315 321
316 322 @logged
317 323 def fail_unreachable(self, msg_id, why=error.ImpossibleDependency):
318 """a message has become unreachable"""
324 """a task has become unreachable, send a reply with an ImpossibleDependency
325 error."""
319 326 if msg_id not in self.depending:
320 327 self.log.error("msg %r already failed!"%msg_id)
321 328 return
322 raw_msg, after, follow, timeout = self.depending.pop(msg_id)
329 raw_msg,targets,after,follow,timeout = self.depending.pop(msg_id)
323 330 for mid in follow.union(after):
324 331 if mid in self.graph:
325 332 self.graph[mid].remove(msg_id)
@@ -344,37 +351,51 b' class TaskScheduler(SessionFactory):'
344 351 self.update_graph(msg_id, success=False)
345 352
346 353 @logged
347 def maybe_run(self, msg_id, raw_msg, follow=None, timeout=None):
354 def maybe_run(self, msg_id, raw_msg, targets, after, follow, timeout):
348 355 """check location dependencies, and run if they are met."""
349
350 if follow:
356 blacklist = self.blacklist.setdefault(msg_id, set())
357 if follow or targets or blacklist:
358 # we need a can_run filter
351 359 def can_run(idx):
352 360 target = self.targets[idx]
353 return target not in self.blacklist.get(msg_id, []) and\
354 follow.check(self.completed[target], self.failed[target])
361 # check targets
362 if targets and target not in targets:
363 return False
364 # check blacklist
365 if target in blacklist:
366 return False
367 # check follow
368 return follow.check(self.completed[target], self.failed[target])
355 369
356 370 indices = filter(can_run, range(len(self.targets)))
357 371 if not indices:
372 # couldn't run
358 373 if follow.all:
374 # check follow for impossibility
359 375 dests = set()
360 376 relevant = self.all_completed if follow.success_only else self.all_done
361 377 for m in follow.intersection(relevant):
362 378 dests.add(self.destinations[m])
363 379 if len(dests) > 1:
364 380 self.fail_unreachable(msg_id)
365
366
381 return False
382 if targets:
383 # check blacklist+targets for impossibility
384 targets.difference_update(blacklist)
385 if not targets or not targets.intersection(self.targets):
386 self.fail_unreachable(msg_id)
387 return False
367 388 return False
368 389 else:
369 390 indices = None
370 391
371 self.submit_task(msg_id, raw_msg, follow, timeout, indices)
392 self.submit_task(msg_id, raw_msg, targets, follow, timeout, indices)
372 393 return True
373 394
374 395 @logged
375 def save_unmet(self, msg_id, raw_msg, after, follow, timeout):
396 def save_unmet(self, msg_id, raw_msg, targets, after, follow, timeout):
376 397 """Save a message for later submission when its dependencies are met."""
377 self.depending[msg_id] = [raw_msg,after,follow,timeout]
398 self.depending[msg_id] = [raw_msg,targets,after,follow,timeout]
378 399 # track the ids in follow or after, but not those already finished
379 400 for dep_id in after.union(follow).difference(self.all_done):
380 401 if dep_id not in self.graph:
@@ -382,7 +403,7 b' class TaskScheduler(SessionFactory):'
382 403 self.graph[dep_id].add(msg_id)
383 404
384 405 @logged
385 def submit_task(self, msg_id, raw_msg, follow, timeout, indices=None):
406 def submit_task(self, msg_id, raw_msg, targets, follow, timeout, indices=None):
386 407 """Submit a task to any of a subset of our targets."""
387 408 if indices:
388 409 loads = [self.loads[i] for i in indices]
@@ -396,7 +417,7 b' class TaskScheduler(SessionFactory):'
396 417 self.engine_stream.send(target, flags=zmq.SNDMORE, copy=False)
397 418 self.engine_stream.send_multipart(raw_msg, copy=False)
398 419 self.add_job(idx)
399 self.pending[target][msg_id] = (raw_msg, follow, timeout)
420 self.pending[target][msg_id] = (raw_msg, targets, MET, follow, timeout)
400 421 content = dict(msg_id=msg_id, engine_id=target)
401 422 self.session.send(self.mon_stream, 'task_destination', content=content,
402 423 ident=['tracktask',self.session.session])
@@ -406,6 +427,7 b' class TaskScheduler(SessionFactory):'
406 427 #-----------------------------------------------------------------------
407 428 @logged
408 429 def dispatch_result(self, raw_msg):
430 """dispatch method for result replies"""
409 431 try:
410 432 idents,msg = self.session.feed_identities(raw_msg, copy=False)
411 433 msg = self.session.unpack_message(msg, content=False, copy=False)
@@ -424,6 +446,7 b' class TaskScheduler(SessionFactory):'
424 446
425 447 @logged
426 448 def handle_result(self, idents, parent, raw_msg, success=True):
449 """handle a real task result, either success or failure"""
427 450 # first, relay result to client
428 451 engine = idents[0]
429 452 client = idents[1]
@@ -448,21 +471,30 b' class TaskScheduler(SessionFactory):'
448 471
449 472 @logged
450 473 def handle_unmet_dependency(self, idents, parent):
474 """handle an unmet dependency"""
451 475 engine = idents[0]
452 476 msg_id = parent['msg_id']
477
453 478 if msg_id not in self.blacklist:
454 479 self.blacklist[msg_id] = set()
455 480 self.blacklist[msg_id].add(engine)
456 raw_msg,follow,timeout = self.pending[engine].pop(msg_id)
457 if not self.maybe_run(msg_id, raw_msg, follow, timeout):
481
482 args = self.pending[engine].pop(msg_id)
483 raw,targets,after,follow,timeout = args
484
485 if self.blacklist[msg_id] == targets:
486 self.depending[msg_id] = args
487 return self.fail_unreachable(msg_id)
488
489 elif not self.maybe_run(msg_id, *args):
458 490 # resubmit failed, put it back in our dependency tree
459 self.save_unmet(msg_id, raw_msg, MET, follow, timeout)
460 pass
491 self.save_unmet(msg_id, *args)
492
461 493
462 494 @logged
463 495 def update_graph(self, dep_id, success=True):
464 496 """dep_id just finished. Update our dependency
465 table and submit any jobs that just became runable."""
497 graph and submit any jobs that just became runable."""
466 498 # print ("\n\n***********")
467 499 # pprint (dep_id)
468 500 # pprint (self.graph)
@@ -475,7 +507,7 b' class TaskScheduler(SessionFactory):'
475 507 jobs = self.graph.pop(dep_id)
476 508
477 509 for msg_id in jobs:
478 raw_msg, after, follow, timeout = self.depending[msg_id]
510 raw_msg, targets, after, follow, timeout = self.depending[msg_id]
479 511 # if dep_id in after:
480 512 # if after.all and (success or not after.success_only):
481 513 # after.remove(dep_id)
@@ -484,8 +516,7 b' class TaskScheduler(SessionFactory):'
484 516 self.fail_unreachable(msg_id)
485 517
486 518 elif after.check(self.all_completed, self.all_failed): # time deps met, maybe run
487 self.depending[msg_id][1] = MET
488 if self.maybe_run(msg_id, raw_msg, follow, timeout):
519 if self.maybe_run(msg_id, raw_msg, targets, MET, follow, timeout):
489 520
490 521 self.depending.pop(msg_id)
491 522 for mid in follow.union(after):
@@ -66,10 +66,15 b' class View(object):'
66 66
67 67 Don't use this class, use subclasses.
68 68 """
69 _targets = None
70 69 block=None
71 70 bound=None
72 71 history=None
72 outstanding = set()
73 results = {}
74
75 _targets = None
76 _apply_name = 'apply'
77 _default_names = ['targets', 'block']
73 78
74 79 def __init__(self, client, targets=None):
75 80 self.client = client
@@ -80,6 +85,9 b' class View(object):'
80 85 self.history = []
81 86 self.outstanding = set()
82 87 self.results = {}
88 for name in self._default_names:
89 setattr(self, name, getattr(self, name, None))
90
83 91
84 92 def __repr__(self):
85 93 strtargets = str(self._targets)
@@ -95,11 +103,23 b' class View(object):'
95 103 def targets(self, value):
96 104 self._targets = value
97 105 # raise AttributeError("Cannot set my targets argument after construction!")
98
106
107 def _defaults(self, *excludes):
108 """return dict of our default attributes, excluding names given."""
109 d = {}
110 for name in self._default_names:
111 if name not in excludes:
112 d[name] = getattr(self, name)
113 return d
114
99 115 @sync_results
100 116 def spin(self):
101 117 """spin the client, and sync"""
102 118 self.client.spin()
119
120 @property
121 def _apply(self):
122 return getattr(self.client, self._apply_name)
103 123
104 124 @sync_results
105 125 @save_ids
@@ -113,7 +133,7 b' class View(object):'
113 133 else:
114 134 returns actual result of f(*args, **kwargs)
115 135 """
116 return self.client.apply(f, args, kwargs, block=self.block, targets=self.targets, bound=self.bound)
136 return self._apply(f, args, kwargs, **self._defaults())
117 137
118 138 @save_ids
119 139 def apply_async(self, f, *args, **kwargs):
@@ -123,7 +143,8 b' class View(object):'
123 143
124 144 returns msg_id
125 145 """
126 return self.client.apply(f,args,kwargs, block=False, targets=self.targets, bound=False)
146 d = self._defaults('block', 'bound')
147 return self._apply(f,args,kwargs, block=False, bound=False, **d)
127 148
128 149 @spin_after
129 150 @save_ids
@@ -135,7 +156,8 b' class View(object):'
135 156
136 157 returns: actual result of f(*args, **kwargs)
137 158 """
138 return self.client.apply(f,args,kwargs, block=True, targets=self.targets, bound=False)
159 d = self._defaults('block', 'bound')
160 return self._apply(f,args,kwargs, block=True, bound=False, **d)
139 161
140 162 @sync_results
141 163 @save_ids
@@ -150,7 +172,8 b' class View(object):'
150 172 This method has access to the targets' globals
151 173
152 174 """
153 return self.client.apply(f, args, kwargs, block=self.block, targets=self.targets, bound=True)
175 d = self._defaults('bound')
176 return self._apply(f, args, kwargs, bound=True, **d)
154 177
155 178 @sync_results
156 179 @save_ids
@@ -163,7 +186,8 b' class View(object):'
163 186 This method has access to the targets' globals
164 187
165 188 """
166 return self.client.apply(f, args, kwargs, block=False, targets=self.targets, bound=True)
189 d = self._defaults('block', 'bound')
190 return self._apply(f, args, kwargs, block=False, bound=True, **d)
167 191
168 192 @spin_after
169 193 @save_ids
@@ -175,7 +199,8 b' class View(object):'
175 199 This method has access to the targets' globals
176 200
177 201 """
178 return self.client.apply(f, args, kwargs, block=True, targets=self.targets, bound=True)
202 d = self._defaults('block', 'bound')
203 return self._apply(f, args, kwargs, block=True, bound=True, **d)
179 204
180 205 @spin_after
181 206 @save_ids
@@ -337,24 +362,22 b' class LoadBalancedView(View):'
337 362
338 363 Typically created via:
339 364
340 >>> lbv = client[None]
341 <LoadBalancedView tcp://127.0.0.1:12345>
365 >>> v = client[None]
366 <LoadBalancedView None>
342 367
343 368 but can also be created with:
344 369
345 >>> lbc = LoadBalancedView(client)
370 >>> v = client.view([1,3],balanced=True)
371
372 which would restrict loadbalancing to between engines 1 and 3.
346 373
347 TODO: allow subset of engines across which to balance.
348 374 """
349 def __repr__(self):
350 return "<%s %s>"%(self.__class__.__name__, self.client._config['url'])
351 375
352 @property
353 def targets(self):
354 return None
355
356 @targets.setter
357 def targets(self, value):
358 raise AttributeError("Cannot set targets for LoadbalancedView!")
359
360 No newline at end of file
376 _apply_name = 'apply_balanced'
377 _default_names = ['targets', 'block', 'bound', 'follow', 'after', 'timeout']
378
379 def __init__(self, client, targets=None):
380 super(LoadBalancedView, self).__init__(client, targets)
381 self._ntargets = 1
382 self._apply_name = 'apply_balanced'
383
General Comments 0
You need to be logged in to leave comments. Login now