##// END OF EJS Templates
basic LoadBalancedView, RemoteFunction
MinRK -
Show More
@@ -1,818 +1,854
1 1 """A semi-synchronous Client for the ZMQ controller"""
2 2 #-----------------------------------------------------------------------------
3 3 # Copyright (C) 2010 The IPython Development Team
4 4 #
5 5 # Distributed under the terms of the BSD License. The full license is in
6 6 # the file COPYING, distributed as part of this software.
7 7 #-----------------------------------------------------------------------------
8 8
9 9 #-----------------------------------------------------------------------------
10 10 # Imports
11 11 #-----------------------------------------------------------------------------
12 12
13 13 from __future__ import print_function
14 14
15 15 import time
16 16 from pprint import pprint
17 17
18 18 import zmq
19 19 from zmq.eventloop import ioloop, zmqstream
20 20
21 21 from IPython.external.decorator import decorator
22 22
23 23 import streamsession as ss
24 from remotenamespace import RemoteNamespace
25 from view import DirectView
24 # from remotenamespace import RemoteNamespace
25 from view import DirectView, LoadBalancedView
26 26 from dependency import Dependency, depend, require
27 27
28 28 def _push(ns):
29 29 globals().update(ns)
30 30
31 31 def _pull(keys):
32 32 g = globals()
33 33 if isinstance(keys, (list,tuple, set)):
34 for key in keys:
35 if not g.has_key(key):
36 raise NameError("name '%s' is not defined"%key)
34 37 return map(g.get, keys)
35 38 else:
39 if not g.has_key(keys):
40 raise NameError("name '%s' is not defined"%keys)
36 41 return g.get(keys)
37 42
38 43 def _clear():
39 44 globals().clear()
40 45
41 46 def execute(code):
42 47 exec code in globals()
43 48
44 49 #--------------------------------------------------------------------------
45 50 # Decorators for Client methods
46 51 #--------------------------------------------------------------------------
47 52
48 53 @decorator
49 54 def spinfirst(f, self, *args, **kwargs):
50 55 """Call spin() to sync state prior to calling the method."""
51 56 self.spin()
52 57 return f(self, *args, **kwargs)
53 58
54 59 @decorator
55 60 def defaultblock(f, self, *args, **kwargs):
56 61 """Default to self.block; preserve self.block."""
57 62 block = kwargs.get('block',None)
58 63 block = self.block if block is None else block
59 64 saveblock = self.block
60 65 self.block = block
61 66 ret = f(self, *args, **kwargs)
62 67 self.block = saveblock
63 68 return ret
64 69
70 def remote(client, block=None, targets=None):
71 """Turn a function into a remote function.
72
73 This method can be used for map:
74
75 >>> @remote(client,block=True)
76 def func(a)
77 """
78 def remote_function(f):
79 return RemoteFunction(client, f, block, targets)
80 return remote_function
81
65 82 #--------------------------------------------------------------------------
66 83 # Classes
67 84 #--------------------------------------------------------------------------
68 85
86 class RemoteFunction(object):
87 """Turn an existing function into a remote function"""
88
89 def __init__(self, client, f, block=None, targets=None):
90 self.client = client
91 self.func = f
92 self.block=block
93 self.targets=targets
94
95 def __call__(self, *args, **kwargs):
96 return self.client.apply(self.func, args=args, kwargs=kwargs,
97 block=self.block, targets=self.targets)
98
69 99
70 100 class AbortedTask(object):
71 101 """A basic wrapper object describing an aborted task."""
72 102 def __init__(self, msg_id):
73 103 self.msg_id = msg_id
74 104
75 105 class ControllerError(Exception):
76 106 def __init__(self, etype, evalue, tb):
77 107 self.etype = etype
78 108 self.evalue = evalue
79 109 self.traceback=tb
80 110
81 111 class Client(object):
82 112 """A semi-synchronous client to the IPython ZMQ controller
83 113
84 114 Parameters
85 115 ----------
86 116
87 addr : bytes; zmq url, e.g. 'tcp://127.0.0.1:10101
117 addr : bytes; zmq url, e.g. 'tcp://127.0.0.1:10101'
88 118 The address of the controller's registration socket.
89 119
90 120
91 121 Attributes
92 122 ----------
93 123 ids : set of int engine IDs
94 124 requesting the ids attribute always synchronizes
95 125 the registration state. To request ids without synchronization,
96 126 use semi-private _ids.
97 127
98 128 history : list of msg_ids
99 129 a list of msg_ids, keeping track of all the execution
100 130 messages you have submitted in order.
101 131
102 132 outstanding : set of msg_ids
103 133 a set of msg_ids that have been submitted, but whose
104 134 results have not yet been received.
105 135
106 136 results : dict
107 137 a dict of all our results, keyed by msg_id
108 138
109 139 block : bool
110 140 determines default behavior when block not specified
111 141 in execution methods
112 142
113 143 Methods
114 144 -------
115 145 spin : flushes incoming results and registration state changes
116 146 control methods spin, and requesting `ids` also ensures up to date
117 147
118 148 barrier : wait on one or more msg_ids
119 149
120 150 execution methods: apply/apply_bound/apply_to/applu_bount
121 151 legacy: execute, run
122 152
123 153 query methods: queue_status, get_result, purge
124 154
125 155 control methods: abort, kill
126 156
127 157 """
128 158
129 159
130 160 _connected=False
131 161 _engines=None
132 162 _addr='tcp://127.0.0.1:10101'
133 163 _registration_socket=None
134 164 _query_socket=None
135 165 _control_socket=None
136 166 _notification_socket=None
137 167 _mux_socket=None
138 168 _task_socket=None
139 169 block = False
140 170 outstanding=None
141 171 results = None
142 172 history = None
143 173 debug = False
144 174
145 175 def __init__(self, addr='tcp://127.0.0.1:10101', context=None, username=None, debug=False):
146 176 if context is None:
147 177 context = zmq.Context()
148 178 self.context = context
149 179 self._addr = addr
150 180 if username is None:
151 181 self.session = ss.StreamSession()
152 182 else:
153 183 self.session = ss.StreamSession(username)
154 184 self._registration_socket = self.context.socket(zmq.PAIR)
155 185 self._registration_socket.setsockopt(zmq.IDENTITY, self.session.session)
156 186 self._registration_socket.connect(addr)
157 187 self._engines = {}
158 188 self._ids = set()
159 189 self.outstanding=set()
160 190 self.results = {}
161 191 self.history = []
162 192 self.debug = debug
163 193 self.session.debug = debug
164 194
165 195 self._notification_handlers = {'registration_notification' : self._register_engine,
166 196 'unregistration_notification' : self._unregister_engine,
167 197 }
168 198 self._queue_handlers = {'execute_reply' : self._handle_execute_reply,
169 199 'apply_reply' : self._handle_apply_reply}
170 200 self._connect()
171 201
172 202
173 203 @property
174 204 def ids(self):
175 205 """Always up to date ids property."""
176 206 self._flush_notifications()
177 207 return self._ids
178 208
179 209 def _update_engines(self, engines):
180 210 """Update our engines dict and _ids from a dict of the form: {id:uuid}."""
181 211 for k,v in engines.iteritems():
182 212 eid = int(k)
183 213 self._engines[eid] = bytes(v) # force not unicode
184 214 self._ids.add(eid)
185 215
186 216 def _build_targets(self, targets):
187 217 """Turn valid target IDs or 'all' into two lists:
188 218 (int_ids, uuids).
189 219 """
190 220 if targets is None:
191 221 targets = self._ids
192 222 elif isinstance(targets, str):
193 223 if targets.lower() == 'all':
194 224 targets = self._ids
195 225 else:
196 226 raise TypeError("%r not valid str target, must be 'all'"%(targets))
197 227 elif isinstance(targets, int):
198 228 targets = [targets]
199 229 return [self._engines[t] for t in targets], list(targets)
200 230
201 231 def _connect(self):
202 232 """setup all our socket connections to the controller. This is called from
203 233 __init__."""
204 234 if self._connected:
205 235 return
206 236 self._connected=True
207 237 self.session.send(self._registration_socket, 'connection_request')
208 238 idents,msg = self.session.recv(self._registration_socket,mode=0)
209 239 if self.debug:
210 240 pprint(msg)
211 241 msg = ss.Message(msg)
212 242 content = msg.content
213 243 if content.status == 'ok':
214 244 if content.queue:
215 245 self._mux_socket = self.context.socket(zmq.PAIR)
216 246 self._mux_socket.setsockopt(zmq.IDENTITY, self.session.session)
217 247 self._mux_socket.connect(content.queue)
218 248 if content.task:
219 249 self._task_socket = self.context.socket(zmq.PAIR)
220 250 self._task_socket.setsockopt(zmq.IDENTITY, self.session.session)
221 251 self._task_socket.connect(content.task)
222 252 if content.notification:
223 253 self._notification_socket = self.context.socket(zmq.SUB)
224 254 self._notification_socket.connect(content.notification)
225 255 self._notification_socket.setsockopt(zmq.SUBSCRIBE, "")
226 256 if content.query:
227 257 self._query_socket = self.context.socket(zmq.PAIR)
228 258 self._query_socket.setsockopt(zmq.IDENTITY, self.session.session)
229 259 self._query_socket.connect(content.query)
230 260 if content.control:
231 261 self._control_socket = self.context.socket(zmq.PAIR)
232 262 self._control_socket.setsockopt(zmq.IDENTITY, self.session.session)
233 263 self._control_socket.connect(content.control)
234 264 self._update_engines(dict(content.engines))
235 265
236 266 else:
237 267 self._connected = False
238 268 raise Exception("Failed to connect!")
239 269
240 270 #--------------------------------------------------------------------------
241 271 # handlers and callbacks for incoming messages
242 272 #--------------------------------------------------------------------------
243 273
244 274 def _register_engine(self, msg):
245 275 """Register a new engine, and update our connection info."""
246 276 content = msg['content']
247 277 eid = content['id']
248 278 d = {eid : content['queue']}
249 279 self._update_engines(d)
250 280 self._ids.add(int(eid))
251 281
252 282 def _unregister_engine(self, msg):
253 283 """Unregister an engine that has died."""
254 284 content = msg['content']
255 285 eid = int(content['id'])
256 286 if eid in self._ids:
257 287 self._ids.remove(eid)
258 288 self._engines.pop(eid)
259 289
260 290 def _handle_execute_reply(self, msg):
261 291 """Save the reply to an execute_request into our results."""
262 292 parent = msg['parent_header']
263 293 msg_id = parent['msg_id']
264 294 if msg_id not in self.outstanding:
265 295 print("got unknown result: %s"%msg_id)
266 296 else:
267 297 self.outstanding.remove(msg_id)
268 298 self.results[msg_id] = ss.unwrap_exception(msg['content'])
269 299
270 300 def _handle_apply_reply(self, msg):
271 301 """Save the reply to an apply_request into our results."""
272 302 parent = msg['parent_header']
273 303 msg_id = parent['msg_id']
274 304 if msg_id not in self.outstanding:
275 305 print ("got unknown result: %s"%msg_id)
276 306 else:
277 307 self.outstanding.remove(msg_id)
278 308 content = msg['content']
279 309 if content['status'] == 'ok':
280 310 self.results[msg_id] = ss.unserialize_object(msg['buffers'])
281 311 elif content['status'] == 'aborted':
282 312 self.results[msg_id] = AbortedTask(msg_id)
283 313 elif content['status'] == 'resubmitted':
284 pass # handle resubmission
314 # TODO: handle resubmission
315 pass
285 316 else:
286 317 self.results[msg_id] = ss.unwrap_exception(content)
287 318
288 319 def _flush_notifications(self):
289 320 """Flush notifications of engine registrations waiting
290 321 in ZMQ queue."""
291 322 msg = self.session.recv(self._notification_socket, mode=zmq.NOBLOCK)
292 323 while msg is not None:
293 324 if self.debug:
294 325 pprint(msg)
295 326 msg = msg[-1]
296 327 msg_type = msg['msg_type']
297 328 handler = self._notification_handlers.get(msg_type, None)
298 329 if handler is None:
299 330 raise Exception("Unhandled message type: %s"%msg.msg_type)
300 331 else:
301 332 handler(msg)
302 333 msg = self.session.recv(self._notification_socket, mode=zmq.NOBLOCK)
303 334
304 335 def _flush_results(self, sock):
305 336 """Flush task or queue results waiting in ZMQ queue."""
306 337 msg = self.session.recv(sock, mode=zmq.NOBLOCK)
307 338 while msg is not None:
308 339 if self.debug:
309 340 pprint(msg)
310 341 msg = msg[-1]
311 342 msg_type = msg['msg_type']
312 343 handler = self._queue_handlers.get(msg_type, None)
313 344 if handler is None:
314 345 raise Exception("Unhandled message type: %s"%msg.msg_type)
315 346 else:
316 347 handler(msg)
317 348 msg = self.session.recv(sock, mode=zmq.NOBLOCK)
318 349
319 350 def _flush_control(self, sock):
320 351 """Flush replies from the control channel waiting
321 in the ZMQ queue."""
352 in the ZMQ queue.
353
354 Currently: ignore them."""
322 355 msg = self.session.recv(sock, mode=zmq.NOBLOCK)
323 356 while msg is not None:
324 357 if self.debug:
325 358 pprint(msg)
326 359 msg = self.session.recv(sock, mode=zmq.NOBLOCK)
327 360
328 361 #--------------------------------------------------------------------------
329 362 # getitem
330 363 #--------------------------------------------------------------------------
331 364
332 365 def __getitem__(self, key):
333 """Dict access returns DirectView multiplexer objects."""
366 """Dict access returns DirectView multiplexer objects or,
367 if key is None, a LoadBalancedView."""
368 if key is None:
369 return LoadBalancedView(self)
334 370 if isinstance(key, int):
335 371 if key not in self.ids:
336 372 raise IndexError("No such engine: %i"%key)
337 373 return DirectView(self, key)
338 374
339 375 if isinstance(key, slice):
340 376 indices = range(len(self.ids))[key]
341 377 ids = sorted(self._ids)
342 378 key = [ ids[i] for i in indices ]
343 379 # newkeys = sorted(self._ids)[thekeys[k]]
344 380
345 381 if isinstance(key, (tuple, list, xrange)):
346 382 _,targets = self._build_targets(list(key))
347 383 return DirectView(self, targets)
348 384 else:
349 385 raise TypeError("key by int/iterable of ints only, not %s"%(type(key)))
350 386
351 387 #--------------------------------------------------------------------------
352 388 # Begin public methods
353 389 #--------------------------------------------------------------------------
354 390
355 391 def spin(self):
356 392 """Flush any registration notifications and execution results
357 393 waiting in the ZMQ queue.
358 394 """
359 395 if self._notification_socket:
360 396 self._flush_notifications()
361 397 if self._mux_socket:
362 398 self._flush_results(self._mux_socket)
363 399 if self._task_socket:
364 400 self._flush_results(self._task_socket)
365 401 if self._control_socket:
366 402 self._flush_control(self._control_socket)
367 403
368 404 def barrier(self, msg_ids=None, timeout=-1):
369 405 """waits on one or more `msg_ids`, for up to `timeout` seconds.
370 406
371 407 Parameters
372 408 ----------
373 409 msg_ids : int, str, or list of ints and/or strs
374 410 ints are indices to self.history
375 411 strs are msg_ids
376 412 default: wait on all outstanding messages
377 413 timeout : float
378 414 a time in seconds, after which to give up.
379 415 default is -1, which means no timeout
380 416
381 417 Returns
382 418 -------
383 419 True : when all msg_ids are done
384 420 False : timeout reached, some msg_ids still outstanding
385 421 """
386 422 tic = time.time()
387 423 if msg_ids is None:
388 424 theids = self.outstanding
389 425 else:
390 426 if isinstance(msg_ids, (int, str)):
391 427 msg_ids = [msg_ids]
392 428 theids = set()
393 429 for msg_id in msg_ids:
394 430 if isinstance(msg_id, int):
395 431 msg_id = self.history[msg_id]
396 432 theids.add(msg_id)
397 433 self.spin()
398 434 while theids.intersection(self.outstanding):
399 435 if timeout >= 0 and ( time.time()-tic ) > timeout:
400 436 break
401 437 time.sleep(1e-3)
402 438 self.spin()
403 439 return len(theids.intersection(self.outstanding)) == 0
404 440
405 441 #--------------------------------------------------------------------------
406 442 # Control methods
407 443 #--------------------------------------------------------------------------
408 444
409 445 @spinfirst
410 446 @defaultblock
411 447 def clear(self, targets=None, block=None):
412 448 """Clear the namespace in target(s)."""
413 449 targets = self._build_targets(targets)[0]
414 450 for t in targets:
415 self.session.send(self._control_socket, 'clear_request', content={},ident=t)
451 self.session.send(self._control_socket, 'clear_request', content={}, ident=t)
416 452 error = False
417 453 if self.block:
418 454 for i in range(len(targets)):
419 455 idents,msg = self.session.recv(self._control_socket,0)
420 456 if self.debug:
421 457 pprint(msg)
422 458 if msg['content']['status'] != 'ok':
423 error = msg['content']
459 error = ss.unwrap_exception(msg['content'])
424 460 if error:
425 461 return error
426 462
427 463
428 464 @spinfirst
429 465 @defaultblock
430 466 def abort(self, msg_ids = None, targets=None, block=None):
431 467 """Abort the execution queues of target(s)."""
432 468 targets = self._build_targets(targets)[0]
433 469 if isinstance(msg_ids, basestring):
434 470 msg_ids = [msg_ids]
435 471 content = dict(msg_ids=msg_ids)
436 472 for t in targets:
437 473 self.session.send(self._control_socket, 'abort_request',
438 474 content=content, ident=t)
439 475 error = False
440 476 if self.block:
441 477 for i in range(len(targets)):
442 478 idents,msg = self.session.recv(self._control_socket,0)
443 479 if self.debug:
444 480 pprint(msg)
445 481 if msg['content']['status'] != 'ok':
446 error = msg['content']
482 error = ss.unwrap_exception(msg['content'])
447 483 if error:
448 484 return error
449 485
450 486 @spinfirst
451 487 @defaultblock
452 488 def kill(self, targets=None, block=None):
453 489 """Terminates one or more engine processes."""
454 490 targets = self._build_targets(targets)[0]
455 491 for t in targets:
456 492 self.session.send(self._control_socket, 'kill_request', content={},ident=t)
457 493 error = False
458 494 if self.block:
459 495 for i in range(len(targets)):
460 496 idents,msg = self.session.recv(self._control_socket,0)
461 497 if self.debug:
462 498 pprint(msg)
463 499 if msg['content']['status'] != 'ok':
464 error = msg['content']
500 error = ss.unwrap_exception(msg['content'])
465 501 if error:
466 502 return error
467 503
468 504 #--------------------------------------------------------------------------
469 505 # Execution methods
470 506 #--------------------------------------------------------------------------
471 507
472 508 @defaultblock
473 509 def execute(self, code, targets='all', block=None):
474 510 """Executes `code` on `targets` in blocking or nonblocking manner.
475 511
476 512 Parameters
477 513 ----------
478 514 code : str
479 515 the code string to be executed
480 516 targets : int/str/list of ints/strs
481 517 the engines on which to execute
482 518 default : all
483 519 block : bool
484 520 whether or not to wait until done to return
485 521 default: self.block
486 522 """
487 523 # block = self.block if block is None else block
488 524 # saveblock = self.block
489 525 # self.block = block
490 526 result = self.apply(execute, (code,), targets=targets, block=block, bound=True)
491 527 # self.block = saveblock
492 528 return result
493 529
494 530 def run(self, code, block=None):
495 531 """Runs `code` on an engine.
496 532
497 533 Calls to this are load-balanced.
498 534
499 535 Parameters
500 536 ----------
501 537 code : str
502 538 the code string to be executed
503 539 block : bool
504 540 whether or not to wait until done
505 541
506 542 """
507 543 result = self.apply(execute, (code,), targets=None, block=block, bound=False)
508 544 return result
509 545
510 546 def apply(self, f, args=None, kwargs=None, bound=True, block=None, targets=None,
511 547 after=None, follow=None):
512 548 """Call `f(*args, **kwargs)` on a remote engine(s), returning the result.
513 549
514 550 This is the central execution command for the client.
515 551
516 552 Parameters
517 553 ----------
518 554
519 555 f : function
520 556 The fuction to be called remotely
521 557 args : tuple/list
522 558 The positional arguments passed to `f`
523 559 kwargs : dict
524 560 The keyword arguments passed to `f`
525 561 bound : bool (default: True)
526 562 Whether to execute in the Engine(s) namespace, or in a clean
527 563 namespace not affecting the engine.
528 564 block : bool (default: self.block)
529 565 Whether to wait for the result, or return immediately.
530 566 False:
531 567 returns msg_id(s)
532 568 if multiple targets:
533 569 list of ids
534 570 True:
535 571 returns actual result(s) of f(*args, **kwargs)
536 572 if multiple targets:
537 573 dict of results, by engine ID
538 574 targets : int,list of ints, 'all', None
539 575 Specify the destination of the job.
540 576 if None:
541 577 Submit via Task queue for load-balancing.
542 578 if 'all':
543 579 Run on all active engines
544 580 if list:
545 581 Run on each specified engine
546 582 if int:
547 583 Run on single engine
548 584
549 585 after : Dependency or collection of msg_ids
550 586 Only for load-balanced execution (targets=None)
551 587 Specify a list of msg_ids as a time-based dependency.
552 588 This job will only be run *after* the dependencies
553 589 have been met.
554 590
555 591 follow : Dependency or collection of msg_ids
556 592 Only for load-balanced execution (targets=None)
557 593 Specify a list of msg_ids as a location-based dependency.
558 594 This job will only be run on an engine where this dependency
559 595 is met.
560 596
561 597 Returns
562 598 -------
563 599 if block is False:
564 600 if single target:
565 601 return msg_id
566 602 else:
567 603 return list of msg_ids
568 604 ? (should this be dict like block=True) ?
569 605 else:
570 606 if single target:
571 607 return result of f(*args, **kwargs)
572 608 else:
573 609 return dict of results, keyed by engine
574 610 """
575 611
576 612 # defaults:
577 613 block = block if block is not None else self.block
578 614 args = args if args is not None else []
579 615 kwargs = kwargs if kwargs is not None else {}
580 616
581 617 # enforce types of f,args,kwrags
582 618 if not callable(f):
583 619 raise TypeError("f must be callable, not %s"%type(f))
584 620 if not isinstance(args, (tuple, list)):
585 621 raise TypeError("args must be tuple or list, not %s"%type(args))
586 622 if not isinstance(kwargs, dict):
587 623 raise TypeError("kwargs must be dict, not %s"%type(kwargs))
588 624
589 625 options = dict(bound=bound, block=block, after=after, follow=follow)
590 626
591 627 if targets is None:
592 628 return self._apply_balanced(f, args, kwargs, **options)
593 629 else:
594 630 return self._apply_direct(f, args, kwargs, targets=targets, **options)
595 631
596 632 def _apply_balanced(self, f, args, kwargs, bound=True, block=None,
597 633 after=None, follow=None):
598 634 """The underlying method for applying functions in a load balanced
599 635 manner, via the task queue."""
600 636 if isinstance(after, Dependency):
601 637 after = after.as_dict()
602 638 elif after is None:
603 639 after = []
604 640 if isinstance(follow, Dependency):
605 641 follow = follow.as_dict()
606 642 elif follow is None:
607 643 follow = []
608 644 subheader = dict(after=after, follow=follow)
609 645
610 646 bufs = ss.pack_apply_message(f,args,kwargs)
611 647 content = dict(bound=bound)
612 648 msg = self.session.send(self._task_socket, "apply_request",
613 649 content=content, buffers=bufs, subheader=subheader)
614 650 msg_id = msg['msg_id']
615 651 self.outstanding.add(msg_id)
616 652 self.history.append(msg_id)
617 653 if block:
618 654 self.barrier(msg_id)
619 655 return self.results[msg_id]
620 656 else:
621 657 return msg_id
622 658
623 659 def _apply_direct(self, f, args, kwargs, bound=True, block=None, targets=None,
624 660 after=None, follow=None):
625 661 """Then underlying method for applying functions to specific engines
626 662 via the MUX queue."""
627 663
628 664 queues,targets = self._build_targets(targets)
629 665 bufs = ss.pack_apply_message(f,args,kwargs)
630 666 if isinstance(after, Dependency):
631 667 after = after.as_dict()
632 668 elif after is None:
633 669 after = []
634 670 if isinstance(follow, Dependency):
635 671 follow = follow.as_dict()
636 672 elif follow is None:
637 673 follow = []
638 674 subheader = dict(after=after, follow=follow)
639 675 content = dict(bound=bound)
640 676 msg_ids = []
641 677 for queue in queues:
642 678 msg = self.session.send(self._mux_socket, "apply_request",
643 679 content=content, buffers=bufs,ident=queue, subheader=subheader)
644 680 msg_id = msg['msg_id']
645 681 self.outstanding.add(msg_id)
646 682 self.history.append(msg_id)
647 683 msg_ids.append(msg_id)
648 684 if block:
649 685 self.barrier(msg_ids)
650 686 else:
651 687 if len(msg_ids) == 1:
652 688 return msg_ids[0]
653 689 else:
654 690 return msg_ids
655 691 if len(msg_ids) == 1:
656 692 return self.results[msg_ids[0]]
657 693 else:
658 694 result = {}
659 695 for target,mid in zip(targets, msg_ids):
660 696 result[target] = self.results[mid]
661 697 return result
662 698
663 699 #--------------------------------------------------------------------------
664 700 # Data movement
665 701 #--------------------------------------------------------------------------
666 702
667 703 @defaultblock
668 704 def push(self, ns, targets=None, block=None):
669 705 """Push the contents of `ns` into the namespace on `target`"""
670 706 if not isinstance(ns, dict):
671 707 raise TypeError("Must be a dict, not %s"%type(ns))
672 708 result = self.apply(_push, (ns,), targets=targets, block=block, bound=True)
673 709 return result
674 710
675 711 @defaultblock
676 712 def pull(self, keys, targets=None, block=True):
677 713 """Pull objects from `target`'s namespace by `keys`"""
678 714 if isinstance(keys, str):
679 715 pass
680 716 elif isistance(keys, (list,tuple,set)):
681 717 for key in keys:
682 718 if not isinstance(key, str):
683 719 raise TypeError
684 720 result = self.apply(_pull, (keys,), targets=targets, block=block, bound=True)
685 721 return result
686 722
687 723 #--------------------------------------------------------------------------
688 724 # Query methods
689 725 #--------------------------------------------------------------------------
690 726
691 727 @spinfirst
692 728 def get_results(self, msg_ids, status_only=False):
693 729 """Returns the result of the execute or task request with `msg_ids`.
694 730
695 731 Parameters
696 732 ----------
697 733 msg_ids : list of ints or msg_ids
698 734 if int:
699 735 Passed as index to self.history for convenience.
700 736 status_only : bool (default: False)
701 737 if False:
702 738 return the actual results
703 739 """
704 740 if not isinstance(msg_ids, (list,tuple)):
705 741 msg_ids = [msg_ids]
706 742 theids = []
707 743 for msg_id in msg_ids:
708 744 if isinstance(msg_id, int):
709 745 msg_id = self.history[msg_id]
710 746 if not isinstance(msg_id, str):
711 747 raise TypeError("msg_ids must be str, not %r"%msg_id)
712 748 theids.append(msg_id)
713 749
714 750 completed = []
715 751 local_results = {}
716 752 for msg_id in list(theids):
717 753 if msg_id in self.results:
718 754 completed.append(msg_id)
719 755 local_results[msg_id] = self.results[msg_id]
720 756 theids.remove(msg_id)
721 757
722 if msg_ids: # some not locally cached
758 if theids: # some not locally cached
723 759 content = dict(msg_ids=theids, status_only=status_only)
724 760 msg = self.session.send(self._query_socket, "result_request", content=content)
725 761 zmq.select([self._query_socket], [], [])
726 762 idents,msg = self.session.recv(self._query_socket, zmq.NOBLOCK)
727 763 if self.debug:
728 764 pprint(msg)
729 765 content = msg['content']
730 766 if content['status'] != 'ok':
731 767 raise ss.unwrap_exception(content)
732 768 else:
733 769 content = dict(completed=[],pending=[])
734 770 if not status_only:
735 771 # load cached results into result:
736 772 content['completed'].extend(completed)
737 773 content.update(local_results)
738 774 # update cache with results:
739 775 for msg_id in msg_ids:
740 776 if msg_id in content['completed']:
741 777 self.results[msg_id] = content[msg_id]
742 778 return content
743 779
744 780 @spinfirst
745 781 def queue_status(self, targets=None, verbose=False):
746 782 """Fetch the status of engine queues.
747 783
748 784 Parameters
749 785 ----------
750 786 targets : int/str/list of ints/strs
751 787 the engines on which to execute
752 788 default : all
753 789 verbose : bool
754 790 whether to return lengths only, or lists of ids for each element
755 791 """
756 792 targets = self._build_targets(targets)[1]
757 793 content = dict(targets=targets, verbose=verbose)
758 794 self.session.send(self._query_socket, "queue_request", content=content)
759 795 idents,msg = self.session.recv(self._query_socket, 0)
760 796 if self.debug:
761 797 pprint(msg)
762 798 content = msg['content']
763 799 status = content.pop('status')
764 800 if status != 'ok':
765 801 raise ss.unwrap_exception(content)
766 802 return content
767 803
768 804 @spinfirst
769 805 def purge_results(self, msg_ids=[], targets=[]):
770 806 """Tell the controller to forget results.
771 807
772 808 Individual results can be purged by msg_id, or the entire
773 809 history of specific targets can
774 810
775 811 Parameters
776 812 ----------
777 813 targets : int/str/list of ints/strs
778 814 the targets
779 815 default : None
780 816 """
781 817 if not targets and not msg_ids:
782 818 raise ValueError
783 819 if targets:
784 820 targets = self._build_targets(targets)[1]
785 821 content = dict(targets=targets, msg_ids=msg_ids)
786 822 self.session.send(self._query_socket, "purge_request", content=content)
787 823 idents, msg = self.session.recv(self._query_socket, 0)
788 824 if self.debug:
789 825 pprint(msg)
790 826 content = msg['content']
791 827 if content['status'] != 'ok':
792 828 raise ss.unwrap_exception(content)
793 829
794 830 class AsynClient(Client):
795 831 """An Asynchronous client, using the Tornado Event Loop.
796 832 !!!unfinished!!!"""
797 833 io_loop = None
798 834 _queue_stream = None
799 835 _notifier_stream = None
800 836 _task_stream = None
801 837 _control_stream = None
802 838
803 839 def __init__(self, addr, context=None, username=None, debug=False, io_loop=None):
804 840 Client.__init__(self, addr, context, username, debug)
805 841 if io_loop is None:
806 842 io_loop = ioloop.IOLoop.instance()
807 843 self.io_loop = io_loop
808 844
809 845 self._queue_stream = zmqstream.ZMQStream(self._mux_socket, io_loop)
810 846 self._control_stream = zmqstream.ZMQStream(self._control_socket, io_loop)
811 847 self._task_stream = zmqstream.ZMQStream(self._task_socket, io_loop)
812 848 self._notification_stream = zmqstream.ZMQStream(self._notification_socket, io_loop)
813 849
814 850 def spin(self):
815 851 for stream in (self.queue_stream, self.notifier_stream,
816 852 self.task_stream, self.control_stream):
817 853 stream.flush()
818 854 No newline at end of file
@@ -1,200 +1,220
1 1 #!/usr/bin/env python
2 2 """Views"""
3 3
4 4 from IPython.external.decorator import decorator
5 5
6 6
7 7 @decorator
8 8 def myblock(f, self, *args, **kwargs):
9 9 block = self.client.block
10 10 self.client.block = self.block
11 11 ret = f(self, *args, **kwargs)
12 12 self.client.block = block
13 13 return ret
14 14
15 15 @decorator
16 16 def save_ids(f, self, *args, **kwargs):
17 17 ret = f(self, *args, **kwargs)
18 18 msg_ids = self.client.history[-self._ntargets:]
19 19 self.history.extend(msg_ids)
20 20 map(self.outstanding.add, msg_ids)
21 21 return ret
22 22
23 23 @decorator
24 24 def sync_results(f, self, *args, **kwargs):
25 25 ret = f(self, *args, **kwargs)
26 26 delta = self.outstanding.difference(self.client.outstanding)
27 27 completed = self.outstanding.intersection(delta)
28 28 self.outstanding = self.outstanding.difference(completed)
29 29 for msg_id in completed:
30 30 self.results[msg_id] = self.client.results[msg_id]
31 31 return ret
32 32
33 33 @decorator
34 34 def spin_after(f, self, *args, **kwargs):
35 35 ret = f(self, *args, **kwargs)
36 36 self.spin()
37 37 return ret
38 38
39 39
40 40 class View(object):
41 41 """Base View class"""
42 42 _targets = None
43 43 _ntargets = None
44 44 block=None
45 bound=None
45 46 history=None
46 47
47 def __init__(self, client, targets):
48 def __init__(self, client, targets=None):
48 49 self.client = client
49 50 self._targets = targets
50 self._ntargets = 1 if isinstance(targets, int) else len(targets)
51 self._ntargets = 1 if isinstance(targets, (int,type(None))) else len(targets)
51 52 self.block = client.block
53 self.bound=True
52 54 self.history = []
53 55 self.outstanding = set()
54 56 self.results = {}
55 57
56 58 def __repr__(self):
57 59 strtargets = str(self._targets)
58 60 if len(strtargets) > 16:
59 61 strtargets = strtargets[:12]+'...]'
60 62 return "<%s %s>"%(self.__class__.__name__, strtargets)
61 63
62 64 @property
63 65 def targets(self):
64 66 return self._targets
65 67
66 68 @targets.setter
67 69 def targets(self, value):
68 70 raise TypeError("Cannot set my targets argument after construction!")
69 71
70 72 @sync_results
71 73 def spin(self):
72 74 """spin the client, and sync"""
73 75 self.client.spin()
74 76
75 77 @sync_results
76 78 @save_ids
77 79 def apply(self, f, *args, **kwargs):
78 80 """calls f(*args, **kwargs) on remote engines, returning the result.
79 81
80 82 This method does not involve the engine's namespace.
81 83
82 84 if self.block is False:
83 85 returns msg_id
84 86 else:
85 87 returns actual result of f(*args, **kwargs)
86 88 """
87 return self.client.apply(f, args, kwargs, block=self.block, targets=self.targets, bound=False)
89 return self.client.apply(f, args, kwargs, block=self.block, targets=self.targets, bound=self.bound)
88 90
89 91 @save_ids
90 92 def apply_async(self, f, *args, **kwargs):
91 93 """calls f(*args, **kwargs) on remote engines in a nonblocking manner.
92 94
93 95 This method does not involve the engine's namespace.
94 96
95 97 returns msg_id
96 98 """
97 99 return self.client.apply(f,args,kwargs, block=False, targets=self.targets, bound=False)
98 100
99 101 @spin_after
100 102 @save_ids
101 103 def apply_sync(self, f, *args, **kwargs):
102 104 """calls f(*args, **kwargs) on remote engines in a blocking manner,
103 105 returning the result.
104 106
105 107 This method does not involve the engine's namespace.
106 108
107 109 returns: actual result of f(*args, **kwargs)
108 110 """
109 111 return self.client.apply(f,args,kwargs, block=True, targets=self.targets, bound=False)
110 112
111 113 @sync_results
112 114 @save_ids
113 115 def apply_bound(self, f, *args, **kwargs):
114 116 """calls f(*args, **kwargs) bound to engine namespace(s).
115 117
116 118 if self.block is False:
117 119 returns msg_id
118 120 else:
119 121 returns actual result of f(*args, **kwargs)
120 122
121 123 This method has access to the targets' globals
122 124
123 125 """
124 126 return self.client.apply(f, args, kwargs, block=self.block, targets=self.targets, bound=True)
125 127
126 128 @sync_results
127 129 @save_ids
128 130 def apply_async_bound(self, f, *args, **kwargs):
129 131 """calls f(*args, **kwargs) bound to engine namespace(s)
130 132 in a nonblocking manner.
131 133
132 134 returns: msg_id
133 135
134 136 This method has access to the targets' globals
135 137
136 138 """
137 139 return self.client.apply(f, args, kwargs, block=False, targets=self.targets, bound=True)
138 140
139 141 @spin_after
140 142 @save_ids
141 143 def apply_sync_bound(self, f, *args, **kwargs):
142 144 """calls f(*args, **kwargs) bound to engine namespace(s), waiting for the result.
143 145
144 146 returns: actual result of f(*args, **kwargs)
145 147
146 148 This method has access to the targets' globals
147 149
148 150 """
149 151 return self.client.apply(f, args, kwargs, block=True, targets=self.targets, bound=True)
152
153 def abort(self, msg_ids=None, block=None):
154 """Abort jobs on my engines.
155
156 Parameters
157 ----------
158
159 msg_ids : None, str, list of strs, optional
160 if None: abort all jobs.
161 else: abort specific msg_id(s).
162 """
163 block = block if block is not None else self.block
164 return self.client.abort(msg_ids=msg_ids, targets=self.targets, block=block)
165
166 def queue_status(self, verbose=False):
167 """Fetch the Queue status of my engines"""
168 return self.client.queue_status(targets=self.targets, verbose=verbose)
169
170 def purge_results(self, msg_ids=[],targets=[]):
171 """Instruct the controller to forget specific results."""
172 if targets is None or targets == 'all':
173 targets = self.targets
174 return self.client.purge_results(msg_ids=msg_ids, targets=targets)
150 175
151 176
152 177 class DirectView(View):
153 178 """Direct Multiplexer View"""
154 179
155 180 def update(self, ns):
156 181 """update remote namespace with dict `ns`"""
157 182 return self.client.push(ns, targets=self.targets, block=self.block)
158 183
184 push = update
185
159 186 def get(self, key_s):
160 187 """get object(s) by `key_s` from remote namespace
161 188 will return one object if it is a key.
162 189 It also takes a list of keys, and will return a list of objects."""
163 190 # block = block if block is not None else self.block
164 return self.client.pull(key_s, block=self.block, targets=self.targets)
191 return self.client.pull(key_s, block=True, targets=self.targets)
165 192
166 push = update
167 pull = get
193 def pull(self, key_s, block=True):
194 """get object(s) by `key_s` from remote namespace
195 will return one object if it is a key.
196 It also takes a list of keys, and will return a list of objects."""
197 block = block if block is not None else self.block
198 return self.client.pull(key_s, block=block, targets=self.targets)
168 199
169 200 def __getitem__(self, key):
170 201 return self.get(key)
171 202
172 def __setitem__(self,key,value):
203 def __setitem__(self,key, value):
173 204 self.update({key:value})
174 205
175 206 def clear(self, block=False):
176 207 """Clear the remote namespaces on my engines."""
177 208 block = block if block is not None else self.block
178 return self.client.clear(targets=self.targets,block=block)
209 return self.client.clear(targets=self.targets, block=block)
179 210
180 211 def kill(self, block=True):
181 212 """Kill my engines."""
182 213 block = block if block is not None else self.block
183 return self.client.kill(targets=self.targets,block=block)
214 return self.client.kill(targets=self.targets, block=block)
184 215
185 def abort(self, msg_ids=None, block=None):
186 """Abort jobs on my engines.
187
188 Parameters
189 ----------
190
191 msg_ids : None, str, list of strs, optional
192 if None: abort all jobs.
193 else: abort specific msg_id(s).
194 """
195 block = block if block is not None else self.block
196 return self.client.abort(msg_ids=msg_ids, targets=self.targets, block=block)
197
198 216 class LoadBalancedView(View):
199 _targets=None
217 def __repr__(self):
218 return "<%s %s>"%(self.__class__.__name__, self.client._addr)
219
200 220 No newline at end of file
General Comments 0
You need to be logged in to leave comments. Login now