##// END OF EJS Templates
PendingResult->AsyncResult; match multiprocessing.AsyncResult api
MinRK -
Show More
@@ -0,0 +1,112
1 """AsyncResult objects for the client"""
2 #-----------------------------------------------------------------------------
3 # Copyright (C) 2010 The IPython Development Team
4 #
5 # Distributed under the terms of the BSD License. The full license is in
6 # the file COPYING, distributed as part of this software.
7 #-----------------------------------------------------------------------------
8
9 #-----------------------------------------------------------------------------
10 # Imports
11 #-----------------------------------------------------------------------------
12
13 import error
14
15 #-----------------------------------------------------------------------------
16 # Classes
17 #-----------------------------------------------------------------------------
18
19 class AsyncResult(object):
20 """Class for representing results of non-blocking calls.
21
22 Provides the same interface as :py:class:`multiprocessing.AsyncResult`.
23 """
24 def __init__(self, client, msg_ids):
25 self._client = client
26 self._msg_ids = msg_ids
27 self._ready = False
28 self._success = None
29
30 def __repr__(self):
31 if self._ready:
32 return "<%s: finished>"%(self.__class__.__name__)
33 else:
34 return "<%s: %r>"%(self.__class__.__name__,self._msg_ids)
35
36
37 def _reconstruct_result(self, res):
38 """
39 Override me in subclasses for turning a list of results
40 into the expected form.
41 """
42 if len(res) == 1:
43 return res[0]
44 else:
45 return res
46
47 def get(self, timeout=-1):
48 """Return the result when it arrives.
49
50 If `timeout` is not ``None`` and the result does not arrive within
51 `timeout` seconds then ``TimeoutError`` is raised. If the
52 remote call raised an exception then that exception will be reraised
53 by get().
54 """
55 if not self.ready():
56 self.wait(timeout)
57
58 if self._ready:
59 if self._success:
60 return self._result
61 else:
62 raise self._exception
63 else:
64 raise error.TimeoutError("Result not ready.")
65
66 def ready(self):
67 """Return whether the call has completed."""
68 if not self._ready:
69 self.wait(0)
70 return self._ready
71
72 def wait(self, timeout=-1):
73 """Wait until the result is available or until `timeout` seconds pass.
74 """
75 if self._ready:
76 return
77 self._ready = self._client.barrier(self._msg_ids, timeout)
78 if self._ready:
79 try:
80 results = map(self._client.results.get, self._msg_ids)
81 results = error.collect_exceptions(results, 'get')
82 self._result = self._reconstruct_result(results)
83 except Exception, e:
84 self._exception = e
85 self._success = False
86 else:
87 self._success = True
88
89
90 def successful(self):
91 """Return whether the call completed without raising an exception.
92
93 Will raise ``AssertionError`` if the result is not ready.
94 """
95 assert self._ready
96 return self._success
97
98 class AsyncMapResult(AsyncResult):
99 """Class for representing results of non-blocking gathers.
100
101 This will properly reconstruct the gather.
102 """
103
104 def __init__(self, client, msg_ids, mapObject):
105 self._mapObject = mapObject
106 AsyncResult.__init__(self, client, msg_ids)
107
108 def _reconstruct_result(self, res):
109 """Perform the gather on the actual results."""
110 return self._mapObject.joinPartitions(res)
111
112
@@ -1,1007 +1,1019
1 1 """A semi-synchronous Client for the ZMQ controller"""
2 2 #-----------------------------------------------------------------------------
3 3 # Copyright (C) 2010 The IPython Development Team
4 4 #
5 5 # Distributed under the terms of the BSD License. The full license is in
6 6 # the file COPYING, distributed as part of this software.
7 7 #-----------------------------------------------------------------------------
8 8
9 9 #-----------------------------------------------------------------------------
10 10 # Imports
11 11 #-----------------------------------------------------------------------------
12 12
13 13 from __future__ import print_function
14 14
15 15 import os
16 16 import time
17 17 from getpass import getpass
18 18 from pprint import pprint
19 19
20 20 import zmq
21 21 from zmq.eventloop import ioloop, zmqstream
22 22
23 23 from IPython.external.decorator import decorator
24 24 from IPython.zmq import tunnel
25 25
26 26 import streamsession as ss
27 27 # from remotenamespace import RemoteNamespace
28 28 from view import DirectView, LoadBalancedView
29 29 from dependency import Dependency, depend, require
30 30 import error
31 31 import map as Map
32 from pendingresult import PendingResult,PendingMapResult
32 from asyncresult import AsyncResult, AsyncMapResult
33 33 from remotefunction import remote,parallel,ParallelFunction,RemoteFunction
34 34
35 35 #--------------------------------------------------------------------------
36 36 # helpers for implementing old MEC API via client.apply
37 37 #--------------------------------------------------------------------------
38 38
39 39 def _push(ns):
40 40 """helper method for implementing `client.push` via `client.apply`"""
41 41 globals().update(ns)
42 42
43 43 def _pull(keys):
44 44 """helper method for implementing `client.pull` via `client.apply`"""
45 45 g = globals()
46 46 if isinstance(keys, (list,tuple, set)):
47 47 for key in keys:
48 48 if not g.has_key(key):
49 49 raise NameError("name '%s' is not defined"%key)
50 50 return map(g.get, keys)
51 51 else:
52 52 if not g.has_key(keys):
53 53 raise NameError("name '%s' is not defined"%keys)
54 54 return g.get(keys)
55 55
56 56 def _clear():
57 57 """helper method for implementing `client.clear` via `client.apply`"""
58 58 globals().clear()
59 59
60 60 def execute(code):
61 61 """helper method for implementing `client.execute` via `client.apply`"""
62 62 exec code in globals()
63 63
64 64
65 65 #--------------------------------------------------------------------------
66 66 # Decorators for Client methods
67 67 #--------------------------------------------------------------------------
68 68
69 69 @decorator
70 70 def spinfirst(f, self, *args, **kwargs):
71 71 """Call spin() to sync state prior to calling the method."""
72 72 self.spin()
73 73 return f(self, *args, **kwargs)
74 74
75 75 @decorator
76 76 def defaultblock(f, self, *args, **kwargs):
77 77 """Default to self.block; preserve self.block."""
78 78 block = kwargs.get('block',None)
79 79 block = self.block if block is None else block
80 80 saveblock = self.block
81 81 self.block = block
82 82 ret = f(self, *args, **kwargs)
83 83 self.block = saveblock
84 84 return ret
85 85
86 86
87 87 class AbortedTask(object):
88 88 """A basic wrapper object describing an aborted task."""
89 89 def __init__(self, msg_id):
90 90 self.msg_id = msg_id
91 91
92 92 class ResultDict(dict):
93 93 """A subclass of dict that raises errors if it has them."""
94 94 def __getitem__(self, key):
95 95 res = dict.__getitem__(self, key)
96 96 if isinstance(res, error.KernelError):
97 97 raise res
98 98 return res
99 99
100 100 class Client(object):
101 101 """A semi-synchronous client to the IPython ZMQ controller
102 102
103 103 Parameters
104 104 ----------
105 105
106 106 addr : bytes; zmq url, e.g. 'tcp://127.0.0.1:10101'
107 107 The address of the controller's registration socket.
108 108 [Default: 'tcp://127.0.0.1:10101']
109 109 context : zmq.Context
110 110 Pass an existing zmq.Context instance, otherwise the client will create its own
111 111 username : bytes
112 112 set username to be passed to the Session object
113 113 debug : bool
114 114 flag for lots of message printing for debug purposes
115 115
116 116 #-------------- ssh related args ----------------
117 117 # These are args for configuring the ssh tunnel to be used
118 118 # credentials are used to forward connections over ssh to the Controller
119 119 # Note that the ip given in `addr` needs to be relative to sshserver
120 120 # The most basic case is to leave addr as pointing to localhost (127.0.0.1),
121 121 # and set sshserver as the same machine the Controller is on. However,
122 122 # the only requirement is that sshserver is able to see the Controller
123 123 # (i.e. is within the same trusted network).
124 124
125 125 sshserver : str
126 126 A string of the form passed to ssh, i.e. 'server.tld' or 'user@server.tld:port'
127 127 If keyfile or password is specified, and this is not, it will default to
128 128 the ip given in addr.
129 129 sshkey : str; path to public ssh key file
130 130 This specifies a key to be used in ssh login, default None.
131 131 Regular default ssh keys will be used without specifying this argument.
132 132 password : str;
133 133 Your ssh password to sshserver. Note that if this is left None,
134 134 you will be prompted for it if passwordless key based login is unavailable.
135 135
136 136 #------- exec authentication args -------
137 137 # If even localhost is untrusted, you can have some protection against
138 138 # unauthorized execution by using a key. Messages are still sent
139 139 # as cleartext, so if someone can snoop your loopback traffic this will
140 140 # not help anything.
141 141
142 142 exec_key : str
143 143 an authentication key or file containing a key
144 144 default: None
145 145
146 146
147 147 Attributes
148 148 ----------
149 149 ids : set of int engine IDs
150 150 requesting the ids attribute always synchronizes
151 151 the registration state. To request ids without synchronization,
152 152 use semi-private _ids attributes.
153 153
154 154 history : list of msg_ids
155 155 a list of msg_ids, keeping track of all the execution
156 156 messages you have submitted in order.
157 157
158 158 outstanding : set of msg_ids
159 159 a set of msg_ids that have been submitted, but whose
160 160 results have not yet been received.
161 161
162 162 results : dict
163 163 a dict of all our results, keyed by msg_id
164 164
165 165 block : bool
166 166 determines default behavior when block not specified
167 167 in execution methods
168 168
169 169 Methods
170 170 -------
171 171 spin : flushes incoming results and registration state changes
172 172 control methods spin, and requesting `ids` also ensures up to date
173 173
174 174 barrier : wait on one or more msg_ids
175 175
176 176 execution methods: apply/apply_bound/apply_to/apply_bound
177 177 legacy: execute, run
178 178
179 179 query methods: queue_status, get_result, purge
180 180
181 181 control methods: abort, kill
182 182
183 183 """
184 184
185 185
186 186 _connected=False
187 187 _ssh=False
188 188 _engines=None
189 189 _addr='tcp://127.0.0.1:10101'
190 190 _registration_socket=None
191 191 _query_socket=None
192 192 _control_socket=None
193 193 _notification_socket=None
194 194 _mux_socket=None
195 195 _task_socket=None
196 196 block = False
197 197 outstanding=None
198 198 results = None
199 199 history = None
200 200 debug = False
201 201
202 202 def __init__(self, addr='tcp://127.0.0.1:10101', context=None, username=None, debug=False,
203 203 sshserver=None, sshkey=None, password=None, paramiko=None,
204 204 exec_key=None,):
205 205 if context is None:
206 206 context = zmq.Context()
207 207 self.context = context
208 208 self._addr = addr
209 209 self._ssh = bool(sshserver or sshkey or password)
210 210 if self._ssh and sshserver is None:
211 211 # default to the same
212 212 sshserver = addr.split('://')[1].split(':')[0]
213 213 if self._ssh and password is None:
214 214 if tunnel.try_passwordless_ssh(sshserver, sshkey, paramiko):
215 215 password=False
216 216 else:
217 217 password = getpass("SSH Password for %s: "%sshserver)
218 218 ssh_kwargs = dict(keyfile=sshkey, password=password, paramiko=paramiko)
219 219
220 220 if exec_key is not None and os.path.isfile(exec_key):
221 221 arg = 'keyfile'
222 222 else:
223 223 arg = 'key'
224 224 key_arg = {arg:exec_key}
225 225 if username is None:
226 226 self.session = ss.StreamSession(**key_arg)
227 227 else:
228 228 self.session = ss.StreamSession(username, **key_arg)
229 229 self._registration_socket = self.context.socket(zmq.XREQ)
230 230 self._registration_socket.setsockopt(zmq.IDENTITY, self.session.session)
231 231 if self._ssh:
232 232 tunnel.tunnel_connection(self._registration_socket, addr, sshserver, **ssh_kwargs)
233 233 else:
234 234 self._registration_socket.connect(addr)
235 235 self._engines = {}
236 236 self._ids = set()
237 237 self.outstanding=set()
238 238 self.results = {}
239 239 self.history = []
240 240 self.debug = debug
241 241 self.session.debug = debug
242 242
243 243 self._notification_handlers = {'registration_notification' : self._register_engine,
244 244 'unregistration_notification' : self._unregister_engine,
245 245 }
246 246 self._queue_handlers = {'execute_reply' : self._handle_execute_reply,
247 247 'apply_reply' : self._handle_apply_reply}
248 248 self._connect(sshserver, ssh_kwargs)
249 249
250 250
251 251 @property
252 252 def ids(self):
253 253 """Always up to date ids property."""
254 254 self._flush_notifications()
255 255 return self._ids
256 256
257 257 def _update_engines(self, engines):
258 258 """Update our engines dict and _ids from a dict of the form: {id:uuid}."""
259 259 for k,v in engines.iteritems():
260 260 eid = int(k)
261 261 self._engines[eid] = bytes(v) # force not unicode
262 262 self._ids.add(eid)
263 263
264 264 def _build_targets(self, targets):
265 265 """Turn valid target IDs or 'all' into two lists:
266 266 (int_ids, uuids).
267 267 """
268 268 if targets is None:
269 269 targets = self._ids
270 270 elif isinstance(targets, str):
271 271 if targets.lower() == 'all':
272 272 targets = self._ids
273 273 else:
274 274 raise TypeError("%r not valid str target, must be 'all'"%(targets))
275 275 elif isinstance(targets, int):
276 276 targets = [targets]
277 277 return [self._engines[t] for t in targets], list(targets)
278 278
279 279 def _connect(self, sshserver, ssh_kwargs):
280 280 """setup all our socket connections to the controller. This is called from
281 281 __init__."""
282 282 if self._connected:
283 283 return
284 284 self._connected=True
285 285
286 286 def connect_socket(s, addr):
287 287 if self._ssh:
288 288 return tunnel.tunnel_connection(s, addr, sshserver, **ssh_kwargs)
289 289 else:
290 290 return s.connect(addr)
291 291
292 292 self.session.send(self._registration_socket, 'connection_request')
293 293 idents,msg = self.session.recv(self._registration_socket,mode=0)
294 294 if self.debug:
295 295 pprint(msg)
296 296 msg = ss.Message(msg)
297 297 content = msg.content
298 298 if content.status == 'ok':
299 299 if content.queue:
300 300 self._mux_socket = self.context.socket(zmq.PAIR)
301 301 self._mux_socket.setsockopt(zmq.IDENTITY, self.session.session)
302 302 connect_socket(self._mux_socket, content.queue)
303 303 if content.task:
304 304 self._task_socket = self.context.socket(zmq.PAIR)
305 305 self._task_socket.setsockopt(zmq.IDENTITY, self.session.session)
306 306 connect_socket(self._task_socket, content.task)
307 307 if content.notification:
308 308 self._notification_socket = self.context.socket(zmq.SUB)
309 309 connect_socket(self._notification_socket, content.notification)
310 310 self._notification_socket.setsockopt(zmq.SUBSCRIBE, "")
311 311 if content.query:
312 312 self._query_socket = self.context.socket(zmq.PAIR)
313 313 self._query_socket.setsockopt(zmq.IDENTITY, self.session.session)
314 314 connect_socket(self._query_socket, content.query)
315 315 if content.control:
316 316 self._control_socket = self.context.socket(zmq.PAIR)
317 317 self._control_socket.setsockopt(zmq.IDENTITY, self.session.session)
318 318 connect_socket(self._control_socket, content.control)
319 319 self._update_engines(dict(content.engines))
320 320
321 321 else:
322 322 self._connected = False
323 323 raise Exception("Failed to connect!")
324 324
325 325 #--------------------------------------------------------------------------
326 326 # handlers and callbacks for incoming messages
327 327 #--------------------------------------------------------------------------
328 328
329 329 def _register_engine(self, msg):
330 330 """Register a new engine, and update our connection info."""
331 331 content = msg['content']
332 332 eid = content['id']
333 333 d = {eid : content['queue']}
334 334 self._update_engines(d)
335 335 self._ids.add(int(eid))
336 336
337 337 def _unregister_engine(self, msg):
338 338 """Unregister an engine that has died."""
339 339 content = msg['content']
340 340 eid = int(content['id'])
341 341 if eid in self._ids:
342 342 self._ids.remove(eid)
343 343 self._engines.pop(eid)
344 344
345 345 def _handle_execute_reply(self, msg):
346 346 """Save the reply to an execute_request into our results."""
347 347 parent = msg['parent_header']
348 348 msg_id = parent['msg_id']
349 349 if msg_id not in self.outstanding:
350 350 print("got unknown result: %s"%msg_id)
351 351 else:
352 352 self.outstanding.remove(msg_id)
353 353 self.results[msg_id] = ss.unwrap_exception(msg['content'])
354 354
355 355 def _handle_apply_reply(self, msg):
356 356 """Save the reply to an apply_request into our results."""
357 357 parent = msg['parent_header']
358 358 msg_id = parent['msg_id']
359 359 if msg_id not in self.outstanding:
360 360 print ("got unknown result: %s"%msg_id)
361 361 else:
362 362 self.outstanding.remove(msg_id)
363 363 content = msg['content']
364 364 if content['status'] == 'ok':
365 365 self.results[msg_id] = ss.unserialize_object(msg['buffers'])
366 366 elif content['status'] == 'aborted':
367 367 self.results[msg_id] = error.AbortedTask(msg_id)
368 368 elif content['status'] == 'resubmitted':
369 369 # TODO: handle resubmission
370 370 pass
371 371 else:
372 372 e = ss.unwrap_exception(content)
373 373 e_uuid = e.engine_info['engineid']
374 374 for k,v in self._engines.iteritems():
375 375 if v == e_uuid:
376 376 e.engine_info['engineid'] = k
377 377 break
378 378 self.results[msg_id] = e
379 379
380 380 def _flush_notifications(self):
381 381 """Flush notifications of engine registrations waiting
382 382 in ZMQ queue."""
383 383 msg = self.session.recv(self._notification_socket, mode=zmq.NOBLOCK)
384 384 while msg is not None:
385 385 if self.debug:
386 386 pprint(msg)
387 387 msg = msg[-1]
388 388 msg_type = msg['msg_type']
389 389 handler = self._notification_handlers.get(msg_type, None)
390 390 if handler is None:
391 391 raise Exception("Unhandled message type: %s"%msg.msg_type)
392 392 else:
393 393 handler(msg)
394 394 msg = self.session.recv(self._notification_socket, mode=zmq.NOBLOCK)
395 395
396 396 def _flush_results(self, sock):
397 397 """Flush task or queue results waiting in ZMQ queue."""
398 398 msg = self.session.recv(sock, mode=zmq.NOBLOCK)
399 399 while msg is not None:
400 400 if self.debug:
401 401 pprint(msg)
402 402 msg = msg[-1]
403 403 msg_type = msg['msg_type']
404 404 handler = self._queue_handlers.get(msg_type, None)
405 405 if handler is None:
406 406 raise Exception("Unhandled message type: %s"%msg.msg_type)
407 407 else:
408 408 handler(msg)
409 409 msg = self.session.recv(sock, mode=zmq.NOBLOCK)
410 410
411 411 def _flush_control(self, sock):
412 412 """Flush replies from the control channel waiting
413 413 in the ZMQ queue.
414 414
415 415 Currently: ignore them."""
416 416 msg = self.session.recv(sock, mode=zmq.NOBLOCK)
417 417 while msg is not None:
418 418 if self.debug:
419 419 pprint(msg)
420 420 msg = self.session.recv(sock, mode=zmq.NOBLOCK)
421 421
422 422 #--------------------------------------------------------------------------
423 423 # getitem
424 424 #--------------------------------------------------------------------------
425 425
426 426 def __getitem__(self, key):
427 427 """Dict access returns DirectView multiplexer objects or,
428 428 if key is None, a LoadBalancedView."""
429 429 if key is None:
430 430 return LoadBalancedView(self)
431 431 if isinstance(key, int):
432 432 if key not in self.ids:
433 433 raise IndexError("No such engine: %i"%key)
434 434 return DirectView(self, key)
435 435
436 436 if isinstance(key, slice):
437 437 indices = range(len(self.ids))[key]
438 438 ids = sorted(self._ids)
439 439 key = [ ids[i] for i in indices ]
440 440 # newkeys = sorted(self._ids)[thekeys[k]]
441 441
442 442 if isinstance(key, (tuple, list, xrange)):
443 443 _,targets = self._build_targets(list(key))
444 444 return DirectView(self, targets)
445 445 else:
446 446 raise TypeError("key by int/iterable of ints only, not %s"%(type(key)))
447 447
448 448 #--------------------------------------------------------------------------
449 449 # Begin public methods
450 450 #--------------------------------------------------------------------------
451 451
452 452 @property
453 453 def remote(self):
454 454 """property for convenient RemoteFunction generation.
455 455
456 456 >>> @client.remote
457 457 ... def f():
458 458 import os
459 459 print (os.getpid())
460 460 """
461 461 return remote(self, block=self.block)
462 462
463 463 def spin(self):
464 464 """Flush any registration notifications and execution results
465 465 waiting in the ZMQ queue.
466 466 """
467 467 if self._notification_socket:
468 468 self._flush_notifications()
469 469 if self._mux_socket:
470 470 self._flush_results(self._mux_socket)
471 471 if self._task_socket:
472 472 self._flush_results(self._task_socket)
473 473 if self._control_socket:
474 474 self._flush_control(self._control_socket)
475 475
476 476 def barrier(self, msg_ids=None, timeout=-1):
477 477 """waits on one or more `msg_ids`, for up to `timeout` seconds.
478 478
479 479 Parameters
480 480 ----------
481 481 msg_ids : int, str, or list of ints and/or strs
482 482 ints are indices to self.history
483 483 strs are msg_ids
484 484 default: wait on all outstanding messages
485 485 timeout : float
486 486 a time in seconds, after which to give up.
487 487 default is -1, which means no timeout
488 488
489 489 Returns
490 490 -------
491 491 True : when all msg_ids are done
492 492 False : timeout reached, some msg_ids still outstanding
493 493 """
494 494 tic = time.time()
495 495 if msg_ids is None:
496 496 theids = self.outstanding
497 497 else:
498 498 if isinstance(msg_ids, (int, str)):
499 499 msg_ids = [msg_ids]
500 500 theids = set()
501 501 for msg_id in msg_ids:
502 502 if isinstance(msg_id, int):
503 503 msg_id = self.history[msg_id]
504 504 theids.add(msg_id)
505 505 self.spin()
506 506 while theids.intersection(self.outstanding):
507 507 if timeout >= 0 and ( time.time()-tic ) > timeout:
508 508 break
509 509 time.sleep(1e-3)
510 510 self.spin()
511 511 return len(theids.intersection(self.outstanding)) == 0
512 512
513 513 #--------------------------------------------------------------------------
514 514 # Control methods
515 515 #--------------------------------------------------------------------------
516 516
517 517 @spinfirst
518 518 @defaultblock
519 519 def clear(self, targets=None, block=None):
520 520 """Clear the namespace in target(s)."""
521 521 targets = self._build_targets(targets)[0]
522 522 for t in targets:
523 523 self.session.send(self._control_socket, 'clear_request', content={}, ident=t)
524 524 error = False
525 525 if self.block:
526 526 for i in range(len(targets)):
527 527 idents,msg = self.session.recv(self._control_socket,0)
528 528 if self.debug:
529 529 pprint(msg)
530 530 if msg['content']['status'] != 'ok':
531 531 error = ss.unwrap_exception(msg['content'])
532 532 if error:
533 533 return error
534 534
535 535
536 536 @spinfirst
537 537 @defaultblock
538 538 def abort(self, msg_ids = None, targets=None, block=None):
539 539 """Abort the execution queues of target(s)."""
540 540 targets = self._build_targets(targets)[0]
541 541 if isinstance(msg_ids, basestring):
542 542 msg_ids = [msg_ids]
543 543 content = dict(msg_ids=msg_ids)
544 544 for t in targets:
545 545 self.session.send(self._control_socket, 'abort_request',
546 546 content=content, ident=t)
547 547 error = False
548 548 if self.block:
549 549 for i in range(len(targets)):
550 550 idents,msg = self.session.recv(self._control_socket,0)
551 551 if self.debug:
552 552 pprint(msg)
553 553 if msg['content']['status'] != 'ok':
554 554 error = ss.unwrap_exception(msg['content'])
555 555 if error:
556 556 return error
557 557
558 558 @spinfirst
559 559 @defaultblock
560 560 def shutdown(self, targets=None, restart=False, controller=False, block=None):
561 561 """Terminates one or more engine processes, optionally including the controller."""
562 562 if controller:
563 563 targets = 'all'
564 564 targets = self._build_targets(targets)[0]
565 565 for t in targets:
566 566 self.session.send(self._control_socket, 'shutdown_request',
567 567 content={'restart':restart},ident=t)
568 568 error = False
569 569 if block or controller:
570 570 for i in range(len(targets)):
571 571 idents,msg = self.session.recv(self._control_socket,0)
572 572 if self.debug:
573 573 pprint(msg)
574 574 if msg['content']['status'] != 'ok':
575 575 error = ss.unwrap_exception(msg['content'])
576 576
577 577 if controller:
578 578 time.sleep(0.25)
579 579 self.session.send(self._query_socket, 'shutdown_request')
580 580 idents,msg = self.session.recv(self._query_socket, 0)
581 581 if self.debug:
582 582 pprint(msg)
583 583 if msg['content']['status'] != 'ok':
584 584 error = ss.unwrap_exception(msg['content'])
585 585
586 586 if error:
587 587 raise error
588 588
589 589 #--------------------------------------------------------------------------
590 590 # Execution methods
591 591 #--------------------------------------------------------------------------
592 592
593 593 @defaultblock
594 594 def execute(self, code, targets='all', block=None):
595 595 """Executes `code` on `targets` in blocking or nonblocking manner.
596 596
597 597 ``execute`` is always `bound` (affects engine namespace)
598 598
599 599 Parameters
600 600 ----------
601 601 code : str
602 602 the code string to be executed
603 603 targets : int/str/list of ints/strs
604 604 the engines on which to execute
605 605 default : all
606 606 block : bool
607 607 whether or not to wait until done to return
608 608 default: self.block
609 609 """
610 610 result = self.apply(execute, (code,), targets=targets, block=block, bound=True)
611 611 return result
612 612
613 613 def run(self, code, block=None):
614 614 """Runs `code` on an engine.
615 615
616 616 Calls to this are load-balanced.
617 617
618 618 ``run`` is never `bound` (no effect on engine namespace)
619 619
620 620 Parameters
621 621 ----------
622 622 code : str
623 623 the code string to be executed
624 624 block : bool
625 625 whether or not to wait until done
626 626
627 627 """
628 628 result = self.apply(execute, (code,), targets=None, block=block, bound=False)
629 629 return result
630 630
631 631 def _maybe_raise(self, result):
632 632 """wrapper for maybe raising an exception if apply failed."""
633 633 if isinstance(result, error.RemoteError):
634 634 raise result
635 635
636 636 return result
637 637
638 638 def apply(self, f, args=None, kwargs=None, bound=True, block=None, targets=None,
639 639 after=None, follow=None):
640 640 """Call `f(*args, **kwargs)` on a remote engine(s), returning the result.
641 641
642 642 This is the central execution command for the client.
643 643
644 644 Parameters
645 645 ----------
646 646
647 647 f : function
648 648 The fuction to be called remotely
649 649 args : tuple/list
650 650 The positional arguments passed to `f`
651 651 kwargs : dict
652 652 The keyword arguments passed to `f`
653 653 bound : bool (default: True)
654 654 Whether to execute in the Engine(s) namespace, or in a clean
655 655 namespace not affecting the engine.
656 656 block : bool (default: self.block)
657 657 Whether to wait for the result, or return immediately.
658 658 False:
659 659 returns msg_id(s)
660 660 if multiple targets:
661 661 list of ids
662 662 True:
663 663 returns actual result(s) of f(*args, **kwargs)
664 664 if multiple targets:
665 665 dict of results, by engine ID
666 666 targets : int,list of ints, 'all', None
667 667 Specify the destination of the job.
668 668 if None:
669 669 Submit via Task queue for load-balancing.
670 670 if 'all':
671 671 Run on all active engines
672 672 if list:
673 673 Run on each specified engine
674 674 if int:
675 675 Run on single engine
676 676
677 677 after : Dependency or collection of msg_ids
678 678 Only for load-balanced execution (targets=None)
679 679 Specify a list of msg_ids as a time-based dependency.
680 680 This job will only be run *after* the dependencies
681 681 have been met.
682 682
683 683 follow : Dependency or collection of msg_ids
684 684 Only for load-balanced execution (targets=None)
685 685 Specify a list of msg_ids as a location-based dependency.
686 686 This job will only be run on an engine where this dependency
687 687 is met.
688 688
689 689 Returns
690 690 -------
691 691 if block is False:
692 692 if single target:
693 693 return msg_id
694 694 else:
695 695 return list of msg_ids
696 696 ? (should this be dict like block=True) ?
697 697 else:
698 698 if single target:
699 699 return result of f(*args, **kwargs)
700 700 else:
701 701 return dict of results, keyed by engine
702 702 """
703 703
704 704 # defaults:
705 705 block = block if block is not None else self.block
706 706 args = args if args is not None else []
707 707 kwargs = kwargs if kwargs is not None else {}
708 708
709 709 # enforce types of f,args,kwrags
710 710 if not callable(f):
711 711 raise TypeError("f must be callable, not %s"%type(f))
712 712 if not isinstance(args, (tuple, list)):
713 713 raise TypeError("args must be tuple or list, not %s"%type(args))
714 714 if not isinstance(kwargs, dict):
715 715 raise TypeError("kwargs must be dict, not %s"%type(kwargs))
716 716
717 717 options = dict(bound=bound, block=block, after=after, follow=follow)
718 718
719 719 if targets is None:
720 720 return self._apply_balanced(f, args, kwargs, **options)
721 721 else:
722 722 return self._apply_direct(f, args, kwargs, targets=targets, **options)
723 723
724 724 def _apply_balanced(self, f, args, kwargs, bound=True, block=None,
725 725 after=None, follow=None):
726 726 """The underlying method for applying functions in a load balanced
727 727 manner, via the task queue."""
728 728 if isinstance(after, Dependency):
729 729 after = after.as_dict()
730 730 elif after is None:
731 731 after = []
732 732 if isinstance(follow, Dependency):
733 733 follow = follow.as_dict()
734 734 elif follow is None:
735 735 follow = []
736 736 subheader = dict(after=after, follow=follow)
737 737
738 738 bufs = ss.pack_apply_message(f,args,kwargs)
739 739 content = dict(bound=bound)
740 740 msg = self.session.send(self._task_socket, "apply_request",
741 741 content=content, buffers=bufs, subheader=subheader)
742 742 msg_id = msg['msg_id']
743 743 self.outstanding.add(msg_id)
744 744 self.history.append(msg_id)
745 745 if block:
746 746 self.barrier(msg_id)
747 747 return self._maybe_raise(self.results[msg_id])
748 748 else:
749 return PendingResult(self, [msg_id])
749 return AsyncResult(self, [msg_id])
750 750
751 751 def _apply_direct(self, f, args, kwargs, bound=True, block=None, targets=None,
752 752 after=None, follow=None):
753 753 """Then underlying method for applying functions to specific engines
754 754 via the MUX queue."""
755 755
756 756 queues,targets = self._build_targets(targets)
757 757 bufs = ss.pack_apply_message(f,args,kwargs)
758 758 if isinstance(after, Dependency):
759 759 after = after.as_dict()
760 760 elif after is None:
761 761 after = []
762 762 if isinstance(follow, Dependency):
763 763 follow = follow.as_dict()
764 764 elif follow is None:
765 765 follow = []
766 766 subheader = dict(after=after, follow=follow)
767 767 content = dict(bound=bound)
768 768 msg_ids = []
769 769 for queue in queues:
770 770 msg = self.session.send(self._mux_socket, "apply_request",
771 771 content=content, buffers=bufs,ident=queue, subheader=subheader)
772 772 msg_id = msg['msg_id']
773 773 self.outstanding.add(msg_id)
774 774 self.history.append(msg_id)
775 775 msg_ids.append(msg_id)
776 776 if block:
777 777 self.barrier(msg_ids)
778 778 else:
779 return PendingResult(self, msg_ids)
779 return AsyncResult(self, msg_ids)
780 780 if len(msg_ids) == 1:
781 781 return self._maybe_raise(self.results[msg_ids[0]])
782 782 else:
783 783 result = {}
784 784 for target,mid in zip(targets, msg_ids):
785 785 result[target] = self.results[mid]
786 786 return error.collect_exceptions(result, f.__name__)
787 787
788 #--------------------------------------------------------------------------
789 # Map and decorators
790 #--------------------------------------------------------------------------
791
788 792 def map(self, f, *sequences):
789 793 """Parallel version of builtin `map`, using all our engines."""
790 794 pf = ParallelFunction(self, f, block=self.block,
791 795 bound=True, targets='all')
792 796 return pf.map(*sequences)
793 797
798 def parallel(self, bound=True, targets='all', block=True):
799 """Decorator for making a ParallelFunction"""
800 return parallel(self, bound=bound, targets=targets, block=block)
801
802 def remote(self, bound=True, targets='all', block=True):
803 """Decorator for making a RemoteFunction"""
804 return remote(self, bound=bound, targets=targets, block=block)
805
794 806 #--------------------------------------------------------------------------
795 807 # Data movement
796 808 #--------------------------------------------------------------------------
797 809
798 810 @defaultblock
799 811 def push(self, ns, targets='all', block=None):
800 812 """Push the contents of `ns` into the namespace on `target`"""
801 813 if not isinstance(ns, dict):
802 814 raise TypeError("Must be a dict, not %s"%type(ns))
803 815 result = self.apply(_push, (ns,), targets=targets, block=block, bound=True)
804 816 return result
805 817
806 818 @defaultblock
807 819 def pull(self, keys, targets='all', block=True):
808 820 """Pull objects from `target`'s namespace by `keys`"""
809 821 if isinstance(keys, str):
810 822 pass
811 823 elif isinstance(keys, (list,tuple,set)):
812 824 for key in keys:
813 825 if not isinstance(key, str):
814 826 raise TypeError
815 827 result = self.apply(_pull, (keys,), targets=targets, block=block, bound=True)
816 828 return result
817 829
818 830 @defaultblock
819 831 def scatter(self, key, seq, dist='b', flatten=False, targets='all', block=None):
820 832 """
821 833 Partition a Python sequence and send the partitions to a set of engines.
822 834 """
823 835 targets = self._build_targets(targets)[-1]
824 836 mapObject = Map.dists[dist]()
825 837 nparts = len(targets)
826 838 msg_ids = []
827 839 for index, engineid in enumerate(targets):
828 840 partition = mapObject.getPartition(seq, index, nparts)
829 841 if flatten and len(partition) == 1:
830 842 mid = self.push({key: partition[0]}, targets=engineid, block=False)
831 843 else:
832 844 mid = self.push({key: partition}, targets=engineid, block=False)
833 845 msg_ids.append(mid)
834 r = PendingResult(self, msg_ids)
846 r = AsyncResult(self, msg_ids)
835 847 if block:
836 848 r.wait()
837 849 return
838 850 else:
839 851 return r
840 852
841 853 @defaultblock
842 854 def gather(self, key, dist='b', targets='all', block=True):
843 855 """
844 856 Gather a partitioned sequence on a set of engines as a single local seq.
845 857 """
846 858
847 859 targets = self._build_targets(targets)[-1]
848 860 mapObject = Map.dists[dist]()
849 861 msg_ids = []
850 862 for index, engineid in enumerate(targets):
851 863 msg_ids.append(self.pull(key, targets=engineid,block=False))
852 864
853 r = PendingMapResult(self, msg_ids, mapObject)
865 r = AsyncMapResult(self, msg_ids, mapObject)
854 866 if block:
855 867 r.wait()
856 868 return r.result
857 869 else:
858 870 return r
859 871
860 872 #--------------------------------------------------------------------------
861 873 # Query methods
862 874 #--------------------------------------------------------------------------
863 875
864 876 @spinfirst
865 877 def get_results(self, msg_ids, status_only=False):
866 878 """Returns the result of the execute or task request with `msg_ids`.
867 879
868 880 Parameters
869 881 ----------
870 882 msg_ids : list of ints or msg_ids
871 883 if int:
872 884 Passed as index to self.history for convenience.
873 885 status_only : bool (default: False)
874 886 if False:
875 887 return the actual results
876 888 """
877 889 if not isinstance(msg_ids, (list,tuple)):
878 890 msg_ids = [msg_ids]
879 891 theids = []
880 892 for msg_id in msg_ids:
881 893 if isinstance(msg_id, int):
882 894 msg_id = self.history[msg_id]
883 895 if not isinstance(msg_id, str):
884 896 raise TypeError("msg_ids must be str, not %r"%msg_id)
885 897 theids.append(msg_id)
886 898
887 899 completed = []
888 900 local_results = {}
889 901 for msg_id in list(theids):
890 902 if msg_id in self.results:
891 903 completed.append(msg_id)
892 904 local_results[msg_id] = self.results[msg_id]
893 905 theids.remove(msg_id)
894 906
895 907 if theids: # some not locally cached
896 908 content = dict(msg_ids=theids, status_only=status_only)
897 909 msg = self.session.send(self._query_socket, "result_request", content=content)
898 910 zmq.select([self._query_socket], [], [])
899 911 idents,msg = self.session.recv(self._query_socket, zmq.NOBLOCK)
900 912 if self.debug:
901 913 pprint(msg)
902 914 content = msg['content']
903 915 if content['status'] != 'ok':
904 916 raise ss.unwrap_exception(content)
905 917 else:
906 918 content = dict(completed=[],pending=[])
907 919 if not status_only:
908 920 # load cached results into result:
909 921 content['completed'].extend(completed)
910 922 content.update(local_results)
911 923 # update cache with results:
912 924 for msg_id in msg_ids:
913 925 if msg_id in content['completed']:
914 926 self.results[msg_id] = content[msg_id]
915 927 return content
916 928
917 929 @spinfirst
918 930 def queue_status(self, targets=None, verbose=False):
919 931 """Fetch the status of engine queues.
920 932
921 933 Parameters
922 934 ----------
923 935 targets : int/str/list of ints/strs
924 936 the engines on which to execute
925 937 default : all
926 938 verbose : bool
927 939 Whether to return lengths only, or lists of ids for each element
928 940 """
929 941 targets = self._build_targets(targets)[1]
930 942 content = dict(targets=targets, verbose=verbose)
931 943 self.session.send(self._query_socket, "queue_request", content=content)
932 944 idents,msg = self.session.recv(self._query_socket, 0)
933 945 if self.debug:
934 946 pprint(msg)
935 947 content = msg['content']
936 948 status = content.pop('status')
937 949 if status != 'ok':
938 950 raise ss.unwrap_exception(content)
939 951 return content
940 952
941 953 @spinfirst
942 954 def purge_results(self, msg_ids=[], targets=[]):
943 955 """Tell the controller to forget results.
944 956
945 957 Individual results can be purged by msg_id, or the entire
946 958 history of specific targets can be purged.
947 959
948 960 Parameters
949 961 ----------
950 962 msg_ids : str or list of strs
951 963 the msg_ids whose results should be forgotten.
952 964 targets : int/str/list of ints/strs
953 965 The targets, by uuid or int_id, whose entire history is to be purged.
954 966 Use `targets='all'` to scrub everything from the controller's memory.
955 967
956 968 default : None
957 969 """
958 970 if not targets and not msg_ids:
959 971 raise ValueError
960 972 if targets:
961 973 targets = self._build_targets(targets)[1]
962 974 content = dict(targets=targets, msg_ids=msg_ids)
963 975 self.session.send(self._query_socket, "purge_request", content=content)
964 976 idents, msg = self.session.recv(self._query_socket, 0)
965 977 if self.debug:
966 978 pprint(msg)
967 979 content = msg['content']
968 980 if content['status'] != 'ok':
969 981 raise ss.unwrap_exception(content)
970 982
971 983 class AsynClient(Client):
972 984 """An Asynchronous client, using the Tornado Event Loop.
973 985 !!!unfinished!!!"""
974 986 io_loop = None
975 987 _queue_stream = None
976 988 _notifier_stream = None
977 989 _task_stream = None
978 990 _control_stream = None
979 991
980 992 def __init__(self, addr, context=None, username=None, debug=False, io_loop=None):
981 993 Client.__init__(self, addr, context, username, debug)
982 994 if io_loop is None:
983 995 io_loop = ioloop.IOLoop.instance()
984 996 self.io_loop = io_loop
985 997
986 998 self._queue_stream = zmqstream.ZMQStream(self._mux_socket, io_loop)
987 999 self._control_stream = zmqstream.ZMQStream(self._control_socket, io_loop)
988 1000 self._task_stream = zmqstream.ZMQStream(self._task_socket, io_loop)
989 1001 self._notification_stream = zmqstream.ZMQStream(self._notification_socket, io_loop)
990 1002
991 1003 def spin(self):
992 1004 for stream in (self.queue_stream, self.notifier_stream,
993 1005 self.task_stream, self.control_stream):
994 1006 stream.flush()
995 1007
996 1008 __all__ = [ 'Client',
997 1009 'depend',
998 1010 'require',
999 1011 'remote',
1000 1012 'parallel',
1001 1013 'RemoteFunction',
1002 1014 'ParallelFunction',
1003 1015 'DirectView',
1004 1016 'LoadBalancedView',
1005 'PendingResult',
1006 'PendingMapResult'
1017 'AsyncResult',
1018 'AsyncMapResult'
1007 1019 ]
@@ -1,280 +1,283
1 1 # encoding: utf-8
2 2
3 3 """Classes and functions for kernel related errors and exceptions."""
4 4 from __future__ import print_function
5 5
6 6 __docformat__ = "restructuredtext en"
7 7
8 8 # Tell nose to skip this module
9 9 __test__ = {}
10 10
11 11 #-------------------------------------------------------------------------------
12 12 # Copyright (C) 2008 The IPython Development Team
13 13 #
14 14 # Distributed under the terms of the BSD License. The full license is in
15 15 # the file COPYING, distributed as part of this software.
16 16 #-------------------------------------------------------------------------------
17 17
18 18 #-------------------------------------------------------------------------------
19 19 # Error classes
20 20 #-------------------------------------------------------------------------------
21 21 class IPythonError(Exception):
22 22 """Base exception that all of our exceptions inherit from.
23 23
24 24 This can be raised by code that doesn't have any more specific
25 25 information."""
26 26
27 27 pass
28 28
29 29 # Exceptions associated with the controller objects
30 30 class ControllerError(IPythonError): pass
31 31
32 32 class ControllerCreationError(ControllerError): pass
33 33
34 34
35 35 # Exceptions associated with the Engines
36 36 class EngineError(IPythonError): pass
37 37
38 38 class EngineCreationError(EngineError): pass
39 39
40 40 class KernelError(IPythonError):
41 41 pass
42 42
43 43 class NotDefined(KernelError):
44 44 def __init__(self, name):
45 45 self.name = name
46 46 self.args = (name,)
47 47
48 48 def __repr__(self):
49 49 return '<NotDefined: %s>' % self.name
50 50
51 51 __str__ = __repr__
52 52
53 53
54 54 class QueueCleared(KernelError):
55 55 pass
56 56
57 57
58 58 class IdInUse(KernelError):
59 59 pass
60 60
61 61
62 62 class ProtocolError(KernelError):
63 63 pass
64 64
65 65
66 66 class ConnectionError(KernelError):
67 67 pass
68 68
69 69
70 70 class InvalidEngineID(KernelError):
71 71 pass
72 72
73 73
74 74 class NoEnginesRegistered(KernelError):
75 75 pass
76 76
77 77
78 78 class InvalidClientID(KernelError):
79 79 pass
80 80
81 81
82 82 class InvalidDeferredID(KernelError):
83 83 pass
84 84
85 85
86 86 class SerializationError(KernelError):
87 87 pass
88 88
89 89
90 90 class MessageSizeError(KernelError):
91 91 pass
92 92
93 93
94 94 class PBMessageSizeError(MessageSizeError):
95 95 pass
96 96
97 97
98 98 class ResultNotCompleted(KernelError):
99 99 pass
100 100
101 101
102 102 class ResultAlreadyRetrieved(KernelError):
103 103 pass
104 104
105 105 class ClientError(KernelError):
106 106 pass
107 107
108 108
109 109 class TaskAborted(KernelError):
110 110 pass
111 111
112 112
113 113 class TaskTimeout(KernelError):
114 114 pass
115 115
116 116
117 117 class NotAPendingResult(KernelError):
118 118 pass
119 119
120 120
121 121 class UnpickleableException(KernelError):
122 122 pass
123 123
124 124
125 125 class AbortedPendingDeferredError(KernelError):
126 126 pass
127 127
128 128
129 129 class InvalidProperty(KernelError):
130 130 pass
131 131
132 132
133 133 class MissingBlockArgument(KernelError):
134 134 pass
135 135
136 136
137 137 class StopLocalExecution(KernelError):
138 138 pass
139 139
140 140
141 141 class SecurityError(KernelError):
142 142 pass
143 143
144 144
145 145 class FileTimeoutError(KernelError):
146 146 pass
147 147
148 class TimeoutError(KernelError):
149 pass
150
148 151 class RemoteError(KernelError):
149 152 """Error raised elsewhere"""
150 153 ename=None
151 154 evalue=None
152 155 traceback=None
153 156 engine_info=None
154 157
155 158 def __init__(self, ename, evalue, traceback, engine_info=None):
156 159 self.ename=ename
157 160 self.evalue=evalue
158 161 self.traceback=traceback
159 162 self.engine_info=engine_info or {}
160 163 self.args=(ename, evalue)
161 164
162 165 def __repr__(self):
163 166 engineid = self.engine_info.get('engineid', ' ')
164 167 return "<Remote[%s]:%s(%s)>"%(engineid, self.ename, self.evalue)
165 168
166 169 def __str__(self):
167 170 sig = "%s(%s)"%(self.ename, self.evalue)
168 171 if self.traceback:
169 172 return sig + '\n' + self.traceback
170 173 else:
171 174 return sig
172 175
173 176
174 177 class TaskRejectError(KernelError):
175 178 """Exception to raise when a task should be rejected by an engine.
176 179
177 180 This exception can be used to allow a task running on an engine to test
178 181 if the engine (or the user's namespace on the engine) has the needed
179 182 task dependencies. If not, the task should raise this exception. For
180 183 the task to be retried on another engine, the task should be created
181 184 with the `retries` argument > 1.
182 185
183 186 The advantage of this approach over our older properties system is that
184 187 tasks have full access to the user's namespace on the engines and the
185 188 properties don't have to be managed or tested by the controller.
186 189 """
187 190
188 191
189 192 class CompositeError(KernelError):
190 193 """Error for representing possibly multiple errors on engines"""
191 194 def __init__(self, message, elist):
192 195 Exception.__init__(self, *(message, elist))
193 196 # Don't use pack_exception because it will conflict with the .message
194 197 # attribute that is being deprecated in 2.6 and beyond.
195 198 self.msg = message
196 199 self.elist = elist
197 200 self.args = [ e[0] for e in elist ]
198 201
199 202 def _get_engine_str(self, ei):
200 203 if not ei:
201 204 return '[Engine Exception]'
202 205 else:
203 206 return '[%i:%s]: ' % (ei['engineid'], ei['method'])
204 207
205 208 def _get_traceback(self, ev):
206 209 try:
207 210 tb = ev._ipython_traceback_text
208 211 except AttributeError:
209 212 return 'No traceback available'
210 213 else:
211 214 return tb
212 215
213 216 def __str__(self):
214 217 s = str(self.msg)
215 218 for en, ev, etb, ei in self.elist:
216 219 engine_str = self._get_engine_str(ei)
217 220 s = s + '\n' + engine_str + en + ': ' + str(ev)
218 221 return s
219 222
220 223 def __repr__(self):
221 224 return "CompositeError(%i)"%len(self.elist)
222 225
223 226 def print_tracebacks(self, excid=None):
224 227 if excid is None:
225 228 for (en,ev,etb,ei) in self.elist:
226 229 print (self._get_engine_str(ei))
227 230 print (etb or 'No traceback available')
228 231 print ()
229 232 else:
230 233 try:
231 234 en,ev,etb,ei = self.elist[excid]
232 235 except:
233 236 raise IndexError("an exception with index %i does not exist"%excid)
234 237 else:
235 238 print (self._get_engine_str(ei))
236 239 print (etb or 'No traceback available')
237 240
238 241 def raise_exception(self, excid=0):
239 242 try:
240 243 en,ev,etb,ei = self.elist[excid]
241 244 except:
242 245 raise IndexError("an exception with index %i does not exist"%excid)
243 246 else:
244 247 try:
245 248 raise RemoteError(en, ev, etb, ei)
246 249 except:
247 250 et,ev,tb = sys.exc_info()
248 251
249 252
250 253 def collect_exceptions(rdict_or_list, method):
251 254 """check a result dict for errors, and raise CompositeError if any exist.
252 255 Passthrough otherwise."""
253 256 elist = []
254 257 if isinstance(rdict_or_list, dict):
255 258 rlist = rdict_or_list.values()
256 259 else:
257 260 rlist = rdict_or_list
258 261 for r in rlist:
259 262 if isinstance(r, RemoteError):
260 263 en, ev, etb, ei = r.ename, r.evalue, r.traceback, r.engine_info
261 264 # Sometimes we could have CompositeError in our list. Just take
262 265 # the errors out of them and put them in our new list. This
263 266 # has the effect of flattening lists of CompositeErrors into one
264 267 # CompositeError
265 268 if en=='CompositeError':
266 269 for e in ev.elist:
267 270 elist.append(e)
268 271 else:
269 272 elist.append((en, ev, etb, ei))
270 273 if len(elist)==0:
271 274 return rdict_or_list
272 275 else:
273 276 msg = "one or more exceptions from call to method: %s" % (method)
274 277 # This silliness is needed so the debugger has access to the exception
275 278 # instance (e in this case)
276 279 try:
277 280 raise CompositeError(msg, elist)
278 281 except CompositeError, e:
279 282 raise e
280 283
@@ -1,171 +1,171
1 1 #!/usr/bin/env python
2 2 """
3 3 A multi-heart Heartbeat system using PUB and XREP sockets. pings are sent out on the PUB,
4 4 and hearts are tracked based on their XREQ identities.
5 5 """
6 6
7 7 from __future__ import print_function
8 8 import time
9 9 import uuid
10 10
11 11 import zmq
12 from zmq.devices import ProcessDevice
12 from zmq.devices import ProcessDevice,ThreadDevice
13 13 from zmq.eventloop import ioloop, zmqstream
14 14
15 15 #internal
16 16 from IPython.zmq.log import logger
17 17
18 18 class Heart(object):
19 19 """A basic heart object for responding to a HeartMonitor.
20 20 This is a simple wrapper with defaults for the most common
21 21 Device model for responding to heartbeats.
22 22
23 23 It simply builds a threadsafe zmq.FORWARDER Device, defaulting to using
24 24 SUB/XREQ for in/out.
25 25
26 26 You can specify the XREQ's IDENTITY via the optional heart_id argument."""
27 27 device=None
28 28 id=None
29 29 def __init__(self, in_addr, out_addr, in_type=zmq.SUB, out_type=zmq.XREQ, heart_id=None):
30 self.device = ProcessDevice(zmq.FORWARDER, in_type, out_type)
30 self.device = ThreadDevice(zmq.FORWARDER, in_type, out_type)
31 31 self.device.daemon=True
32 32 self.device.connect_in(in_addr)
33 33 self.device.connect_out(out_addr)
34 34 if in_type == zmq.SUB:
35 35 self.device.setsockopt_in(zmq.SUBSCRIBE, "")
36 36 if heart_id is None:
37 37 heart_id = str(uuid.uuid4())
38 38 self.device.setsockopt_out(zmq.IDENTITY, heart_id)
39 39 self.id = heart_id
40 40
41 41 def start(self):
42 42 return self.device.start()
43 43
44 44 class HeartMonitor(object):
45 45 """A basic HeartMonitor class
46 46 pingstream: a PUB stream
47 47 pongstream: an XREP stream
48 48 period: the period of the heartbeat in milliseconds"""
49 49 loop=None
50 50 pingstream=None
51 51 pongstream=None
52 52 period=None
53 53 hearts=None
54 54 on_probation=None
55 55 last_ping=None
56 56
57 57 def __init__(self, loop, pingstream, pongstream, period=1000):
58 58 self.loop = loop
59 59 self.period = period
60 60
61 61 self.pingstream = pingstream
62 62 self.pongstream = pongstream
63 63 self.pongstream.on_recv(self.handle_pong)
64 64
65 65 self.hearts = set()
66 66 self.responses = set()
67 67 self.on_probation = set()
68 68 self.lifetime = 0
69 69 self.tic = time.time()
70 70
71 71 self._new_handlers = set()
72 72 self._failure_handlers = set()
73 73
74 74 def start(self):
75 75 self.caller = ioloop.PeriodicCallback(self.beat, self.period, self.loop)
76 76 self.caller.start()
77 77
78 78 def add_new_heart_handler(self, handler):
79 79 """add a new handler for new hearts"""
80 80 logger.debug("heartbeat::new_heart_handler: %s"%handler)
81 81 self._new_handlers.add(handler)
82 82
83 83 def add_heart_failure_handler(self, handler):
84 84 """add a new handler for heart failure"""
85 85 logger.debug("heartbeat::new heart failure handler: %s"%handler)
86 86 self._failure_handlers.add(handler)
87 87
88 88 # def _flush(self):
89 89 # """override IOLoop triggers"""
90 90 # while True:
91 91 # try:
92 92 # msg = self.pongstream.socket.recv_multipart(zmq.NOBLOCK)
93 93 # logger.warn("IOLoop triggered beat with incoming heartbeat waiting to be handled")
94 94 # except zmq.ZMQError:
95 95 # return
96 96 # else:
97 97 # self.handle_pong(msg)
98 98 # # print '.'
99 99 #
100 100
101 101 def beat(self):
102 102 self.pongstream.flush()
103 103 self.last_ping = self.lifetime
104 104
105 105 toc = time.time()
106 106 self.lifetime += toc-self.tic
107 107 self.tic = toc
108 108 logger.debug("heartbeat::%s"%self.lifetime)
109 109 goodhearts = self.hearts.intersection(self.responses)
110 110 missed_beats = self.hearts.difference(goodhearts)
111 111 heartfailures = self.on_probation.intersection(missed_beats)
112 112 newhearts = self.responses.difference(goodhearts)
113 113 map(self.handle_new_heart, newhearts)
114 114 map(self.handle_heart_failure, heartfailures)
115 115 self.on_probation = missed_beats.intersection(self.hearts)
116 116 self.responses = set()
117 117 # print self.on_probation, self.hearts
118 118 # logger.debug("heartbeat::beat %.3f, %i beating hearts"%(self.lifetime, len(self.hearts)))
119 119 self.pingstream.send(str(self.lifetime))
120 120
121 121 def handle_new_heart(self, heart):
122 122 if self._new_handlers:
123 123 for handler in self._new_handlers:
124 124 handler(heart)
125 125 else:
126 126 logger.info("heartbeat::yay, got new heart %s!"%heart)
127 127 self.hearts.add(heart)
128 128
129 129 def handle_heart_failure(self, heart):
130 130 if self._failure_handlers:
131 131 for handler in self._failure_handlers:
132 132 try:
133 133 handler(heart)
134 134 except Exception as e:
135 135 print (e)
136 136 logger.error("heartbeat::Bad Handler! %s"%handler)
137 137 pass
138 138 else:
139 139 logger.info("heartbeat::Heart %s failed :("%heart)
140 140 self.hearts.remove(heart)
141 141
142 142
143 143 def handle_pong(self, msg):
144 144 "a heart just beat"
145 145 if msg[1] == str(self.lifetime):
146 146 delta = time.time()-self.tic
147 147 logger.debug("heartbeat::heart %r took %.2f ms to respond"%(msg[0], 1000*delta))
148 148 self.responses.add(msg[0])
149 149 elif msg[1] == str(self.last_ping):
150 150 delta = time.time()-self.tic + (self.lifetime-self.last_ping)
151 151 logger.warn("heartbeat::heart %r missed a beat, and took %.2f ms to respond"%(msg[0], 1000*delta))
152 152 self.responses.add(msg[0])
153 153 else:
154 154 logger.warn("heartbeat::got bad heartbeat (possibly old?): %s (current=%.3f)"%
155 155 (msg[1],self.lifetime))
156 156
157 157
158 158 if __name__ == '__main__':
159 159 loop = ioloop.IOLoop.instance()
160 160 context = zmq.Context()
161 161 pub = context.socket(zmq.PUB)
162 162 pub.bind('tcp://127.0.0.1:5555')
163 163 xrep = context.socket(zmq.XREP)
164 164 xrep.bind('tcp://127.0.0.1:5556')
165 165
166 166 outstream = zmqstream.ZMQStream(pub, loop)
167 167 instream = zmqstream.ZMQStream(xrep, loop)
168 168
169 169 hb = HeartMonitor(loop, outstream, instream)
170 170
171 171 loop.start()
@@ -1,145 +1,145
1 1 """Remote Functions and decorators for the client."""
2 2 #-----------------------------------------------------------------------------
3 3 # Copyright (C) 2010 The IPython Development Team
4 4 #
5 5 # Distributed under the terms of the BSD License. The full license is in
6 6 # the file COPYING, distributed as part of this software.
7 7 #-----------------------------------------------------------------------------
8 8
9 9 #-----------------------------------------------------------------------------
10 10 # Imports
11 11 #-----------------------------------------------------------------------------
12 12
13 13 import map as Map
14 from pendingresult import PendingMapResult
14 from asyncresult import AsyncMapResult
15 15
16 16 #-----------------------------------------------------------------------------
17 17 # Decorators
18 18 #-----------------------------------------------------------------------------
19 19
20 20 def remote(client, bound=False, block=None, targets=None):
21 21 """Turn a function into a remote function.
22 22
23 23 This method can be used for map:
24 24
25 25 >>> @remote(client,block=True)
26 26 def func(a)
27 27 """
28 28 def remote_function(f):
29 29 return RemoteFunction(client, f, bound, block, targets)
30 30 return remote_function
31 31
32 32 def parallel(client, dist='b', bound=False, block=None, targets='all'):
33 33 """Turn a function into a parallel remote function.
34 34
35 35 This method can be used for map:
36 36
37 37 >>> @parallel(client,block=True)
38 38 def func(a)
39 39 """
40 40 def parallel_function(f):
41 41 return ParallelFunction(client, f, dist, bound, block, targets)
42 42 return parallel_function
43 43
44 44 #--------------------------------------------------------------------------
45 45 # Classes
46 46 #--------------------------------------------------------------------------
47 47
48 48 class RemoteFunction(object):
49 49 """Turn an existing function into a remote function.
50 50
51 51 Parameters
52 52 ----------
53 53
54 54 client : Client instance
55 55 The client to be used to connect to engines
56 56 f : callable
57 57 The function to be wrapped into a remote function
58 58 bound : bool [default: False]
59 59 Whether the affect the remote namespace when called
60 60 block : bool [default: None]
61 61 Whether to wait for results or not. The default behavior is
62 62 to use the current `block` attribute of `client`
63 63 targets : valid target list [default: all]
64 64 The targets on which to execute.
65 65 """
66 66
67 67 client = None # the remote connection
68 68 func = None # the wrapped function
69 69 block = None # whether to block
70 70 bound = None # whether to affect the namespace
71 71 targets = None # where to execute
72 72
73 73 def __init__(self, client, f, bound=False, block=None, targets=None):
74 74 self.client = client
75 75 self.func = f
76 76 self.block=block
77 77 self.bound=bound
78 78 self.targets=targets
79 79
80 80 def __call__(self, *args, **kwargs):
81 81 return self.client.apply(self.func, args=args, kwargs=kwargs,
82 82 block=self.block, targets=self.targets, bound=self.bound)
83 83
84 84
85 85 class ParallelFunction(RemoteFunction):
86 86 """Class for mapping a function to sequences."""
87 87 def __init__(self, client, f, dist='b', bound=False, block=None, targets='all'):
88 88 super(ParallelFunction, self).__init__(client,f,bound,block,targets)
89 89 mapClass = Map.dists[dist]
90 90 self.mapObject = mapClass()
91 91
92 92 def __call__(self, *sequences):
93 93 len_0 = len(sequences[0])
94 94 for s in sequences:
95 95 if len(s)!=len_0:
96 96 raise ValueError('all sequences must have equal length')
97 97
98 98 if self.targets is None:
99 99 # load-balanced:
100 100 engines = [None]*len_0
101 101 elif isinstance(self.targets, int):
102 102 engines = [None]*self.targets
103 103 else:
104 104 # multiplexed:
105 105 engines = self.client._build_targets(self.targets)[-1]
106 106
107 107 nparts = len(engines)
108 108 msg_ids = []
109 109 # my_f = lambda *a: map(self.func, *a)
110 110 for index, engineid in enumerate(engines):
111 111 args = []
112 112 for seq in sequences:
113 113 part = self.mapObject.getPartition(seq, index, nparts)
114 114 if not part:
115 115 continue
116 116 else:
117 117 args.append(part)
118 118 if not args:
119 119 continue
120 120
121 121 # print (args)
122 122 if hasattr(self, '_map'):
123 123 f = map
124 124 args = [self.func]+args
125 125 else:
126 126 f=self.func
127 127 mid = self.client.apply(f, args=args, block=False,
128 128 bound=self.bound,
129 targets=engineid).msg_ids[0]
129 targets=engineid)._msg_ids[0]
130 130 msg_ids.append(mid)
131 131
132 r = PendingMapResult(self.client, msg_ids, self.mapObject)
132 r = AsyncMapResult(self.client, msg_ids, self.mapObject)
133 133 if self.block:
134 134 r.wait()
135 135 return r.result
136 136 else:
137 137 return r
138 138
139 139 def map(self, *sequences):
140 140 """call a function on each element of a sequence remotely."""
141 141 self._map = True
142 142 ret = self.__call__(*sequences)
143 143 del self._map
144 144 return ret
145 145
1 NO CONTENT: file was removed
General Comments 0
You need to be logged in to leave comments. Login now