##// END OF EJS Templates
some docstring cleanup
MinRK -
Show More
@@ -1,975 +1,988 b''
1 1 """A semi-synchronous Client for the ZMQ controller"""
2 2 #-----------------------------------------------------------------------------
3 3 # Copyright (C) 2010 The IPython Development Team
4 4 #
5 5 # Distributed under the terms of the BSD License. The full license is in
6 6 # the file COPYING, distributed as part of this software.
7 7 #-----------------------------------------------------------------------------
8 8
9 9 #-----------------------------------------------------------------------------
10 10 # Imports
11 11 #-----------------------------------------------------------------------------
12 12
13 13 from __future__ import print_function
14 14
15 15 import os
16 16 import time
17 17 from getpass import getpass
18 18 from pprint import pprint
19 19
20 20 import zmq
21 21 from zmq.eventloop import ioloop, zmqstream
22 22
23 23 from IPython.external.decorator import decorator
24 24 from IPython.zmq import tunnel
25 25
26 26 import streamsession as ss
27 27 # from remotenamespace import RemoteNamespace
28 28 from view import DirectView, LoadBalancedView
29 29 from dependency import Dependency, depend, require
30 30 import error
31 31
32 #--------------------------------------------------------------------------
33 # helpers for implementing old MEC API via client.apply
34 #--------------------------------------------------------------------------
35
32 36 def _push(ns):
37 """helper method for implementing `client.push` via `client.apply`"""
33 38 globals().update(ns)
34 39
35 40 def _pull(keys):
41 """helper method for implementing `client.pull` via `client.apply`"""
36 42 g = globals()
37 43 if isinstance(keys, (list,tuple, set)):
38 44 for key in keys:
39 45 if not g.has_key(key):
40 46 raise NameError("name '%s' is not defined"%key)
41 47 return map(g.get, keys)
42 48 else:
43 49 if not g.has_key(keys):
44 50 raise NameError("name '%s' is not defined"%keys)
45 51 return g.get(keys)
46 52
47 53 def _clear():
54 """helper method for implementing `client.clear` via `client.apply`"""
48 55 globals().clear()
49 56
50 57 def execute(code):
58 """helper method for implementing `client.execute` via `client.apply`"""
51 59 exec code in globals()
60
52 61
53 62 #--------------------------------------------------------------------------
54 63 # Decorators for Client methods
55 64 #--------------------------------------------------------------------------
56 65
57 66 @decorator
58 67 def spinfirst(f, self, *args, **kwargs):
59 68 """Call spin() to sync state prior to calling the method."""
60 69 self.spin()
61 70 return f(self, *args, **kwargs)
62 71
63 72 @decorator
64 73 def defaultblock(f, self, *args, **kwargs):
65 74 """Default to self.block; preserve self.block."""
66 75 block = kwargs.get('block',None)
67 76 block = self.block if block is None else block
68 77 saveblock = self.block
69 78 self.block = block
70 79 ret = f(self, *args, **kwargs)
71 80 self.block = saveblock
72 81 return ret
73 82
74 83 def remote(client, bound=False, block=None, targets=None):
75 84 """Turn a function into a remote function.
76 85
77 86 This method can be used for map:
78 87
79 88 >>> @remote(client,block=True)
80 89 def func(a)
81 90 """
82 91 def remote_function(f):
83 92 return RemoteFunction(client, f, bound, block, targets)
84 93 return remote_function
85 94
86 95 #--------------------------------------------------------------------------
87 96 # Classes
88 97 #--------------------------------------------------------------------------
89 98
90 99 class RemoteFunction(object):
91 100 """Turn an existing function into a remote function.
92 101
93 102 Parameters
94 103 ----------
95 104
96 105 client : Client instance
97 106 The client to be used to connect to engines
98 107 f : callable
99 108 The function to be wrapped into a remote function
100 109 bound : bool [default: False]
101 110 Whether the affect the remote namespace when called
102 111 block : bool [default: None]
103 112 Whether to wait for results or not. The default behavior is
104 113 to use the current `block` attribute of `client`
105 114 targets : valid target list [default: all]
106 115 The targets on which to execute.
107 116 """
108 117
109 118 client = None # the remote connection
110 119 func = None # the wrapped function
111 120 block = None # whether to block
112 121 bound = None # whether to affect the namespace
113 122 targets = None # where to execute
114 123
115 124 def __init__(self, client, f, bound=False, block=None, targets=None):
116 125 self.client = client
117 126 self.func = f
118 127 self.block=block
119 128 self.bound=bound
120 129 self.targets=targets
121 130
122 131 def __call__(self, *args, **kwargs):
123 132 return self.client.apply(self.func, args=args, kwargs=kwargs,
124 133 block=self.block, targets=self.targets, bound=self.bound)
125 134
126 135
127 136 class AbortedTask(object):
128 137 """A basic wrapper object describing an aborted task."""
129 138 def __init__(self, msg_id):
130 139 self.msg_id = msg_id
131 140
132 141 class ResultDict(dict):
133 142 """A subclass of dict that raises errors if it has them."""
134 143 def __getitem__(self, key):
135 144 res = dict.__getitem__(self, key)
136 145 if isinstance(res, error.KernelError):
137 146 raise res
138 147 return res
139 148
140 149 class Client(object):
141 150 """A semi-synchronous client to the IPython ZMQ controller
142 151
143 152 Parameters
144 153 ----------
145 154
146 155 addr : bytes; zmq url, e.g. 'tcp://127.0.0.1:10101'
147 156 The address of the controller's registration socket.
148 157 [Default: 'tcp://127.0.0.1:10101']
149 158 context : zmq.Context
150 159 Pass an existing zmq.Context instance, otherwise the client will create its own
151 160 username : bytes
152 161 set username to be passed to the Session object
153 162 debug : bool
154 163 flag for lots of message printing for debug purposes
155 164
156 165 #-------------- ssh related args ----------------
157 166 # These are args for configuring the ssh tunnel to be used
158 167 # credentials are used to forward connections over ssh to the Controller
159 168 # Note that the ip given in `addr` needs to be relative to sshserver
160 169 # The most basic case is to leave addr as pointing to localhost (127.0.0.1),
161 170 # and set sshserver as the same machine the Controller is on. However,
162 171 # the only requirement is that sshserver is able to see the Controller
163 172 # (i.e. is within the same trusted network).
164 173
165 174 sshserver : str
166 175 A string of the form passed to ssh, i.e. 'server.tld' or 'user@server.tld:port'
167 176 If keyfile or password is specified, and this is not, it will default to
168 177 the ip given in addr.
169 178 sshkey : str; path to public ssh key file
170 179 This specifies a key to be used in ssh login, default None.
171 180 Regular default ssh keys will be used without specifying this argument.
172 181 password : str;
173 182 Your ssh password to sshserver. Note that if this is left None,
174 183 you will be prompted for it if passwordless key based login is unavailable.
175 184
176 185 #------- exec authentication args -------
177 186 # If even localhost is untrusted, you can have some protection against
178 187 # unauthorized execution by using a key. Messages are still sent
179 188 # as cleartext, so if someone can snoop your loopback traffic this will
180 189 # not help anything.
181 190
182 191 exec_key : str
183 192 an authentication key or file containing a key
184 193 default: None
185 194
186 195
187 196 Attributes
188 197 ----------
189 198 ids : set of int engine IDs
190 199 requesting the ids attribute always synchronizes
191 200 the registration state. To request ids without synchronization,
192 201 use semi-private _ids attributes.
193 202
194 203 history : list of msg_ids
195 204 a list of msg_ids, keeping track of all the execution
196 205 messages you have submitted in order.
197 206
198 207 outstanding : set of msg_ids
199 208 a set of msg_ids that have been submitted, but whose
200 209 results have not yet been received.
201 210
202 211 results : dict
203 212 a dict of all our results, keyed by msg_id
204 213
205 214 block : bool
206 215 determines default behavior when block not specified
207 216 in execution methods
208 217
209 218 Methods
210 219 -------
211 220 spin : flushes incoming results and registration state changes
212 221 control methods spin, and requesting `ids` also ensures up to date
213 222
214 223 barrier : wait on one or more msg_ids
215 224
216 225 execution methods: apply/apply_bound/apply_to/apply_bound
217 226 legacy: execute, run
218 227
219 228 query methods: queue_status, get_result, purge
220 229
221 230 control methods: abort, kill
222 231
223 232 """
224 233
225 234
226 235 _connected=False
227 236 _ssh=False
228 237 _engines=None
229 238 _addr='tcp://127.0.0.1:10101'
230 239 _registration_socket=None
231 240 _query_socket=None
232 241 _control_socket=None
233 242 _notification_socket=None
234 243 _mux_socket=None
235 244 _task_socket=None
236 245 block = False
237 246 outstanding=None
238 247 results = None
239 248 history = None
240 249 debug = False
241 250
242 251 def __init__(self, addr='tcp://127.0.0.1:10101', context=None, username=None, debug=False,
243 252 sshserver=None, sshkey=None, password=None, paramiko=None,
244 253 exec_key=None,):
245 254 if context is None:
246 255 context = zmq.Context()
247 256 self.context = context
248 257 self._addr = addr
249 258 self._ssh = bool(sshserver or sshkey or password)
250 259 if self._ssh and sshserver is None:
251 260 # default to the same
252 261 sshserver = addr.split('://')[1].split(':')[0]
253 262 if self._ssh and password is None:
254 263 if tunnel.try_passwordless_ssh(sshserver, sshkey, paramiko):
255 264 password=False
256 265 else:
257 266 password = getpass("SSH Password for %s: "%sshserver)
258 267 ssh_kwargs = dict(keyfile=sshkey, password=password, paramiko=paramiko)
259 268
260 269 if exec_key is not None and os.path.isfile(exec_key):
261 270 arg = 'keyfile'
262 271 else:
263 272 arg = 'key'
264 273 key_arg = {arg:exec_key}
265 274 if username is None:
266 275 self.session = ss.StreamSession(**key_arg)
267 276 else:
268 277 self.session = ss.StreamSession(username, **key_arg)
269 278 self._registration_socket = self.context.socket(zmq.XREQ)
270 279 self._registration_socket.setsockopt(zmq.IDENTITY, self.session.session)
271 280 if self._ssh:
272 281 tunnel.tunnel_connection(self._registration_socket, addr, sshserver, **ssh_kwargs)
273 282 else:
274 283 self._registration_socket.connect(addr)
275 284 self._engines = {}
276 285 self._ids = set()
277 286 self.outstanding=set()
278 287 self.results = {}
279 288 self.history = []
280 289 self.debug = debug
281 290 self.session.debug = debug
282 291
283 292 self._notification_handlers = {'registration_notification' : self._register_engine,
284 293 'unregistration_notification' : self._unregister_engine,
285 294 }
286 295 self._queue_handlers = {'execute_reply' : self._handle_execute_reply,
287 296 'apply_reply' : self._handle_apply_reply}
288 297 self._connect(sshserver, ssh_kwargs)
289 298
290 299
291 300 @property
292 301 def ids(self):
293 302 """Always up to date ids property."""
294 303 self._flush_notifications()
295 304 return self._ids
296 305
297 306 def _update_engines(self, engines):
298 307 """Update our engines dict and _ids from a dict of the form: {id:uuid}."""
299 308 for k,v in engines.iteritems():
300 309 eid = int(k)
301 310 self._engines[eid] = bytes(v) # force not unicode
302 311 self._ids.add(eid)
303 312
304 313 def _build_targets(self, targets):
305 314 """Turn valid target IDs or 'all' into two lists:
306 315 (int_ids, uuids).
307 316 """
308 317 if targets is None:
309 318 targets = self._ids
310 319 elif isinstance(targets, str):
311 320 if targets.lower() == 'all':
312 321 targets = self._ids
313 322 else:
314 323 raise TypeError("%r not valid str target, must be 'all'"%(targets))
315 324 elif isinstance(targets, int):
316 325 targets = [targets]
317 326 return [self._engines[t] for t in targets], list(targets)
318 327
319 328 def _connect(self, sshserver, ssh_kwargs):
320 329 """setup all our socket connections to the controller. This is called from
321 330 __init__."""
322 331 if self._connected:
323 332 return
324 333 self._connected=True
325 334
326 335 def connect_socket(s, addr):
327 336 if self._ssh:
328 337 return tunnel.tunnel_connection(s, addr, sshserver, **ssh_kwargs)
329 338 else:
330 339 return s.connect(addr)
331 340
332 341 self.session.send(self._registration_socket, 'connection_request')
333 342 idents,msg = self.session.recv(self._registration_socket,mode=0)
334 343 if self.debug:
335 344 pprint(msg)
336 345 msg = ss.Message(msg)
337 346 content = msg.content
338 347 if content.status == 'ok':
339 348 if content.queue:
340 349 self._mux_socket = self.context.socket(zmq.PAIR)
341 350 self._mux_socket.setsockopt(zmq.IDENTITY, self.session.session)
342 351 connect_socket(self._mux_socket, content.queue)
343 352 if content.task:
344 353 self._task_socket = self.context.socket(zmq.PAIR)
345 354 self._task_socket.setsockopt(zmq.IDENTITY, self.session.session)
346 355 connect_socket(self._task_socket, content.task)
347 356 if content.notification:
348 357 self._notification_socket = self.context.socket(zmq.SUB)
349 358 connect_socket(self._notification_socket, content.notification)
350 359 self._notification_socket.setsockopt(zmq.SUBSCRIBE, "")
351 360 if content.query:
352 361 self._query_socket = self.context.socket(zmq.PAIR)
353 362 self._query_socket.setsockopt(zmq.IDENTITY, self.session.session)
354 363 connect_socket(self._query_socket, content.query)
355 364 if content.control:
356 365 self._control_socket = self.context.socket(zmq.PAIR)
357 366 self._control_socket.setsockopt(zmq.IDENTITY, self.session.session)
358 367 connect_socket(self._control_socket, content.control)
359 368 self._update_engines(dict(content.engines))
360 369
361 370 else:
362 371 self._connected = False
363 372 raise Exception("Failed to connect!")
364 373
365 374 #--------------------------------------------------------------------------
366 375 # handlers and callbacks for incoming messages
367 376 #--------------------------------------------------------------------------
368 377
369 378 def _register_engine(self, msg):
370 379 """Register a new engine, and update our connection info."""
371 380 content = msg['content']
372 381 eid = content['id']
373 382 d = {eid : content['queue']}
374 383 self._update_engines(d)
375 384 self._ids.add(int(eid))
376 385
377 386 def _unregister_engine(self, msg):
378 387 """Unregister an engine that has died."""
379 388 content = msg['content']
380 389 eid = int(content['id'])
381 390 if eid in self._ids:
382 391 self._ids.remove(eid)
383 392 self._engines.pop(eid)
384 393
385 394 def _handle_execute_reply(self, msg):
386 395 """Save the reply to an execute_request into our results."""
387 396 parent = msg['parent_header']
388 397 msg_id = parent['msg_id']
389 398 if msg_id not in self.outstanding:
390 399 print("got unknown result: %s"%msg_id)
391 400 else:
392 401 self.outstanding.remove(msg_id)
393 402 self.results[msg_id] = ss.unwrap_exception(msg['content'])
394 403
395 404 def _handle_apply_reply(self, msg):
396 405 """Save the reply to an apply_request into our results."""
397 406 parent = msg['parent_header']
398 407 msg_id = parent['msg_id']
399 408 if msg_id not in self.outstanding:
400 409 print ("got unknown result: %s"%msg_id)
401 410 else:
402 411 self.outstanding.remove(msg_id)
403 412 content = msg['content']
404 413 if content['status'] == 'ok':
405 414 self.results[msg_id] = ss.unserialize_object(msg['buffers'])
406 415 elif content['status'] == 'aborted':
407 416 self.results[msg_id] = error.AbortedTask(msg_id)
408 417 elif content['status'] == 'resubmitted':
409 418 # TODO: handle resubmission
410 419 pass
411 420 else:
412 421 e = ss.unwrap_exception(content)
413 422 e_uuid = e.engine_info['engineid']
414 423 for k,v in self._engines.iteritems():
415 424 if v == e_uuid:
416 425 e.engine_info['engineid'] = k
417 426 break
418 427 self.results[msg_id] = e
419 428
420 429 def _flush_notifications(self):
421 430 """Flush notifications of engine registrations waiting
422 431 in ZMQ queue."""
423 432 msg = self.session.recv(self._notification_socket, mode=zmq.NOBLOCK)
424 433 while msg is not None:
425 434 if self.debug:
426 435 pprint(msg)
427 436 msg = msg[-1]
428 437 msg_type = msg['msg_type']
429 438 handler = self._notification_handlers.get(msg_type, None)
430 439 if handler is None:
431 440 raise Exception("Unhandled message type: %s"%msg.msg_type)
432 441 else:
433 442 handler(msg)
434 443 msg = self.session.recv(self._notification_socket, mode=zmq.NOBLOCK)
435 444
436 445 def _flush_results(self, sock):
437 446 """Flush task or queue results waiting in ZMQ queue."""
438 447 msg = self.session.recv(sock, mode=zmq.NOBLOCK)
439 448 while msg is not None:
440 449 if self.debug:
441 450 pprint(msg)
442 451 msg = msg[-1]
443 452 msg_type = msg['msg_type']
444 453 handler = self._queue_handlers.get(msg_type, None)
445 454 if handler is None:
446 455 raise Exception("Unhandled message type: %s"%msg.msg_type)
447 456 else:
448 457 handler(msg)
449 458 msg = self.session.recv(sock, mode=zmq.NOBLOCK)
450 459
451 460 def _flush_control(self, sock):
452 461 """Flush replies from the control channel waiting
453 462 in the ZMQ queue.
454 463
455 464 Currently: ignore them."""
456 465 msg = self.session.recv(sock, mode=zmq.NOBLOCK)
457 466 while msg is not None:
458 467 if self.debug:
459 468 pprint(msg)
460 469 msg = self.session.recv(sock, mode=zmq.NOBLOCK)
461 470
462 471 #--------------------------------------------------------------------------
463 472 # getitem
464 473 #--------------------------------------------------------------------------
465 474
466 475 def __getitem__(self, key):
467 476 """Dict access returns DirectView multiplexer objects or,
468 477 if key is None, a LoadBalancedView."""
469 478 if key is None:
470 479 return LoadBalancedView(self)
471 480 if isinstance(key, int):
472 481 if key not in self.ids:
473 482 raise IndexError("No such engine: %i"%key)
474 483 return DirectView(self, key)
475 484
476 485 if isinstance(key, slice):
477 486 indices = range(len(self.ids))[key]
478 487 ids = sorted(self._ids)
479 488 key = [ ids[i] for i in indices ]
480 489 # newkeys = sorted(self._ids)[thekeys[k]]
481 490
482 491 if isinstance(key, (tuple, list, xrange)):
483 492 _,targets = self._build_targets(list(key))
484 493 return DirectView(self, targets)
485 494 else:
486 495 raise TypeError("key by int/iterable of ints only, not %s"%(type(key)))
487 496
488 497 #--------------------------------------------------------------------------
489 498 # Begin public methods
490 499 #--------------------------------------------------------------------------
491 500
492 501 def spin(self):
493 502 """Flush any registration notifications and execution results
494 503 waiting in the ZMQ queue.
495 504 """
496 505 if self._notification_socket:
497 506 self._flush_notifications()
498 507 if self._mux_socket:
499 508 self._flush_results(self._mux_socket)
500 509 if self._task_socket:
501 510 self._flush_results(self._task_socket)
502 511 if self._control_socket:
503 512 self._flush_control(self._control_socket)
504 513
505 514 def barrier(self, msg_ids=None, timeout=-1):
506 515 """waits on one or more `msg_ids`, for up to `timeout` seconds.
507 516
508 517 Parameters
509 518 ----------
510 519 msg_ids : int, str, or list of ints and/or strs
511 520 ints are indices to self.history
512 521 strs are msg_ids
513 522 default: wait on all outstanding messages
514 523 timeout : float
515 524 a time in seconds, after which to give up.
516 525 default is -1, which means no timeout
517 526
518 527 Returns
519 528 -------
520 529 True : when all msg_ids are done
521 530 False : timeout reached, some msg_ids still outstanding
522 531 """
523 532 tic = time.time()
524 533 if msg_ids is None:
525 534 theids = self.outstanding
526 535 else:
527 536 if isinstance(msg_ids, (int, str)):
528 537 msg_ids = [msg_ids]
529 538 theids = set()
530 539 for msg_id in msg_ids:
531 540 if isinstance(msg_id, int):
532 541 msg_id = self.history[msg_id]
533 542 theids.add(msg_id)
534 543 self.spin()
535 544 while theids.intersection(self.outstanding):
536 545 if timeout >= 0 and ( time.time()-tic ) > timeout:
537 546 break
538 547 time.sleep(1e-3)
539 548 self.spin()
540 549 return len(theids.intersection(self.outstanding)) == 0
541 550
542 551 #--------------------------------------------------------------------------
543 552 # Control methods
544 553 #--------------------------------------------------------------------------
545 554
546 555 @spinfirst
547 556 @defaultblock
548 557 def clear(self, targets=None, block=None):
549 558 """Clear the namespace in target(s)."""
550 559 targets = self._build_targets(targets)[0]
551 560 for t in targets:
552 561 self.session.send(self._control_socket, 'clear_request', content={}, ident=t)
553 562 error = False
554 563 if self.block:
555 564 for i in range(len(targets)):
556 565 idents,msg = self.session.recv(self._control_socket,0)
557 566 if self.debug:
558 567 pprint(msg)
559 568 if msg['content']['status'] != 'ok':
560 569 error = ss.unwrap_exception(msg['content'])
561 570 if error:
562 571 return error
563 572
564 573
565 574 @spinfirst
566 575 @defaultblock
567 576 def abort(self, msg_ids = None, targets=None, block=None):
568 577 """Abort the execution queues of target(s)."""
569 578 targets = self._build_targets(targets)[0]
570 579 if isinstance(msg_ids, basestring):
571 580 msg_ids = [msg_ids]
572 581 content = dict(msg_ids=msg_ids)
573 582 for t in targets:
574 583 self.session.send(self._control_socket, 'abort_request',
575 584 content=content, ident=t)
576 585 error = False
577 586 if self.block:
578 587 for i in range(len(targets)):
579 588 idents,msg = self.session.recv(self._control_socket,0)
580 589 if self.debug:
581 590 pprint(msg)
582 591 if msg['content']['status'] != 'ok':
583 592 error = ss.unwrap_exception(msg['content'])
584 593 if error:
585 594 return error
586 595
587 596 @spinfirst
588 597 @defaultblock
589 598 def shutdown(self, targets=None, restart=False, controller=False, block=None):
590 599 """Terminates one or more engine processes, optionally including the controller."""
591 600 if controller:
592 601 targets = 'all'
593 602 targets = self._build_targets(targets)[0]
594 603 for t in targets:
595 604 self.session.send(self._control_socket, 'shutdown_request',
596 605 content={'restart':restart},ident=t)
597 606 error = False
598 607 if block or controller:
599 608 for i in range(len(targets)):
600 609 idents,msg = self.session.recv(self._control_socket,0)
601 610 if self.debug:
602 611 pprint(msg)
603 612 if msg['content']['status'] != 'ok':
604 613 error = ss.unwrap_exception(msg['content'])
605 614
606 615 if controller:
607 616 time.sleep(0.25)
608 617 self.session.send(self._query_socket, 'shutdown_request')
609 618 idents,msg = self.session.recv(self._query_socket, 0)
610 619 if self.debug:
611 620 pprint(msg)
612 621 if msg['content']['status'] != 'ok':
613 622 error = ss.unwrap_exception(msg['content'])
614 623
615 624 if error:
616 return error
625 raise error
617 626
618 627 #--------------------------------------------------------------------------
619 628 # Execution methods
620 629 #--------------------------------------------------------------------------
621 630
622 631 @defaultblock
623 632 def execute(self, code, targets='all', block=None):
624 633 """Executes `code` on `targets` in blocking or nonblocking manner.
625 634
635 ``execute`` is always `bound` (affects engine namespace)
636
626 637 Parameters
627 638 ----------
628 639 code : str
629 640 the code string to be executed
630 641 targets : int/str/list of ints/strs
631 642 the engines on which to execute
632 643 default : all
633 644 block : bool
634 645 whether or not to wait until done to return
635 646 default: self.block
636 647 """
637 # block = self.block if block is None else block
638 # saveblock = self.block
639 # self.block = block
640 648 result = self.apply(execute, (code,), targets=targets, block=block, bound=True)
641 # self.block = saveblock
642 649 return result
643 650
644 651 def run(self, code, block=None):
645 652 """Runs `code` on an engine.
646 653
647 654 Calls to this are load-balanced.
648 655
656 ``run`` is never `bound` (no effect on engine namespace)
657
649 658 Parameters
650 659 ----------
651 660 code : str
652 661 the code string to be executed
653 662 block : bool
654 663 whether or not to wait until done
655 664
656 665 """
657 666 result = self.apply(execute, (code,), targets=None, block=block, bound=False)
658 667 return result
659 668
660 669 def _maybe_raise(self, result):
661 670 """wrapper for maybe raising an exception if apply failed."""
662 671 if isinstance(result, error.RemoteError):
663 672 raise result
664 673
665 674 return result
666 675
667 676 def apply(self, f, args=None, kwargs=None, bound=True, block=None, targets=None,
668 677 after=None, follow=None):
669 678 """Call `f(*args, **kwargs)` on a remote engine(s), returning the result.
670 679
671 680 This is the central execution command for the client.
672 681
673 682 Parameters
674 683 ----------
675 684
676 685 f : function
677 686 The fuction to be called remotely
678 687 args : tuple/list
679 688 The positional arguments passed to `f`
680 689 kwargs : dict
681 690 The keyword arguments passed to `f`
682 691 bound : bool (default: True)
683 692 Whether to execute in the Engine(s) namespace, or in a clean
684 693 namespace not affecting the engine.
685 694 block : bool (default: self.block)
686 695 Whether to wait for the result, or return immediately.
687 696 False:
688 697 returns msg_id(s)
689 698 if multiple targets:
690 699 list of ids
691 700 True:
692 701 returns actual result(s) of f(*args, **kwargs)
693 702 if multiple targets:
694 703 dict of results, by engine ID
695 704 targets : int,list of ints, 'all', None
696 705 Specify the destination of the job.
697 706 if None:
698 707 Submit via Task queue for load-balancing.
699 708 if 'all':
700 709 Run on all active engines
701 710 if list:
702 711 Run on each specified engine
703 712 if int:
704 713 Run on single engine
705 714
706 715 after : Dependency or collection of msg_ids
707 716 Only for load-balanced execution (targets=None)
708 717 Specify a list of msg_ids as a time-based dependency.
709 718 This job will only be run *after* the dependencies
710 719 have been met.
711 720
712 721 follow : Dependency or collection of msg_ids
713 722 Only for load-balanced execution (targets=None)
714 723 Specify a list of msg_ids as a location-based dependency.
715 724 This job will only be run on an engine where this dependency
716 725 is met.
717 726
718 727 Returns
719 728 -------
720 729 if block is False:
721 730 if single target:
722 731 return msg_id
723 732 else:
724 733 return list of msg_ids
725 734 ? (should this be dict like block=True) ?
726 735 else:
727 736 if single target:
728 737 return result of f(*args, **kwargs)
729 738 else:
730 739 return dict of results, keyed by engine
731 740 """
732 741
733 742 # defaults:
734 743 block = block if block is not None else self.block
735 744 args = args if args is not None else []
736 745 kwargs = kwargs if kwargs is not None else {}
737 746
738 747 # enforce types of f,args,kwrags
739 748 if not callable(f):
740 749 raise TypeError("f must be callable, not %s"%type(f))
741 750 if not isinstance(args, (tuple, list)):
742 751 raise TypeError("args must be tuple or list, not %s"%type(args))
743 752 if not isinstance(kwargs, dict):
744 753 raise TypeError("kwargs must be dict, not %s"%type(kwargs))
745 754
746 755 options = dict(bound=bound, block=block, after=after, follow=follow)
747 756
748 757 if targets is None:
749 758 return self._apply_balanced(f, args, kwargs, **options)
750 759 else:
751 760 return self._apply_direct(f, args, kwargs, targets=targets, **options)
752 761
753 762 def _apply_balanced(self, f, args, kwargs, bound=True, block=None,
754 763 after=None, follow=None):
755 764 """The underlying method for applying functions in a load balanced
756 765 manner, via the task queue."""
757 766 if isinstance(after, Dependency):
758 767 after = after.as_dict()
759 768 elif after is None:
760 769 after = []
761 770 if isinstance(follow, Dependency):
762 771 follow = follow.as_dict()
763 772 elif follow is None:
764 773 follow = []
765 774 subheader = dict(after=after, follow=follow)
766 775
767 776 bufs = ss.pack_apply_message(f,args,kwargs)
768 777 content = dict(bound=bound)
769 778 msg = self.session.send(self._task_socket, "apply_request",
770 779 content=content, buffers=bufs, subheader=subheader)
771 780 msg_id = msg['msg_id']
772 781 self.outstanding.add(msg_id)
773 782 self.history.append(msg_id)
774 783 if block:
775 784 self.barrier(msg_id)
776 785 return self._maybe_raise(self.results[msg_id])
777 786 else:
778 787 return msg_id
779 788
780 789 def _apply_direct(self, f, args, kwargs, bound=True, block=None, targets=None,
781 790 after=None, follow=None):
782 791 """Then underlying method for applying functions to specific engines
783 792 via the MUX queue."""
784 793
785 794 queues,targets = self._build_targets(targets)
786 795 bufs = ss.pack_apply_message(f,args,kwargs)
787 796 if isinstance(after, Dependency):
788 797 after = after.as_dict()
789 798 elif after is None:
790 799 after = []
791 800 if isinstance(follow, Dependency):
792 801 follow = follow.as_dict()
793 802 elif follow is None:
794 803 follow = []
795 804 subheader = dict(after=after, follow=follow)
796 805 content = dict(bound=bound)
797 806 msg_ids = []
798 807 for queue in queues:
799 808 msg = self.session.send(self._mux_socket, "apply_request",
800 809 content=content, buffers=bufs,ident=queue, subheader=subheader)
801 810 msg_id = msg['msg_id']
802 811 self.outstanding.add(msg_id)
803 812 self.history.append(msg_id)
804 813 msg_ids.append(msg_id)
805 814 if block:
806 815 self.barrier(msg_ids)
807 816 else:
808 817 if len(msg_ids) == 1:
809 818 return msg_ids[0]
810 819 else:
811 820 return msg_ids
812 821 if len(msg_ids) == 1:
813 822 return self._maybe_raise(self.results[msg_ids[0]])
814 823 else:
815 824 result = {}
816 825 for target,mid in zip(targets, msg_ids):
817 826 result[target] = self.results[mid]
818 827 return error.collect_exceptions(result, f.__name__)
819 828
820 829 #--------------------------------------------------------------------------
821 830 # Data movement
822 831 #--------------------------------------------------------------------------
823 832
824 833 @defaultblock
825 834 def push(self, ns, targets=None, block=None):
826 835 """Push the contents of `ns` into the namespace on `target`"""
827 836 if not isinstance(ns, dict):
828 837 raise TypeError("Must be a dict, not %s"%type(ns))
829 838 result = self.apply(_push, (ns,), targets=targets, block=block, bound=True)
830 839 return result
831 840
832 841 @defaultblock
833 842 def pull(self, keys, targets=None, block=True):
834 843 """Pull objects from `target`'s namespace by `keys`"""
835 844 if isinstance(keys, str):
836 845 pass
837 846 elif isinstance(keys, (list,tuple,set)):
838 847 for key in keys:
839 848 if not isinstance(key, str):
840 849 raise TypeError
841 850 result = self.apply(_pull, (keys,), targets=targets, block=block, bound=True)
842 851 return result
843 852
844 853 #--------------------------------------------------------------------------
845 854 # Query methods
846 855 #--------------------------------------------------------------------------
847 856
848 857 @spinfirst
849 858 def get_results(self, msg_ids, status_only=False):
850 859 """Returns the result of the execute or task request with `msg_ids`.
851 860
852 861 Parameters
853 862 ----------
854 863 msg_ids : list of ints or msg_ids
855 864 if int:
856 865 Passed as index to self.history for convenience.
857 866 status_only : bool (default: False)
858 867 if False:
859 868 return the actual results
860 869 """
861 870 if not isinstance(msg_ids, (list,tuple)):
862 871 msg_ids = [msg_ids]
863 872 theids = []
864 873 for msg_id in msg_ids:
865 874 if isinstance(msg_id, int):
866 875 msg_id = self.history[msg_id]
867 876 if not isinstance(msg_id, str):
868 877 raise TypeError("msg_ids must be str, not %r"%msg_id)
869 878 theids.append(msg_id)
870 879
871 880 completed = []
872 881 local_results = {}
873 882 for msg_id in list(theids):
874 883 if msg_id in self.results:
875 884 completed.append(msg_id)
876 885 local_results[msg_id] = self.results[msg_id]
877 886 theids.remove(msg_id)
878 887
879 888 if theids: # some not locally cached
880 889 content = dict(msg_ids=theids, status_only=status_only)
881 890 msg = self.session.send(self._query_socket, "result_request", content=content)
882 891 zmq.select([self._query_socket], [], [])
883 892 idents,msg = self.session.recv(self._query_socket, zmq.NOBLOCK)
884 893 if self.debug:
885 894 pprint(msg)
886 895 content = msg['content']
887 896 if content['status'] != 'ok':
888 897 raise ss.unwrap_exception(content)
889 898 else:
890 899 content = dict(completed=[],pending=[])
891 900 if not status_only:
892 901 # load cached results into result:
893 902 content['completed'].extend(completed)
894 903 content.update(local_results)
895 904 # update cache with results:
896 905 for msg_id in msg_ids:
897 906 if msg_id in content['completed']:
898 907 self.results[msg_id] = content[msg_id]
899 908 return content
900 909
901 910 @spinfirst
902 911 def queue_status(self, targets=None, verbose=False):
903 912 """Fetch the status of engine queues.
904 913
905 914 Parameters
906 915 ----------
907 916 targets : int/str/list of ints/strs
908 917 the engines on which to execute
909 918 default : all
910 919 verbose : bool
911 whether to return lengths only, or lists of ids for each element
920 Whether to return lengths only, or lists of ids for each element
912 921 """
913 922 targets = self._build_targets(targets)[1]
914 923 content = dict(targets=targets, verbose=verbose)
915 924 self.session.send(self._query_socket, "queue_request", content=content)
916 925 idents,msg = self.session.recv(self._query_socket, 0)
917 926 if self.debug:
918 927 pprint(msg)
919 928 content = msg['content']
920 929 status = content.pop('status')
921 930 if status != 'ok':
922 931 raise ss.unwrap_exception(content)
923 932 return content
924 933
925 934 @spinfirst
926 935 def purge_results(self, msg_ids=[], targets=[]):
927 936 """Tell the controller to forget results.
928 937
929 938 Individual results can be purged by msg_id, or the entire
930 history of specific targets can
939 history of specific targets can be purged.
931 940
932 941 Parameters
933 942 ----------
943 msg_ids : str or list of strs
944 the msg_ids whose results should be forgotten.
934 945 targets : int/str/list of ints/strs
935 the targets
946 The targets, by uuid or int_id, whose entire history is to be purged.
947 Use `targets='all'` to scrub everything from the controller's memory.
948
936 949 default : None
937 950 """
938 951 if not targets and not msg_ids:
939 952 raise ValueError
940 953 if targets:
941 954 targets = self._build_targets(targets)[1]
942 955 content = dict(targets=targets, msg_ids=msg_ids)
943 956 self.session.send(self._query_socket, "purge_request", content=content)
944 957 idents, msg = self.session.recv(self._query_socket, 0)
945 958 if self.debug:
946 959 pprint(msg)
947 960 content = msg['content']
948 961 if content['status'] != 'ok':
949 962 raise ss.unwrap_exception(content)
950 963
951 964 class AsynClient(Client):
952 965 """An Asynchronous client, using the Tornado Event Loop.
953 966 !!!unfinished!!!"""
954 967 io_loop = None
955 968 _queue_stream = None
956 969 _notifier_stream = None
957 970 _task_stream = None
958 971 _control_stream = None
959 972
960 973 def __init__(self, addr, context=None, username=None, debug=False, io_loop=None):
961 974 Client.__init__(self, addr, context, username, debug)
962 975 if io_loop is None:
963 976 io_loop = ioloop.IOLoop.instance()
964 977 self.io_loop = io_loop
965 978
966 979 self._queue_stream = zmqstream.ZMQStream(self._mux_socket, io_loop)
967 980 self._control_stream = zmqstream.ZMQStream(self._control_socket, io_loop)
968 981 self._task_stream = zmqstream.ZMQStream(self._task_socket, io_loop)
969 982 self._notification_stream = zmqstream.ZMQStream(self._notification_socket, io_loop)
970 983
971 984 def spin(self):
972 985 for stream in (self.queue_stream, self.notifier_stream,
973 986 self.task_stream, self.control_stream):
974 987 stream.flush()
975 988
@@ -1,530 +1,544 b''
1 1 #!/usr/bin/env python
2 2 """edited session.py to work with streams, and move msg_type to the header
3 3 """
4 4
5 5
6 6 import os
7 7 import sys
8 8 import traceback
9 9 import pprint
10 10 import uuid
11 11 from datetime import datetime
12 12
13 13 import zmq
14 14 from zmq.utils import jsonapi
15 15 from zmq.eventloop.zmqstream import ZMQStream
16 16
17 17 from IPython.utils.pickleutil import can, uncan, canSequence, uncanSequence
18 18 from IPython.utils.newserialized import serialize, unserialize
19 19
20 20 from IPython.zmq.parallel.error import RemoteError
21 21
22 22 try:
23 23 import cPickle
24 24 pickle = cPickle
25 25 except:
26 26 cPickle = None
27 27 import pickle
28 28
29 29 # packer priority: jsonlib[2], cPickle, simplejson/json, pickle
30 30 json_name = '' if not jsonapi.jsonmod else jsonapi.jsonmod.__name__
31 31 if json_name in ('jsonlib', 'jsonlib2'):
32 32 use_json = True
33 33 elif json_name:
34 34 if cPickle is None:
35 35 use_json = True
36 36 else:
37 37 use_json = False
38 38 else:
39 39 use_json = False
40 40
41 41 def squash_unicode(obj):
42 42 if isinstance(obj,dict):
43 43 for key in obj.keys():
44 44 obj[key] = squash_unicode(obj[key])
45 45 if isinstance(key, unicode):
46 46 obj[squash_unicode(key)] = obj.pop(key)
47 47 elif isinstance(obj, list):
48 48 for i,v in enumerate(obj):
49 49 obj[i] = squash_unicode(v)
50 50 elif isinstance(obj, unicode):
51 51 obj = obj.encode('utf8')
52 52 return obj
53 53
54 54 if use_json:
55 55 default_packer = jsonapi.dumps
56 56 default_unpacker = lambda s: squash_unicode(jsonapi.loads(s))
57 57 else:
58 58 default_packer = lambda o: pickle.dumps(o,-1)
59 59 default_unpacker = pickle.loads
60 60
61 61
62 62 DELIM="<IDS|MSG>"
63 63 ISO8601="%Y-%m-%dT%H:%M:%S.%f"
64 64
65 65 def wrap_exception(engine_info={}):
66 66 etype, evalue, tb = sys.exc_info()
67 67 stb = traceback.format_exception(etype, evalue, tb)
68 68 exc_content = {
69 69 'status' : 'error',
70 70 'traceback' : stb,
71 71 'ename' : unicode(etype.__name__),
72 72 'evalue' : unicode(evalue),
73 73 'engine_info' : engine_info
74 74 }
75 75 return exc_content
76 76
77 77 def unwrap_exception(content):
78 78 err = RemoteError(content['ename'], content['evalue'],
79 79 ''.join(content['traceback']),
80 80 content.get('engine_info', {}))
81 81 return err
82 82
83 83
84 84 class Message(object):
85 85 """A simple message object that maps dict keys to attributes.
86 86
87 87 A Message can be created from a dict and a dict from a Message instance
88 88 simply by calling dict(msg_obj)."""
89 89
90 90 def __init__(self, msg_dict):
91 91 dct = self.__dict__
92 92 for k, v in dict(msg_dict).iteritems():
93 93 if isinstance(v, dict):
94 94 v = Message(v)
95 95 dct[k] = v
96 96
97 97 # Having this iterator lets dict(msg_obj) work out of the box.
98 98 def __iter__(self):
99 99 return iter(self.__dict__.iteritems())
100 100
101 101 def __repr__(self):
102 102 return repr(self.__dict__)
103 103
104 104 def __str__(self):
105 105 return pprint.pformat(self.__dict__)
106 106
107 107 def __contains__(self, k):
108 108 return k in self.__dict__
109 109
110 110 def __getitem__(self, k):
111 111 return self.__dict__[k]
112 112
113 113
114 114 def msg_header(msg_id, msg_type, username, session):
115 115 date=datetime.now().strftime(ISO8601)
116 116 return locals()
117 117
118 118 def extract_header(msg_or_header):
119 119 """Given a message or header, return the header."""
120 120 if not msg_or_header:
121 121 return {}
122 122 try:
123 123 # See if msg_or_header is the entire message.
124 124 h = msg_or_header['header']
125 125 except KeyError:
126 126 try:
127 127 # See if msg_or_header is just the header
128 128 h = msg_or_header['msg_id']
129 129 except KeyError:
130 130 raise
131 131 else:
132 132 h = msg_or_header
133 133 if not isinstance(h, dict):
134 134 h = dict(h)
135 135 return h
136 136
137 137 def rekey(dikt):
138 138 """Rekey a dict that has been forced to use str keys where there should be
139 139 ints by json. This belongs in the jsonutil added by fperez."""
140 140 for k in dikt.iterkeys():
141 141 if isinstance(k, str):
142 142 ik=fk=None
143 143 try:
144 144 ik = int(k)
145 145 except ValueError:
146 146 try:
147 147 fk = float(k)
148 148 except ValueError:
149 149 continue
150 150 if ik is not None:
151 151 nk = ik
152 152 else:
153 153 nk = fk
154 154 if nk in dikt:
155 155 raise KeyError("already have key %r"%nk)
156 156 dikt[nk] = dikt.pop(k)
157 157 return dikt
158 158
159 159 def serialize_object(obj, threshold=64e-6):
160 160 """Serialize an object into a list of sendable buffers.
161 161
162 162 Parameters
163 163 ----------
164 164
165 165 obj : object
166 166 The object to be serialized
167 167 threshold : float
168 168 The threshold for not double-pickling the content.
169 169
170 170
171 171 Returns
172 172 -------
173 173 ('pmd', [bufs]) :
174 174 where pmd is the pickled metadata wrapper,
175 175 bufs is a list of data buffers
176 176 """
177 177 databuffers = []
178 178 if isinstance(obj, (list, tuple)):
179 179 clist = canSequence(obj)
180 180 slist = map(serialize, clist)
181 181 for s in slist:
182 182 if s.typeDescriptor in ('buffer', 'ndarray') or s.getDataSize() > threshold:
183 183 databuffers.append(s.getData())
184 184 s.data = None
185 185 return pickle.dumps(slist,-1), databuffers
186 186 elif isinstance(obj, dict):
187 187 sobj = {}
188 188 for k in sorted(obj.iterkeys()):
189 189 s = serialize(can(obj[k]))
190 190 if s.typeDescriptor in ('buffer', 'ndarray') or s.getDataSize() > threshold:
191 191 databuffers.append(s.getData())
192 192 s.data = None
193 193 sobj[k] = s
194 194 return pickle.dumps(sobj,-1),databuffers
195 195 else:
196 196 s = serialize(can(obj))
197 197 if s.typeDescriptor in ('buffer', 'ndarray') or s.getDataSize() > threshold:
198 198 databuffers.append(s.getData())
199 199 s.data = None
200 200 return pickle.dumps(s,-1),databuffers
201 201
202 202
203 203 def unserialize_object(bufs):
204 """reconstruct an object serialized by serialize_object from data buffers"""
204 """reconstruct an object serialized by serialize_object from data buffers."""
205 205 bufs = list(bufs)
206 206 sobj = pickle.loads(bufs.pop(0))
207 207 if isinstance(sobj, (list, tuple)):
208 208 for s in sobj:
209 209 if s.data is None:
210 210 s.data = bufs.pop(0)
211 211 return uncanSequence(map(unserialize, sobj))
212 212 elif isinstance(sobj, dict):
213 213 newobj = {}
214 214 for k in sorted(sobj.iterkeys()):
215 215 s = sobj[k]
216 216 if s.data is None:
217 217 s.data = bufs.pop(0)
218 218 newobj[k] = uncan(unserialize(s))
219 219 return newobj
220 220 else:
221 221 if sobj.data is None:
222 222 sobj.data = bufs.pop(0)
223 223 return uncan(unserialize(sobj))
224 224
225 225 def pack_apply_message(f, args, kwargs, threshold=64e-6):
226 226 """pack up a function, args, and kwargs to be sent over the wire
227 227 as a series of buffers. Any object whose data is larger than `threshold`
228 228 will not have their data copied (currently only numpy arrays support zero-copy)"""
229 229 msg = [pickle.dumps(can(f),-1)]
230 230 databuffers = [] # for large objects
231 231 sargs, bufs = serialize_object(args,threshold)
232 232 msg.append(sargs)
233 233 databuffers.extend(bufs)
234 234 skwargs, bufs = serialize_object(kwargs,threshold)
235 235 msg.append(skwargs)
236 236 databuffers.extend(bufs)
237 237 msg.extend(databuffers)
238 238 return msg
239 239
240 240 def unpack_apply_message(bufs, g=None, copy=True):
241 241 """unpack f,args,kwargs from buffers packed by pack_apply_message()
242 242 Returns: original f,args,kwargs"""
243 243 bufs = list(bufs) # allow us to pop
244 244 assert len(bufs) >= 3, "not enough buffers!"
245 245 if not copy:
246 246 for i in range(3):
247 247 bufs[i] = bufs[i].bytes
248 248 cf = pickle.loads(bufs.pop(0))
249 249 sargs = list(pickle.loads(bufs.pop(0)))
250 250 skwargs = dict(pickle.loads(bufs.pop(0)))
251 251 # print sargs, skwargs
252 252 f = uncan(cf, g)
253 253 for sa in sargs:
254 254 if sa.data is None:
255 255 m = bufs.pop(0)
256 256 if sa.getTypeDescriptor() in ('buffer', 'ndarray'):
257 257 if copy:
258 258 sa.data = buffer(m)
259 259 else:
260 260 sa.data = m.buffer
261 261 else:
262 262 if copy:
263 263 sa.data = m
264 264 else:
265 265 sa.data = m.bytes
266 266
267 267 args = uncanSequence(map(unserialize, sargs), g)
268 268 kwargs = {}
269 269 for k in sorted(skwargs.iterkeys()):
270 270 sa = skwargs[k]
271 271 if sa.data is None:
272 272 sa.data = bufs.pop(0)
273 273 kwargs[k] = uncan(unserialize(sa), g)
274 274
275 275 return f,args,kwargs
276 276
277 277 class StreamSession(object):
278 278 """tweaked version of IPython.zmq.session.Session, for development in Parallel"""
279 279 debug=False
280 280 key=None
281 281
282 282 def __init__(self, username=None, session=None, packer=None, unpacker=None, key=None, keyfile=None):
283 283 if username is None:
284 284 username = os.environ.get('USER','username')
285 285 self.username = username
286 286 if session is None:
287 287 self.session = str(uuid.uuid4())
288 288 else:
289 289 self.session = session
290 290 self.msg_id = str(uuid.uuid4())
291 291 if packer is None:
292 292 self.pack = default_packer
293 293 else:
294 294 if not callable(packer):
295 295 raise TypeError("packer must be callable, not %s"%type(packer))
296 296 self.pack = packer
297 297
298 298 if unpacker is None:
299 299 self.unpack = default_unpacker
300 300 else:
301 301 if not callable(unpacker):
302 302 raise TypeError("unpacker must be callable, not %s"%type(unpacker))
303 303 self.unpack = unpacker
304 304
305 305 if key is not None and keyfile is not None:
306 306 raise TypeError("Must specify key OR keyfile, not both")
307 307 if keyfile is not None:
308 308 with open(keyfile) as f:
309 309 self.key = f.read().strip()
310 310 else:
311 311 self.key = key
312 312 # print key, keyfile, self.key
313 313 self.none = self.pack({})
314 314
315 315 def msg_header(self, msg_type):
316 316 h = msg_header(self.msg_id, msg_type, self.username, self.session)
317 317 self.msg_id = str(uuid.uuid4())
318 318 return h
319 319
320 320 def msg(self, msg_type, content=None, parent=None, subheader=None):
321 321 msg = {}
322 322 msg['header'] = self.msg_header(msg_type)
323 323 msg['msg_id'] = msg['header']['msg_id']
324 324 msg['parent_header'] = {} if parent is None else extract_header(parent)
325 325 msg['msg_type'] = msg_type
326 326 msg['content'] = {} if content is None else content
327 327 sub = {} if subheader is None else subheader
328 328 msg['header'].update(sub)
329 329 return msg
330 330
331 331 def check_key(self, msg_or_header):
332 332 """Check that a message's header has the right key"""
333 333 if self.key is None:
334 334 return True
335 335 header = extract_header(msg_or_header)
336 336 return header.get('key', None) == self.key
337 337
338 338
339 339 def send(self, stream, msg_type, content=None, buffers=None, parent=None, subheader=None, ident=None):
340 340 """Build and send a message via stream or socket.
341 341
342 342 Parameters
343 343 ----------
344 344
345 345 stream : zmq.Socket or ZMQStream
346 346 the socket-like object used to send the data
347 347 msg_type : str or Message/dict
348 348 Normally, msg_type will be
349 349
350 350
351 351
352 352 Returns
353 353 -------
354 354 (msg,sent) : tuple
355 355 msg : Message
356 356 the nice wrapped dict-like object containing the headers
357 357
358 358 """
359 359 if isinstance(msg_type, (Message, dict)):
360 360 # we got a Message, not a msg_type
361 361 # don't build a new Message
362 362 msg = msg_type
363 363 content = msg['content']
364 364 else:
365 365 msg = self.msg(msg_type, content, parent, subheader)
366 366 buffers = [] if buffers is None else buffers
367 367 to_send = []
368 368 if isinstance(ident, list):
369 369 # accept list of idents
370 370 to_send.extend(ident)
371 371 elif ident is not None:
372 372 to_send.append(ident)
373 373 to_send.append(DELIM)
374 374 if self.key is not None:
375 375 to_send.append(self.key)
376 376 to_send.append(self.pack(msg['header']))
377 377 to_send.append(self.pack(msg['parent_header']))
378 378
379 379 if content is None:
380 380 content = self.none
381 381 elif isinstance(content, dict):
382 382 content = self.pack(content)
383 383 elif isinstance(content, str):
384 384 # content is already packed, as in a relayed message
385 385 pass
386 386 else:
387 387 raise TypeError("Content incorrect type: %s"%type(content))
388 388 to_send.append(content)
389 389 flag = 0
390 390 if buffers:
391 391 flag = zmq.SNDMORE
392 392 stream.send_multipart(to_send, flag, copy=False)
393 393 for b in buffers[:-1]:
394 394 stream.send(b, flag, copy=False)
395 395 if buffers:
396 396 stream.send(buffers[-1], copy=False)
397 397 omsg = Message(msg)
398 398 if self.debug:
399 399 pprint.pprint(omsg)
400 400 pprint.pprint(to_send)
401 401 pprint.pprint(buffers)
402 402 return omsg
403 403
404 404 def send_raw(self, stream, msg, flags=0, copy=True, ident=None):
405 """Send a raw message via idents.
405 """Send a raw message via ident path.
406 406
407 407 Parameters
408 408 ----------
409 409 msg : list of sendable buffers"""
410 410 to_send = []
411 411 if isinstance(ident, str):
412 412 ident = [ident]
413 413 if ident is not None:
414 414 to_send.extend(ident)
415 415 to_send.append(DELIM)
416 416 if self.key is not None:
417 417 to_send.append(self.key)
418 418 to_send.extend(msg)
419 419 stream.send_multipart(msg, flags, copy=copy)
420 420
421 421 def recv(self, socket, mode=zmq.NOBLOCK, content=True, copy=True):
422 422 """receives and unpacks a message
423 423 returns [idents], msg"""
424 424 if isinstance(socket, ZMQStream):
425 425 socket = socket.socket
426 426 try:
427 427 msg = socket.recv_multipart(mode)
428 428 except zmq.ZMQError as e:
429 429 if e.errno == zmq.EAGAIN:
430 430 # We can convert EAGAIN to None as we know in this case
431 431 # recv_json won't return None.
432 432 return None
433 433 else:
434 434 raise
435 435 # return an actual Message object
436 436 # determine the number of idents by trying to unpack them.
437 437 # this is terrible:
438 438 idents, msg = self.feed_identities(msg, copy)
439 439 try:
440 440 return idents, self.unpack_message(msg, content=content, copy=copy)
441 441 except Exception as e:
442 442 print (idents, msg)
443 443 # TODO: handle it
444 444 raise e
445 445
446 446 def feed_identities(self, msg, copy=True):
447 """This is a completely horrible thing, but it strips the zmq
448 ident prefixes off of a message. It will break if any identities
449 are unpackable by self.unpack."""
447 """feed until DELIM is reached, then return the prefix as idents and remainder as
448 msg. This is easily broken by setting an IDENT to DELIM, but that would be silly.
449
450 Parameters
451 ----------
452 msg : a list of Message or bytes objects
453 the message to be split
454 copy : bool
455 flag determining whether the arguments are bytes or Messages
456
457 Returns
458 -------
459 (idents,msg) : two lists
460 idents will always be a list of bytes - the indentity prefix
461 msg will be a list of bytes or Messages, unchanged from input
462 msg should be unpackable via self.unpack_message at this point.
463 """
450 464 msg = list(msg)
451 465 idents = []
452 466 while len(msg) > 3:
453 467 if copy:
454 468 s = msg[0]
455 469 else:
456 470 s = msg[0].bytes
457 471 if s == DELIM:
458 472 msg.pop(0)
459 473 break
460 474 else:
461 475 idents.append(s)
462 476 msg.pop(0)
463 477
464 478 return idents, msg
465 479
466 480 def unpack_message(self, msg, content=True, copy=True):
467 481 """Return a message object from the format
468 482 sent by self.send.
469 483
470 484 Parameters:
471 485 -----------
472 486
473 487 content : bool (True)
474 488 whether to unpack the content dict (True),
475 489 or leave it serialized (False)
476 490
477 491 copy : bool (True)
478 492 whether to return the bytes (True),
479 493 or the non-copying Message object in each place (False)
480 494
481 495 """
482 496 ikey = int(self.key is not None)
483 497 minlen = 3 + ikey
484 498 if not len(msg) >= minlen:
485 499 raise TypeError("malformed message, must have at least %i elements"%minlen)
486 500 message = {}
487 501 if not copy:
488 502 for i in range(minlen):
489 503 msg[i] = msg[i].bytes
490 504 if ikey:
491 505 if not self.key == msg[0]:
492 506 raise KeyError("Invalid Session Key: %s"%msg[0])
493 507 message['header'] = self.unpack(msg[ikey+0])
494 508 message['msg_type'] = message['header']['msg_type']
495 509 message['parent_header'] = self.unpack(msg[ikey+1])
496 510 if content:
497 511 message['content'] = self.unpack(msg[ikey+2])
498 512 else:
499 513 message['content'] = msg[ikey+2]
500 514
501 515 # message['buffers'] = msg[3:]
502 516 # else:
503 517 # message['header'] = self.unpack(msg[0].bytes)
504 518 # message['msg_type'] = message['header']['msg_type']
505 519 # message['parent_header'] = self.unpack(msg[1].bytes)
506 520 # if content:
507 521 # message['content'] = self.unpack(msg[2].bytes)
508 522 # else:
509 523 # message['content'] = msg[2].bytes
510 524
511 525 message['buffers'] = msg[ikey+3:]# [ m.buffer for m in msg[3:] ]
512 526 return message
513 527
514 528
515 529
516 530 def test_msg2obj():
517 531 am = dict(x=1)
518 532 ao = Message(am)
519 533 assert ao.x == am['x']
520 534
521 535 am['y'] = dict(z=1)
522 536 ao = Message(am)
523 537 assert ao.y.z == am['y']['z']
524 538
525 539 k1, k2 = 'y', 'z'
526 540 assert ao[k1][k2] == am[k1][k2]
527 541
528 542 am2 = dict(ao)
529 543 assert am['x'] == am2['x']
530 544 assert am['y']['z'] == am2['y']['z']
@@ -1,220 +1,264 b''
1 #!/usr/bin/env python
2 """Views"""
1 """Views of remote engines"""
2 #-----------------------------------------------------------------------------
3 # Copyright (C) 2010 The IPython Development Team
4 #
5 # Distributed under the terms of the BSD License. The full license is in
6 # the file COPYING, distributed as part of this software.
7 #-----------------------------------------------------------------------------
8
9 #-----------------------------------------------------------------------------
10 # Imports
11 #-----------------------------------------------------------------------------
3 12
4 13 from IPython.external.decorator import decorator
5 14
15 #-----------------------------------------------------------------------------
16 # Decorators
17 #-----------------------------------------------------------------------------
6 18
7 19 @decorator
8 20 def myblock(f, self, *args, **kwargs):
21 """override client.block with self.block during a call"""
9 22 block = self.client.block
10 23 self.client.block = self.block
11 24 ret = f(self, *args, **kwargs)
12 25 self.client.block = block
13 26 return ret
14 27
15 28 @decorator
16 29 def save_ids(f, self, *args, **kwargs):
30 """Keep our history and outstanding attributes up to date after a method call."""
17 31 ret = f(self, *args, **kwargs)
18 32 msg_ids = self.client.history[-self._ntargets:]
19 33 self.history.extend(msg_ids)
20 34 map(self.outstanding.add, msg_ids)
21 35 return ret
22 36
23 37 @decorator
24 38 def sync_results(f, self, *args, **kwargs):
39 """sync relevant results from self.client to our results attribute."""
25 40 ret = f(self, *args, **kwargs)
26 41 delta = self.outstanding.difference(self.client.outstanding)
27 42 completed = self.outstanding.intersection(delta)
28 43 self.outstanding = self.outstanding.difference(completed)
29 44 for msg_id in completed:
30 45 self.results[msg_id] = self.client.results[msg_id]
31 46 return ret
32 47
33 48 @decorator
34 49 def spin_after(f, self, *args, **kwargs):
50 """call spin after the method."""
35 51 ret = f(self, *args, **kwargs)
36 52 self.spin()
37 53 return ret
38 54
55 #-----------------------------------------------------------------------------
56 # Classes
57 #-----------------------------------------------------------------------------
39 58
40 59 class View(object):
41 """Base View class"""
60 """Base View class for more convenint apply(f,*args,**kwargs) syntax via attributes.
61
62 Don't use this class, use subclasses.
63 """
42 64 _targets = None
43 65 _ntargets = None
44 66 block=None
45 67 bound=None
46 68 history=None
47 69
48 70 def __init__(self, client, targets=None):
49 71 self.client = client
50 72 self._targets = targets
51 73 self._ntargets = 1 if isinstance(targets, (int,type(None))) else len(targets)
52 74 self.block = client.block
53 75 self.bound=True
54 76 self.history = []
55 77 self.outstanding = set()
56 78 self.results = {}
57 79
58 80 def __repr__(self):
59 81 strtargets = str(self._targets)
60 82 if len(strtargets) > 16:
61 83 strtargets = strtargets[:12]+'...]'
62 84 return "<%s %s>"%(self.__class__.__name__, strtargets)
63 85
64 86 @property
65 87 def targets(self):
66 88 return self._targets
67 89
68 90 @targets.setter
69 91 def targets(self, value):
70 raise TypeError("Cannot set my targets argument after construction!")
92 raise AttributeError("Cannot set my targets argument after construction!")
71 93
72 94 @sync_results
73 95 def spin(self):
74 96 """spin the client, and sync"""
75 97 self.client.spin()
76 98
77 99 @sync_results
78 100 @save_ids
79 101 def apply(self, f, *args, **kwargs):
80 102 """calls f(*args, **kwargs) on remote engines, returning the result.
81 103
82 104 This method does not involve the engine's namespace.
83 105
84 106 if self.block is False:
85 107 returns msg_id
86 108 else:
87 109 returns actual result of f(*args, **kwargs)
88 110 """
89 111 return self.client.apply(f, args, kwargs, block=self.block, targets=self.targets, bound=self.bound)
90 112
91 113 @save_ids
92 114 def apply_async(self, f, *args, **kwargs):
93 115 """calls f(*args, **kwargs) on remote engines in a nonblocking manner.
94 116
95 117 This method does not involve the engine's namespace.
96 118
97 119 returns msg_id
98 120 """
99 121 return self.client.apply(f,args,kwargs, block=False, targets=self.targets, bound=False)
100 122
101 123 @spin_after
102 124 @save_ids
103 125 def apply_sync(self, f, *args, **kwargs):
104 126 """calls f(*args, **kwargs) on remote engines in a blocking manner,
105 127 returning the result.
106 128
107 129 This method does not involve the engine's namespace.
108 130
109 131 returns: actual result of f(*args, **kwargs)
110 132 """
111 133 return self.client.apply(f,args,kwargs, block=True, targets=self.targets, bound=False)
112 134
113 135 @sync_results
114 136 @save_ids
115 137 def apply_bound(self, f, *args, **kwargs):
116 138 """calls f(*args, **kwargs) bound to engine namespace(s).
117 139
118 140 if self.block is False:
119 141 returns msg_id
120 142 else:
121 143 returns actual result of f(*args, **kwargs)
122 144
123 145 This method has access to the targets' globals
124 146
125 147 """
126 148 return self.client.apply(f, args, kwargs, block=self.block, targets=self.targets, bound=True)
127 149
128 150 @sync_results
129 151 @save_ids
130 152 def apply_async_bound(self, f, *args, **kwargs):
131 153 """calls f(*args, **kwargs) bound to engine namespace(s)
132 154 in a nonblocking manner.
133 155
134 156 returns: msg_id
135 157
136 158 This method has access to the targets' globals
137 159
138 160 """
139 161 return self.client.apply(f, args, kwargs, block=False, targets=self.targets, bound=True)
140 162
141 163 @spin_after
142 164 @save_ids
143 165 def apply_sync_bound(self, f, *args, **kwargs):
144 166 """calls f(*args, **kwargs) bound to engine namespace(s), waiting for the result.
145 167
146 168 returns: actual result of f(*args, **kwargs)
147 169
148 170 This method has access to the targets' globals
149 171
150 172 """
151 173 return self.client.apply(f, args, kwargs, block=True, targets=self.targets, bound=True)
152 174
153 175 def abort(self, msg_ids=None, block=None):
154 176 """Abort jobs on my engines.
155 177
156 178 Parameters
157 179 ----------
158 180
159 181 msg_ids : None, str, list of strs, optional
160 182 if None: abort all jobs.
161 183 else: abort specific msg_id(s).
162 184 """
163 185 block = block if block is not None else self.block
164 186 return self.client.abort(msg_ids=msg_ids, targets=self.targets, block=block)
165 187
166 188 def queue_status(self, verbose=False):
167 189 """Fetch the Queue status of my engines"""
168 190 return self.client.queue_status(targets=self.targets, verbose=verbose)
169 191
170 192 def purge_results(self, msg_ids=[],targets=[]):
171 193 """Instruct the controller to forget specific results."""
172 194 if targets is None or targets == 'all':
173 195 targets = self.targets
174 196 return self.client.purge_results(msg_ids=msg_ids, targets=targets)
175 197
176 198
177 199 class DirectView(View):
178 """Direct Multiplexer View"""
200 """Direct Multiplexer View of one or more engines.
201
202 These are created via indexed access to a client:
203
204 >>> dv_1 = client[1]
205 >>> dv_all = client[:]
206 >>> dv_even = client[::2]
207 >>> dv_some = client[1:3]
208
209 """
179 210
180 211 def update(self, ns):
181 212 """update remote namespace with dict `ns`"""
182 213 return self.client.push(ns, targets=self.targets, block=self.block)
183 214
184 215 push = update
185 216
186 217 def get(self, key_s):
187 218 """get object(s) by `key_s` from remote namespace
188 219 will return one object if it is a key.
189 220 It also takes a list of keys, and will return a list of objects."""
190 221 # block = block if block is not None else self.block
191 222 return self.client.pull(key_s, block=True, targets=self.targets)
192 223
193 224 def pull(self, key_s, block=True):
194 225 """get object(s) by `key_s` from remote namespace
195 226 will return one object if it is a key.
196 227 It also takes a list of keys, and will return a list of objects."""
197 228 block = block if block is not None else self.block
198 229 return self.client.pull(key_s, block=block, targets=self.targets)
199 230
200 231 def __getitem__(self, key):
201 232 return self.get(key)
202 233
203 234 def __setitem__(self,key, value):
204 235 self.update({key:value})
205 236
206 237 def clear(self, block=False):
207 238 """Clear the remote namespaces on my engines."""
208 239 block = block if block is not None else self.block
209 240 return self.client.clear(targets=self.targets, block=block)
210 241
211 242 def kill(self, block=True):
212 243 """Kill my engines."""
213 244 block = block if block is not None else self.block
214 245 return self.client.kill(targets=self.targets, block=block)
215 246
216 247 class LoadBalancedView(View):
248 """An engine-agnostic View that only executes via the Task queue.
249
250 Typically created via:
251
252 >>> lbv = client[None]
253 <LoadBalancedView tcp://127.0.0.1:12345>
254
255 but can also be created with:
256
257 >>> lbc = LoadBalancedView(client)
258
259 TODO: allow subset of engines across which to balance.
260 """
217 261 def __repr__(self):
218 262 return "<%s %s>"%(self.__class__.__name__, self.client._addr)
219 263
220 264 No newline at end of file
General Comments 0
You need to be logged in to leave comments. Login now