##// END OF EJS Templates
Fix small bug when no key given to client
Fernando Perez -
Show More
@@ -1,924 +1,924 b''
1 1 """A semi-synchronous Client for the ZMQ controller"""
2 2 #-----------------------------------------------------------------------------
3 3 # Copyright (C) 2010 The IPython Development Team
4 4 #
5 5 # Distributed under the terms of the BSD License. The full license is in
6 6 # the file COPYING, distributed as part of this software.
7 7 #-----------------------------------------------------------------------------
8 8
9 9 #-----------------------------------------------------------------------------
10 10 # Imports
11 11 #-----------------------------------------------------------------------------
12 12
13 13 from __future__ import print_function
14 14
15 15 import os
16 16 import time
17 17 from pprint import pprint
18 18
19 19 import zmq
20 20 from zmq.eventloop import ioloop, zmqstream
21 21
22 22 from IPython.external.decorator import decorator
23 23 from IPython.zmq import tunnel
24 24
25 25 import streamsession as ss
26 26 # from remotenamespace import RemoteNamespace
27 27 from view import DirectView, LoadBalancedView
28 28 from dependency import Dependency, depend, require
29 29
30 30 def _push(ns):
31 31 globals().update(ns)
32 32
33 33 def _pull(keys):
34 34 g = globals()
35 35 if isinstance(keys, (list,tuple, set)):
36 36 for key in keys:
37 37 if not g.has_key(key):
38 38 raise NameError("name '%s' is not defined"%key)
39 39 return map(g.get, keys)
40 40 else:
41 41 if not g.has_key(keys):
42 42 raise NameError("name '%s' is not defined"%keys)
43 43 return g.get(keys)
44 44
45 45 def _clear():
46 46 globals().clear()
47 47
48 48 def execute(code):
49 49 exec code in globals()
50 50
51 51 #--------------------------------------------------------------------------
52 52 # Decorators for Client methods
53 53 #--------------------------------------------------------------------------
54 54
55 55 @decorator
56 56 def spinfirst(f, self, *args, **kwargs):
57 57 """Call spin() to sync state prior to calling the method."""
58 58 self.spin()
59 59 return f(self, *args, **kwargs)
60 60
61 61 @decorator
62 62 def defaultblock(f, self, *args, **kwargs):
63 63 """Default to self.block; preserve self.block."""
64 64 block = kwargs.get('block',None)
65 65 block = self.block if block is None else block
66 66 saveblock = self.block
67 67 self.block = block
68 68 ret = f(self, *args, **kwargs)
69 69 self.block = saveblock
70 70 return ret
71 71
72 72 def remote(client, bound=False, block=None, targets=None):
73 73 """Turn a function into a remote function.
74 74
75 75 This method can be used for map:
76 76
77 77 >>> @remote(client,block=True)
78 78 def func(a)
79 79 """
80 80 def remote_function(f):
81 81 return RemoteFunction(client, f, bound, block, targets)
82 82 return remote_function
83 83
84 84 #--------------------------------------------------------------------------
85 85 # Classes
86 86 #--------------------------------------------------------------------------
87 87
88 88 class RemoteFunction(object):
89 89 """Turn an existing function into a remote function"""
90 90
91 91 def __init__(self, client, f, bound=False, block=None, targets=None):
92 92 self.client = client
93 93 self.func = f
94 94 self.block=block
95 95 self.bound=bound
96 96 self.targets=targets
97 97
98 98 def __call__(self, *args, **kwargs):
99 99 return self.client.apply(self.func, args=args, kwargs=kwargs,
100 100 block=self.block, targets=self.targets, bound=self.bound)
101 101
102 102
103 103 class AbortedTask(object):
104 104 """A basic wrapper object describing an aborted task."""
105 105 def __init__(self, msg_id):
106 106 self.msg_id = msg_id
107 107
108 108 class ControllerError(Exception):
109 109 def __init__(self, etype, evalue, tb):
110 110 self.etype = etype
111 111 self.evalue = evalue
112 112 self.traceback=tb
113 113
114 114 class Client(object):
115 115 """A semi-synchronous client to the IPython ZMQ controller
116 116
117 117 Parameters
118 118 ----------
119 119
120 120 addr : bytes; zmq url, e.g. 'tcp://127.0.0.1:10101'
121 121 The address of the controller's registration socket.
122 122 [Default: 'tcp://127.0.0.1:10101']
123 123 context : zmq.Context
124 124 Pass an existing zmq.Context instance, otherwise the client will create its own
125 125 username : bytes
126 126 set username to be passed to the Session object
127 127 debug : bool
128 128 flag for lots of message printing for debug purposes
129 129
130 130 #-------------- ssh related args ----------------
131 131 # These are args for configuring the ssh tunnel to be used
132 132 # credentials are used to forward connections over ssh to the Controller
133 133 # Note that the ip given in `addr` needs to be relative to sshserver
134 134 # The most basic case is to leave addr as pointing to localhost (127.0.0.1),
135 135 # and set sshserver as the same machine the Controller is on. However,
136 136 # the only requirement is that sshserver is able to see the Controller
137 137 # (i.e. is within the same trusted network).
138 138
139 139 sshserver : str
140 140 A string of the form passed to ssh, i.e. 'server.tld' or 'user@server.tld:port'
141 141 If keyfile or password is specified, and this is not, it will default to
142 142 the ip given in addr.
143 143 sshkey : str; path to public ssh key file
144 144 This specifies a key to be used in ssh login, default None.
145 145 Regular default ssh keys will be used without specifying this argument.
146 146 password : str;
147 147 Your ssh password to sshserver. Note that if this is left None,
148 148 you will be prompted for it if passwordless key based login is unavailable.
149 149
150 150 #------- exec authentication args -------
151 151 # If even localhost is untrusted, you can have some protection against
152 152 # unauthorized execution by using a key. Messages are still sent
153 153 # as cleartext, so if someone can snoop your loopback traffic this will
154 154 # not help anything.
155 155
156 156 exec_key : str
157 157 an authentication key or file containing a key
158 158 default: None
159 159
160 160
161 161 Attributes
162 162 ----------
163 163 ids : set of int engine IDs
164 164 requesting the ids attribute always synchronizes
165 165 the registration state. To request ids without synchronization,
166 166 use semi-private _ids attributes.
167 167
168 168 history : list of msg_ids
169 169 a list of msg_ids, keeping track of all the execution
170 170 messages you have submitted in order.
171 171
172 172 outstanding : set of msg_ids
173 173 a set of msg_ids that have been submitted, but whose
174 174 results have not yet been received.
175 175
176 176 results : dict
177 177 a dict of all our results, keyed by msg_id
178 178
179 179 block : bool
180 180 determines default behavior when block not specified
181 181 in execution methods
182 182
183 183 Methods
184 184 -------
185 185 spin : flushes incoming results and registration state changes
186 186 control methods spin, and requesting `ids` also ensures up to date
187 187
188 188 barrier : wait on one or more msg_ids
189 189
190 190 execution methods: apply/apply_bound/apply_to/apply_bound
191 191 legacy: execute, run
192 192
193 193 query methods: queue_status, get_result, purge
194 194
195 195 control methods: abort, kill
196 196
197 197 """
198 198
199 199
200 200 _connected=False
201 201 _ssh=False
202 202 _engines=None
203 203 _addr='tcp://127.0.0.1:10101'
204 204 _registration_socket=None
205 205 _query_socket=None
206 206 _control_socket=None
207 207 _notification_socket=None
208 208 _mux_socket=None
209 209 _task_socket=None
210 210 block = False
211 211 outstanding=None
212 212 results = None
213 213 history = None
214 214 debug = False
215 215
216 216 def __init__(self, addr='tcp://127.0.0.1:10101', context=None, username=None, debug=False,
217 217 sshserver=None, sshkey=None, password=None, paramiko=None,
218 218 exec_key=None,):
219 219 if context is None:
220 220 context = zmq.Context()
221 221 self.context = context
222 222 self._addr = addr
223 223 self._ssh = bool(sshserver or sshkey or password)
224 224 if self._ssh and sshserver is None:
225 225 # default to the same
226 226 sshserver = addr.split('://')[1].split(':')[0]
227 227 if self._ssh and password is None:
228 228 if tunnel.try_passwordless_ssh(sshserver, sshkey, paramiko):
229 229 password=False
230 230 else:
231 231 password = getpass("SSH Password for %s: "%sshserver)
232 232 ssh_kwargs = dict(keyfile=sshkey, password=password, paramiko=paramiko)
233 233
234 if os.path.isfile(exec_key):
234 if exec_key is not None and os.path.isfile(exec_key):
235 235 arg = 'keyfile'
236 236 else:
237 237 arg = 'key'
238 238 key_arg = {arg:exec_key}
239 239 if username is None:
240 240 self.session = ss.StreamSession(**key_arg)
241 241 else:
242 242 self.session = ss.StreamSession(username, **key_arg)
243 243 self._registration_socket = self.context.socket(zmq.XREQ)
244 244 self._registration_socket.setsockopt(zmq.IDENTITY, self.session.session)
245 245 if self._ssh:
246 246 tunnel.tunnel_connection(self._registration_socket, addr, sshserver, **ssh_kwargs)
247 247 else:
248 248 self._registration_socket.connect(addr)
249 249 self._engines = {}
250 250 self._ids = set()
251 251 self.outstanding=set()
252 252 self.results = {}
253 253 self.history = []
254 254 self.debug = debug
255 255 self.session.debug = debug
256 256
257 257 self._notification_handlers = {'registration_notification' : self._register_engine,
258 258 'unregistration_notification' : self._unregister_engine,
259 259 }
260 260 self._queue_handlers = {'execute_reply' : self._handle_execute_reply,
261 261 'apply_reply' : self._handle_apply_reply}
262 262 self._connect(sshserver, ssh_kwargs)
263 263
264 264
265 265 @property
266 266 def ids(self):
267 267 """Always up to date ids property."""
268 268 self._flush_notifications()
269 269 return self._ids
270 270
271 271 def _update_engines(self, engines):
272 272 """Update our engines dict and _ids from a dict of the form: {id:uuid}."""
273 273 for k,v in engines.iteritems():
274 274 eid = int(k)
275 275 self._engines[eid] = bytes(v) # force not unicode
276 276 self._ids.add(eid)
277 277
278 278 def _build_targets(self, targets):
279 279 """Turn valid target IDs or 'all' into two lists:
280 280 (int_ids, uuids).
281 281 """
282 282 if targets is None:
283 283 targets = self._ids
284 284 elif isinstance(targets, str):
285 285 if targets.lower() == 'all':
286 286 targets = self._ids
287 287 else:
288 288 raise TypeError("%r not valid str target, must be 'all'"%(targets))
289 289 elif isinstance(targets, int):
290 290 targets = [targets]
291 291 return [self._engines[t] for t in targets], list(targets)
292 292
293 293 def _connect(self, sshserver, ssh_kwargs):
294 294 """setup all our socket connections to the controller. This is called from
295 295 __init__."""
296 296 if self._connected:
297 297 return
298 298 self._connected=True
299 299
300 300 def connect_socket(s, addr):
301 301 if self._ssh:
302 302 return tunnel.tunnel_connection(s, addr, sshserver, **ssh_kwargs)
303 303 else:
304 304 return s.connect(addr)
305 305
306 306 self.session.send(self._registration_socket, 'connection_request')
307 307 idents,msg = self.session.recv(self._registration_socket,mode=0)
308 308 if self.debug:
309 309 pprint(msg)
310 310 msg = ss.Message(msg)
311 311 content = msg.content
312 312 if content.status == 'ok':
313 313 if content.queue:
314 314 self._mux_socket = self.context.socket(zmq.PAIR)
315 315 self._mux_socket.setsockopt(zmq.IDENTITY, self.session.session)
316 316 connect_socket(self._mux_socket, content.queue)
317 317 if content.task:
318 318 self._task_socket = self.context.socket(zmq.PAIR)
319 319 self._task_socket.setsockopt(zmq.IDENTITY, self.session.session)
320 320 connect_socket(self._task_socket, content.task)
321 321 if content.notification:
322 322 self._notification_socket = self.context.socket(zmq.SUB)
323 323 connect_socket(self._notification_socket, content.notification)
324 324 self._notification_socket.setsockopt(zmq.SUBSCRIBE, "")
325 325 if content.query:
326 326 self._query_socket = self.context.socket(zmq.PAIR)
327 327 self._query_socket.setsockopt(zmq.IDENTITY, self.session.session)
328 328 connect_socket(self._query_socket, content.query)
329 329 if content.control:
330 330 self._control_socket = self.context.socket(zmq.PAIR)
331 331 self._control_socket.setsockopt(zmq.IDENTITY, self.session.session)
332 332 connect_socket(self._control_socket, content.control)
333 333 self._update_engines(dict(content.engines))
334 334
335 335 else:
336 336 self._connected = False
337 337 raise Exception("Failed to connect!")
338 338
339 339 #--------------------------------------------------------------------------
340 340 # handlers and callbacks for incoming messages
341 341 #--------------------------------------------------------------------------
342 342
343 343 def _register_engine(self, msg):
344 344 """Register a new engine, and update our connection info."""
345 345 content = msg['content']
346 346 eid = content['id']
347 347 d = {eid : content['queue']}
348 348 self._update_engines(d)
349 349 self._ids.add(int(eid))
350 350
351 351 def _unregister_engine(self, msg):
352 352 """Unregister an engine that has died."""
353 353 content = msg['content']
354 354 eid = int(content['id'])
355 355 if eid in self._ids:
356 356 self._ids.remove(eid)
357 357 self._engines.pop(eid)
358 358
359 359 def _handle_execute_reply(self, msg):
360 360 """Save the reply to an execute_request into our results."""
361 361 parent = msg['parent_header']
362 362 msg_id = parent['msg_id']
363 363 if msg_id not in self.outstanding:
364 364 print("got unknown result: %s"%msg_id)
365 365 else:
366 366 self.outstanding.remove(msg_id)
367 367 self.results[msg_id] = ss.unwrap_exception(msg['content'])
368 368
369 369 def _handle_apply_reply(self, msg):
370 370 """Save the reply to an apply_request into our results."""
371 371 parent = msg['parent_header']
372 372 msg_id = parent['msg_id']
373 373 if msg_id not in self.outstanding:
374 374 print ("got unknown result: %s"%msg_id)
375 375 else:
376 376 self.outstanding.remove(msg_id)
377 377 content = msg['content']
378 378 if content['status'] == 'ok':
379 379 self.results[msg_id] = ss.unserialize_object(msg['buffers'])
380 380 elif content['status'] == 'aborted':
381 381 self.results[msg_id] = AbortedTask(msg_id)
382 382 elif content['status'] == 'resubmitted':
383 383 # TODO: handle resubmission
384 384 pass
385 385 else:
386 386 self.results[msg_id] = ss.unwrap_exception(content)
387 387
388 388 def _flush_notifications(self):
389 389 """Flush notifications of engine registrations waiting
390 390 in ZMQ queue."""
391 391 msg = self.session.recv(self._notification_socket, mode=zmq.NOBLOCK)
392 392 while msg is not None:
393 393 if self.debug:
394 394 pprint(msg)
395 395 msg = msg[-1]
396 396 msg_type = msg['msg_type']
397 397 handler = self._notification_handlers.get(msg_type, None)
398 398 if handler is None:
399 399 raise Exception("Unhandled message type: %s"%msg.msg_type)
400 400 else:
401 401 handler(msg)
402 402 msg = self.session.recv(self._notification_socket, mode=zmq.NOBLOCK)
403 403
404 404 def _flush_results(self, sock):
405 405 """Flush task or queue results waiting in ZMQ queue."""
406 406 msg = self.session.recv(sock, mode=zmq.NOBLOCK)
407 407 while msg is not None:
408 408 if self.debug:
409 409 pprint(msg)
410 410 msg = msg[-1]
411 411 msg_type = msg['msg_type']
412 412 handler = self._queue_handlers.get(msg_type, None)
413 413 if handler is None:
414 414 raise Exception("Unhandled message type: %s"%msg.msg_type)
415 415 else:
416 416 handler(msg)
417 417 msg = self.session.recv(sock, mode=zmq.NOBLOCK)
418 418
419 419 def _flush_control(self, sock):
420 420 """Flush replies from the control channel waiting
421 421 in the ZMQ queue.
422 422
423 423 Currently: ignore them."""
424 424 msg = self.session.recv(sock, mode=zmq.NOBLOCK)
425 425 while msg is not None:
426 426 if self.debug:
427 427 pprint(msg)
428 428 msg = self.session.recv(sock, mode=zmq.NOBLOCK)
429 429
430 430 #--------------------------------------------------------------------------
431 431 # getitem
432 432 #--------------------------------------------------------------------------
433 433
434 434 def __getitem__(self, key):
435 435 """Dict access returns DirectView multiplexer objects or,
436 436 if key is None, a LoadBalancedView."""
437 437 if key is None:
438 438 return LoadBalancedView(self)
439 439 if isinstance(key, int):
440 440 if key not in self.ids:
441 441 raise IndexError("No such engine: %i"%key)
442 442 return DirectView(self, key)
443 443
444 444 if isinstance(key, slice):
445 445 indices = range(len(self.ids))[key]
446 446 ids = sorted(self._ids)
447 447 key = [ ids[i] for i in indices ]
448 448 # newkeys = sorted(self._ids)[thekeys[k]]
449 449
450 450 if isinstance(key, (tuple, list, xrange)):
451 451 _,targets = self._build_targets(list(key))
452 452 return DirectView(self, targets)
453 453 else:
454 454 raise TypeError("key by int/iterable of ints only, not %s"%(type(key)))
455 455
456 456 #--------------------------------------------------------------------------
457 457 # Begin public methods
458 458 #--------------------------------------------------------------------------
459 459
460 460 def spin(self):
461 461 """Flush any registration notifications and execution results
462 462 waiting in the ZMQ queue.
463 463 """
464 464 if self._notification_socket:
465 465 self._flush_notifications()
466 466 if self._mux_socket:
467 467 self._flush_results(self._mux_socket)
468 468 if self._task_socket:
469 469 self._flush_results(self._task_socket)
470 470 if self._control_socket:
471 471 self._flush_control(self._control_socket)
472 472
473 473 def barrier(self, msg_ids=None, timeout=-1):
474 474 """waits on one or more `msg_ids`, for up to `timeout` seconds.
475 475
476 476 Parameters
477 477 ----------
478 478 msg_ids : int, str, or list of ints and/or strs
479 479 ints are indices to self.history
480 480 strs are msg_ids
481 481 default: wait on all outstanding messages
482 482 timeout : float
483 483 a time in seconds, after which to give up.
484 484 default is -1, which means no timeout
485 485
486 486 Returns
487 487 -------
488 488 True : when all msg_ids are done
489 489 False : timeout reached, some msg_ids still outstanding
490 490 """
491 491 tic = time.time()
492 492 if msg_ids is None:
493 493 theids = self.outstanding
494 494 else:
495 495 if isinstance(msg_ids, (int, str)):
496 496 msg_ids = [msg_ids]
497 497 theids = set()
498 498 for msg_id in msg_ids:
499 499 if isinstance(msg_id, int):
500 500 msg_id = self.history[msg_id]
501 501 theids.add(msg_id)
502 502 self.spin()
503 503 while theids.intersection(self.outstanding):
504 504 if timeout >= 0 and ( time.time()-tic ) > timeout:
505 505 break
506 506 time.sleep(1e-3)
507 507 self.spin()
508 508 return len(theids.intersection(self.outstanding)) == 0
509 509
510 510 #--------------------------------------------------------------------------
511 511 # Control methods
512 512 #--------------------------------------------------------------------------
513 513
514 514 @spinfirst
515 515 @defaultblock
516 516 def clear(self, targets=None, block=None):
517 517 """Clear the namespace in target(s)."""
518 518 targets = self._build_targets(targets)[0]
519 519 for t in targets:
520 520 self.session.send(self._control_socket, 'clear_request', content={}, ident=t)
521 521 error = False
522 522 if self.block:
523 523 for i in range(len(targets)):
524 524 idents,msg = self.session.recv(self._control_socket,0)
525 525 if self.debug:
526 526 pprint(msg)
527 527 if msg['content']['status'] != 'ok':
528 528 error = ss.unwrap_exception(msg['content'])
529 529 if error:
530 530 return error
531 531
532 532
533 533 @spinfirst
534 534 @defaultblock
535 535 def abort(self, msg_ids = None, targets=None, block=None):
536 536 """Abort the execution queues of target(s)."""
537 537 targets = self._build_targets(targets)[0]
538 538 if isinstance(msg_ids, basestring):
539 539 msg_ids = [msg_ids]
540 540 content = dict(msg_ids=msg_ids)
541 541 for t in targets:
542 542 self.session.send(self._control_socket, 'abort_request',
543 543 content=content, ident=t)
544 544 error = False
545 545 if self.block:
546 546 for i in range(len(targets)):
547 547 idents,msg = self.session.recv(self._control_socket,0)
548 548 if self.debug:
549 549 pprint(msg)
550 550 if msg['content']['status'] != 'ok':
551 551 error = ss.unwrap_exception(msg['content'])
552 552 if error:
553 553 return error
554 554
555 555 @spinfirst
556 556 @defaultblock
557 557 def shutdown(self, targets=None, restart=False, block=None):
558 558 """Terminates one or more engine processes."""
559 559 targets = self._build_targets(targets)[0]
560 560 for t in targets:
561 561 self.session.send(self._control_socket, 'shutdown_request',
562 562 content={'restart':restart},ident=t)
563 563 error = False
564 564 if self.block:
565 565 for i in range(len(targets)):
566 566 idents,msg = self.session.recv(self._control_socket,0)
567 567 if self.debug:
568 568 pprint(msg)
569 569 if msg['content']['status'] != 'ok':
570 570 error = ss.unwrap_exception(msg['content'])
571 571 if error:
572 572 return error
573 573
574 574 #--------------------------------------------------------------------------
575 575 # Execution methods
576 576 #--------------------------------------------------------------------------
577 577
578 578 @defaultblock
579 579 def execute(self, code, targets='all', block=None):
580 580 """Executes `code` on `targets` in blocking or nonblocking manner.
581 581
582 582 Parameters
583 583 ----------
584 584 code : str
585 585 the code string to be executed
586 586 targets : int/str/list of ints/strs
587 587 the engines on which to execute
588 588 default : all
589 589 block : bool
590 590 whether or not to wait until done to return
591 591 default: self.block
592 592 """
593 593 # block = self.block if block is None else block
594 594 # saveblock = self.block
595 595 # self.block = block
596 596 result = self.apply(execute, (code,), targets=targets, block=block, bound=True)
597 597 # self.block = saveblock
598 598 return result
599 599
600 600 def run(self, code, block=None):
601 601 """Runs `code` on an engine.
602 602
603 603 Calls to this are load-balanced.
604 604
605 605 Parameters
606 606 ----------
607 607 code : str
608 608 the code string to be executed
609 609 block : bool
610 610 whether or not to wait until done
611 611
612 612 """
613 613 result = self.apply(execute, (code,), targets=None, block=block, bound=False)
614 614 return result
615 615
616 616 def apply(self, f, args=None, kwargs=None, bound=True, block=None, targets=None,
617 617 after=None, follow=None):
618 618 """Call `f(*args, **kwargs)` on a remote engine(s), returning the result.
619 619
620 620 This is the central execution command for the client.
621 621
622 622 Parameters
623 623 ----------
624 624
625 625 f : function
626 626 The fuction to be called remotely
627 627 args : tuple/list
628 628 The positional arguments passed to `f`
629 629 kwargs : dict
630 630 The keyword arguments passed to `f`
631 631 bound : bool (default: True)
632 632 Whether to execute in the Engine(s) namespace, or in a clean
633 633 namespace not affecting the engine.
634 634 block : bool (default: self.block)
635 635 Whether to wait for the result, or return immediately.
636 636 False:
637 637 returns msg_id(s)
638 638 if multiple targets:
639 639 list of ids
640 640 True:
641 641 returns actual result(s) of f(*args, **kwargs)
642 642 if multiple targets:
643 643 dict of results, by engine ID
644 644 targets : int,list of ints, 'all', None
645 645 Specify the destination of the job.
646 646 if None:
647 647 Submit via Task queue for load-balancing.
648 648 if 'all':
649 649 Run on all active engines
650 650 if list:
651 651 Run on each specified engine
652 652 if int:
653 653 Run on single engine
654 654
655 655 after : Dependency or collection of msg_ids
656 656 Only for load-balanced execution (targets=None)
657 657 Specify a list of msg_ids as a time-based dependency.
658 658 This job will only be run *after* the dependencies
659 659 have been met.
660 660
661 661 follow : Dependency or collection of msg_ids
662 662 Only for load-balanced execution (targets=None)
663 663 Specify a list of msg_ids as a location-based dependency.
664 664 This job will only be run on an engine where this dependency
665 665 is met.
666 666
667 667 Returns
668 668 -------
669 669 if block is False:
670 670 if single target:
671 671 return msg_id
672 672 else:
673 673 return list of msg_ids
674 674 ? (should this be dict like block=True) ?
675 675 else:
676 676 if single target:
677 677 return result of f(*args, **kwargs)
678 678 else:
679 679 return dict of results, keyed by engine
680 680 """
681 681
682 682 # defaults:
683 683 block = block if block is not None else self.block
684 684 args = args if args is not None else []
685 685 kwargs = kwargs if kwargs is not None else {}
686 686
687 687 # enforce types of f,args,kwrags
688 688 if not callable(f):
689 689 raise TypeError("f must be callable, not %s"%type(f))
690 690 if not isinstance(args, (tuple, list)):
691 691 raise TypeError("args must be tuple or list, not %s"%type(args))
692 692 if not isinstance(kwargs, dict):
693 693 raise TypeError("kwargs must be dict, not %s"%type(kwargs))
694 694
695 695 options = dict(bound=bound, block=block, after=after, follow=follow)
696 696
697 697 if targets is None:
698 698 return self._apply_balanced(f, args, kwargs, **options)
699 699 else:
700 700 return self._apply_direct(f, args, kwargs, targets=targets, **options)
701 701
702 702 def _apply_balanced(self, f, args, kwargs, bound=True, block=None,
703 703 after=None, follow=None):
704 704 """The underlying method for applying functions in a load balanced
705 705 manner, via the task queue."""
706 706 if isinstance(after, Dependency):
707 707 after = after.as_dict()
708 708 elif after is None:
709 709 after = []
710 710 if isinstance(follow, Dependency):
711 711 follow = follow.as_dict()
712 712 elif follow is None:
713 713 follow = []
714 714 subheader = dict(after=after, follow=follow)
715 715
716 716 bufs = ss.pack_apply_message(f,args,kwargs)
717 717 content = dict(bound=bound)
718 718 msg = self.session.send(self._task_socket, "apply_request",
719 719 content=content, buffers=bufs, subheader=subheader)
720 720 msg_id = msg['msg_id']
721 721 self.outstanding.add(msg_id)
722 722 self.history.append(msg_id)
723 723 if block:
724 724 self.barrier(msg_id)
725 725 return self.results[msg_id]
726 726 else:
727 727 return msg_id
728 728
729 729 def _apply_direct(self, f, args, kwargs, bound=True, block=None, targets=None,
730 730 after=None, follow=None):
731 731 """Then underlying method for applying functions to specific engines
732 732 via the MUX queue."""
733 733
734 734 queues,targets = self._build_targets(targets)
735 735 bufs = ss.pack_apply_message(f,args,kwargs)
736 736 if isinstance(after, Dependency):
737 737 after = after.as_dict()
738 738 elif after is None:
739 739 after = []
740 740 if isinstance(follow, Dependency):
741 741 follow = follow.as_dict()
742 742 elif follow is None:
743 743 follow = []
744 744 subheader = dict(after=after, follow=follow)
745 745 content = dict(bound=bound)
746 746 msg_ids = []
747 747 for queue in queues:
748 748 msg = self.session.send(self._mux_socket, "apply_request",
749 749 content=content, buffers=bufs,ident=queue, subheader=subheader)
750 750 msg_id = msg['msg_id']
751 751 self.outstanding.add(msg_id)
752 752 self.history.append(msg_id)
753 753 msg_ids.append(msg_id)
754 754 if block:
755 755 self.barrier(msg_ids)
756 756 else:
757 757 if len(msg_ids) == 1:
758 758 return msg_ids[0]
759 759 else:
760 760 return msg_ids
761 761 if len(msg_ids) == 1:
762 762 return self.results[msg_ids[0]]
763 763 else:
764 764 result = {}
765 765 for target,mid in zip(targets, msg_ids):
766 766 result[target] = self.results[mid]
767 767 return result
768 768
769 769 #--------------------------------------------------------------------------
770 770 # Data movement
771 771 #--------------------------------------------------------------------------
772 772
773 773 @defaultblock
774 774 def push(self, ns, targets=None, block=None):
775 775 """Push the contents of `ns` into the namespace on `target`"""
776 776 if not isinstance(ns, dict):
777 777 raise TypeError("Must be a dict, not %s"%type(ns))
778 778 result = self.apply(_push, (ns,), targets=targets, block=block, bound=True)
779 779 return result
780 780
781 781 @defaultblock
782 782 def pull(self, keys, targets=None, block=True):
783 783 """Pull objects from `target`'s namespace by `keys`"""
784 784 if isinstance(keys, str):
785 785 pass
786 786 elif isistance(keys, (list,tuple,set)):
787 787 for key in keys:
788 788 if not isinstance(key, str):
789 789 raise TypeError
790 790 result = self.apply(_pull, (keys,), targets=targets, block=block, bound=True)
791 791 return result
792 792
793 793 #--------------------------------------------------------------------------
794 794 # Query methods
795 795 #--------------------------------------------------------------------------
796 796
797 797 @spinfirst
798 798 def get_results(self, msg_ids, status_only=False):
799 799 """Returns the result of the execute or task request with `msg_ids`.
800 800
801 801 Parameters
802 802 ----------
803 803 msg_ids : list of ints or msg_ids
804 804 if int:
805 805 Passed as index to self.history for convenience.
806 806 status_only : bool (default: False)
807 807 if False:
808 808 return the actual results
809 809 """
810 810 if not isinstance(msg_ids, (list,tuple)):
811 811 msg_ids = [msg_ids]
812 812 theids = []
813 813 for msg_id in msg_ids:
814 814 if isinstance(msg_id, int):
815 815 msg_id = self.history[msg_id]
816 816 if not isinstance(msg_id, str):
817 817 raise TypeError("msg_ids must be str, not %r"%msg_id)
818 818 theids.append(msg_id)
819 819
820 820 completed = []
821 821 local_results = {}
822 822 for msg_id in list(theids):
823 823 if msg_id in self.results:
824 824 completed.append(msg_id)
825 825 local_results[msg_id] = self.results[msg_id]
826 826 theids.remove(msg_id)
827 827
828 828 if theids: # some not locally cached
829 829 content = dict(msg_ids=theids, status_only=status_only)
830 830 msg = self.session.send(self._query_socket, "result_request", content=content)
831 831 zmq.select([self._query_socket], [], [])
832 832 idents,msg = self.session.recv(self._query_socket, zmq.NOBLOCK)
833 833 if self.debug:
834 834 pprint(msg)
835 835 content = msg['content']
836 836 if content['status'] != 'ok':
837 837 raise ss.unwrap_exception(content)
838 838 else:
839 839 content = dict(completed=[],pending=[])
840 840 if not status_only:
841 841 # load cached results into result:
842 842 content['completed'].extend(completed)
843 843 content.update(local_results)
844 844 # update cache with results:
845 845 for msg_id in msg_ids:
846 846 if msg_id in content['completed']:
847 847 self.results[msg_id] = content[msg_id]
848 848 return content
849 849
850 850 @spinfirst
851 851 def queue_status(self, targets=None, verbose=False):
852 852 """Fetch the status of engine queues.
853 853
854 854 Parameters
855 855 ----------
856 856 targets : int/str/list of ints/strs
857 857 the engines on which to execute
858 858 default : all
859 859 verbose : bool
860 860 whether to return lengths only, or lists of ids for each element
861 861 """
862 862 targets = self._build_targets(targets)[1]
863 863 content = dict(targets=targets, verbose=verbose)
864 864 self.session.send(self._query_socket, "queue_request", content=content)
865 865 idents,msg = self.session.recv(self._query_socket, 0)
866 866 if self.debug:
867 867 pprint(msg)
868 868 content = msg['content']
869 869 status = content.pop('status')
870 870 if status != 'ok':
871 871 raise ss.unwrap_exception(content)
872 872 return content
873 873
874 874 @spinfirst
875 875 def purge_results(self, msg_ids=[], targets=[]):
876 876 """Tell the controller to forget results.
877 877
878 878 Individual results can be purged by msg_id, or the entire
879 879 history of specific targets can
880 880
881 881 Parameters
882 882 ----------
883 883 targets : int/str/list of ints/strs
884 884 the targets
885 885 default : None
886 886 """
887 887 if not targets and not msg_ids:
888 888 raise ValueError
889 889 if targets:
890 890 targets = self._build_targets(targets)[1]
891 891 content = dict(targets=targets, msg_ids=msg_ids)
892 892 self.session.send(self._query_socket, "purge_request", content=content)
893 893 idents, msg = self.session.recv(self._query_socket, 0)
894 894 if self.debug:
895 895 pprint(msg)
896 896 content = msg['content']
897 897 if content['status'] != 'ok':
898 898 raise ss.unwrap_exception(content)
899 899
900 900 class AsynClient(Client):
901 901 """An Asynchronous client, using the Tornado Event Loop.
902 902 !!!unfinished!!!"""
903 903 io_loop = None
904 904 _queue_stream = None
905 905 _notifier_stream = None
906 906 _task_stream = None
907 907 _control_stream = None
908 908
909 909 def __init__(self, addr, context=None, username=None, debug=False, io_loop=None):
910 910 Client.__init__(self, addr, context, username, debug)
911 911 if io_loop is None:
912 912 io_loop = ioloop.IOLoop.instance()
913 913 self.io_loop = io_loop
914 914
915 915 self._queue_stream = zmqstream.ZMQStream(self._mux_socket, io_loop)
916 916 self._control_stream = zmqstream.ZMQStream(self._control_socket, io_loop)
917 917 self._task_stream = zmqstream.ZMQStream(self._task_socket, io_loop)
918 918 self._notification_stream = zmqstream.ZMQStream(self._notification_socket, io_loop)
919 919
920 920 def spin(self):
921 921 for stream in (self.queue_stream, self.notifier_stream,
922 922 self.task_stream, self.control_stream):
923 923 stream.flush()
924 924
General Comments 0
You need to be logged in to leave comments. Login now