From b0d94c76944aa7bce95b820f8c78028bfbf4e8f7 2011-04-08 00:38:12 From: MinRK Date: 2011-04-08 00:38:12 Subject: [PATCH] adapt kernel/error.py to zmq, improve error propagation. --- diff --git a/IPython/zmq/parallel/client.py b/IPython/zmq/parallel/client.py index 7353d97..6dce41c 100644 --- a/IPython/zmq/parallel/client.py +++ b/IPython/zmq/parallel/client.py @@ -27,6 +27,7 @@ import streamsession as ss # from remotenamespace import RemoteNamespace from view import DirectView, LoadBalancedView from dependency import Dependency, depend, require +import error def _push(ns): globals().update(ns) @@ -128,13 +129,14 @@ class AbortedTask(object): def __init__(self, msg_id): self.msg_id = msg_id -class ControllerError(Exception): - """Exception Class for errors in the controller (not the Engine).""" - def __init__(self, etype, evalue, tb): - self.etype = etype - self.evalue = evalue - self.traceback=tb - +class ResultDict(dict): + """A subclass of dict that raises errors if it has them.""" + def __getitem__(self, key): + res = dict.__getitem__(self, key) + if isinstance(res, error.KernelError): + raise res + return res + class Client(object): """A semi-synchronous client to the IPython ZMQ controller @@ -402,12 +404,18 @@ class Client(object): if content['status'] == 'ok': self.results[msg_id] = ss.unserialize_object(msg['buffers']) elif content['status'] == 'aborted': - self.results[msg_id] = AbortedTask(msg_id) + self.results[msg_id] = error.AbortedTask(msg_id) elif content['status'] == 'resubmitted': # TODO: handle resubmission pass else: - self.results[msg_id] = ss.unwrap_exception(content) + e = ss.unwrap_exception(content) + e_uuid = e.engine_info['engineid'] + for k,v in self._engines.iteritems(): + if v == e_uuid: + e.engine_info['engineid'] = k + break + self.results[msg_id] = e def _flush_notifications(self): """Flush notifications of engine registrations waiting @@ -649,6 +657,13 @@ class Client(object): result = self.apply(execute, (code,), targets=None, block=block, bound=False) return result + def _maybe_raise(self, result): + """wrapper for maybe raising an exception if apply failed.""" + if isinstance(result, error.RemoteError): + raise result + + return result + def apply(self, f, args=None, kwargs=None, bound=True, block=None, targets=None, after=None, follow=None): """Call `f(*args, **kwargs)` on a remote engine(s), returning the result. @@ -758,7 +773,7 @@ class Client(object): self.history.append(msg_id) if block: self.barrier(msg_id) - return self.results[msg_id] + return self._maybe_raise(self.results[msg_id]) else: return msg_id @@ -795,12 +810,12 @@ class Client(object): else: return msg_ids if len(msg_ids) == 1: - return self.results[msg_ids[0]] + return self._maybe_raise(self.results[msg_ids[0]]) else: result = {} for target,mid in zip(targets, msg_ids): result[target] = self.results[mid] - return result + return error.collect_exceptions(result, f.__name__) #-------------------------------------------------------------------------- # Data movement diff --git a/IPython/zmq/parallel/error.py b/IPython/zmq/parallel/error.py new file mode 100644 index 0000000..9f1a735 --- /dev/null +++ b/IPython/zmq/parallel/error.py @@ -0,0 +1,276 @@ +# encoding: utf-8 + +"""Classes and functions for kernel related errors and exceptions.""" +from __future__ import print_function + +__docformat__ = "restructuredtext en" + +# Tell nose to skip this module +__test__ = {} + +#------------------------------------------------------------------------------- +# Copyright (C) 2008 The IPython Development Team +# +# Distributed under the terms of the BSD License. The full license is in +# the file COPYING, distributed as part of this software. +#------------------------------------------------------------------------------- + +#------------------------------------------------------------------------------- +# Error classes +#------------------------------------------------------------------------------- +class IPythonError(Exception): + """Base exception that all of our exceptions inherit from. + + This can be raised by code that doesn't have any more specific + information.""" + + pass + +# Exceptions associated with the controller objects +class ControllerError(IPythonError): pass + +class ControllerCreationError(ControllerError): pass + + +# Exceptions associated with the Engines +class EngineError(IPythonError): pass + +class EngineCreationError(EngineError): pass + +class KernelError(IPythonError): + pass + +class NotDefined(KernelError): + def __init__(self, name): + self.name = name + self.args = (name,) + + def __repr__(self): + return '' % self.name + + __str__ = __repr__ + + +class QueueCleared(KernelError): + pass + + +class IdInUse(KernelError): + pass + + +class ProtocolError(KernelError): + pass + + +class ConnectionError(KernelError): + pass + + +class InvalidEngineID(KernelError): + pass + + +class NoEnginesRegistered(KernelError): + pass + + +class InvalidClientID(KernelError): + pass + + +class InvalidDeferredID(KernelError): + pass + + +class SerializationError(KernelError): + pass + + +class MessageSizeError(KernelError): + pass + + +class PBMessageSizeError(MessageSizeError): + pass + + +class ResultNotCompleted(KernelError): + pass + + +class ResultAlreadyRetrieved(KernelError): + pass + +class ClientError(KernelError): + pass + + +class TaskAborted(KernelError): + pass + + +class TaskTimeout(KernelError): + pass + + +class NotAPendingResult(KernelError): + pass + + +class UnpickleableException(KernelError): + pass + + +class AbortedPendingDeferredError(KernelError): + pass + + +class InvalidProperty(KernelError): + pass + + +class MissingBlockArgument(KernelError): + pass + + +class StopLocalExecution(KernelError): + pass + + +class SecurityError(KernelError): + pass + + +class FileTimeoutError(KernelError): + pass + +class RemoteError(KernelError): + """Error raised elsewhere""" + ename=None + evalue=None + traceback=None + engine_info=None + + def __init__(self, ename, evalue, traceback, engine_info=None): + self.ename=ename + self.evalue=evalue + self.traceback=traceback + self.engine_info=engine_info or {} + self.args=(ename, evalue) + + def __repr__(self): + engineid = self.engine_info.get('engineid', ' ') + return ""%(engineid, self.ename, self.evalue) + + def __str__(self): + sig = "%s(%s)"%(self.ename, self.evalue) + if self.traceback: + return sig + '\n' + self.traceback + else: + return sig + + +class TaskRejectError(KernelError): + """Exception to raise when a task should be rejected by an engine. + + This exception can be used to allow a task running on an engine to test + if the engine (or the user's namespace on the engine) has the needed + task dependencies. If not, the task should raise this exception. For + the task to be retried on another engine, the task should be created + with the `retries` argument > 1. + + The advantage of this approach over our older properties system is that + tasks have full access to the user's namespace on the engines and the + properties don't have to be managed or tested by the controller. + """ + + +class CompositeError(KernelError): + """Error for representing possibly multiple errors on engines""" + def __init__(self, message, elist): + Exception.__init__(self, *(message, elist)) + # Don't use pack_exception because it will conflict with the .message + # attribute that is being deprecated in 2.6 and beyond. + self.msg = message + self.elist = elist + self.args = [ e[0] for e in elist ] + + def _get_engine_str(self, ei): + if not ei: + return '[Engine Exception]' + else: + return '[%i:%s]: ' % (ei['engineid'], ei['method']) + + def _get_traceback(self, ev): + try: + tb = ev._ipython_traceback_text + except AttributeError: + return 'No traceback available' + else: + return tb + + def __str__(self): + s = str(self.msg) + for en, ev, etb, ei in self.elist: + engine_str = self._get_engine_str(ei) + s = s + '\n' + engine_str + en + ': ' + str(ev) + return s + + def __repr__(self): + return "CompositeError(%i)"%len(self.elist) + + def print_tracebacks(self, excid=None): + if excid is None: + for (en,ev,etb,ei) in self.elist: + print (self._get_engine_str(ei)) + print (etb or 'No traceback available') + print () + else: + try: + en,ev,etb,ei = self.elist[excid] + except: + raise IndexError("an exception with index %i does not exist"%excid) + else: + print (self._get_engine_str(ei)) + print (etb or 'No traceback available') + + def raise_exception(self, excid=0): + try: + en,ev,etb,ei = self.elist[excid] + except: + raise IndexError("an exception with index %i does not exist"%excid) + else: + try: + raise RemoteError(en, ev, etb, ei) + except: + et,ev,tb = sys.exc_info() + + +def collect_exceptions(rdict, method): + """check a result dict for errors, and raise CompositeError if any exist. + Passthrough otherwise.""" + elist = [] + for r in rdict.values(): + if isinstance(r, RemoteError): + en, ev, etb, ei = r.ename, r.evalue, r.traceback, r.engine_info + # Sometimes we could have CompositeError in our list. Just take + # the errors out of them and put them in our new list. This + # has the effect of flattening lists of CompositeErrors into one + # CompositeError + if en=='CompositeError': + for e in ev.elist: + elist.append(e) + else: + elist.append((en, ev, etb, ei)) + if len(elist)==0: + return rdict + else: + msg = "one or more exceptions from call to method: %s" % (method) + # This silliness is needed so the debugger has access to the exception + # instance (e in this case) + try: + raise CompositeError(msg, elist) + except CompositeError, e: + raise e + diff --git a/IPython/zmq/parallel/streamkernel.py b/IPython/zmq/parallel/streamkernel.py index e0520ba..8a75254 100755 --- a/IPython/zmq/parallel/streamkernel.py +++ b/IPython/zmq/parallel/streamkernel.py @@ -24,6 +24,7 @@ import zmq from zmq.eventloop import ioloop, zmqstream # Local imports. +from IPython.core import ultratb from IPython.utils.traitlets import HasTraits, Instance, List from IPython.zmq.completer import KernelCompleter from IPython.zmq.log import logger # a Logger object @@ -73,7 +74,13 @@ class Kernel(HasTraits): for msg_type in ['shutdown_request', 'abort_request']+self.shell_handlers.keys(): self.control_handlers[msg_type] = getattr(self, msg_type) - + + + def _wrap_exception(self, method=None): + e_info = dict(engineid=self.identity, method=method) + content=wrap_exception(e_info) + return content + #-------------------- control handlers ----------------------------- def abort_queues(self): for stream in self.shell_streams: @@ -131,7 +138,7 @@ class Kernel(HasTraits): try: self.abort_queues() except: - content = wrap_exception() + content = self._wrap_exception('shutdown') else: content = dict(parent['content']) content['status'] = 'ok' @@ -214,7 +221,7 @@ class Kernel(HasTraits): sys.displayhook.set_parent(parent) exec comp_code in self.user_ns, self.user_ns except: - exc_content = wrap_exception() + exc_content = self._wrap_exception('execute') # exc_msg = self.session.msg(u'pyerr', exc_content, parent) self.session.send(self.iopub_stream, u'pyerr', exc_content, parent=parent) reply_content = exc_content @@ -291,13 +298,13 @@ class Kernel(HasTraits): packed_result,buf = serialize_object(result) result_buf = [packed_result]+buf except: - exc_content = wrap_exception() + exc_content = self._wrap_exception('apply') # exc_msg = self.session.msg(u'pyerr', exc_content, parent) self.session.send(self.iopub_stream, u'pyerr', exc_content, parent=parent) reply_content = exc_content result_buf = [] - if etype is UnmetDependency: + if exc_content['ename'] == UnmetDependency.__name__: sub['dependencies_met'] = False else: reply_content = {'status' : 'ok'} diff --git a/IPython/zmq/parallel/streamsession.py b/IPython/zmq/parallel/streamsession.py index 5baa468..6f2229a 100644 --- a/IPython/zmq/parallel/streamsession.py +++ b/IPython/zmq/parallel/streamsession.py @@ -17,6 +17,8 @@ from zmq.eventloop.zmqstream import ZMQStream from IPython.utils.pickleutil import can, uncan, canSequence, uncanSequence from IPython.utils.newserialized import serialize, unserialize +from IPython.zmq.parallel.error import RemoteError + try: import cPickle pickle = cPickle @@ -60,25 +62,22 @@ else: DELIM="" ISO8601="%Y-%m-%dT%H:%M:%S.%f" -def wrap_exception(): +def wrap_exception(engine_info={}): etype, evalue, tb = sys.exc_info() - tb = traceback.format_exception(etype, evalue, tb) + stb = traceback.format_exception(etype, evalue, tb) exc_content = { 'status' : 'error', - 'traceback' : [ line.encode('utf8') for line in tb ], - 'etype' : str(etype).encode('utf8'), - 'evalue' : evalue.encode('utf8') + 'traceback' : stb, + 'ename' : unicode(etype.__name__), + 'evalue' : unicode(evalue), + 'engine_info' : engine_info } return exc_content -class KernelError(Exception): - pass - def unwrap_exception(content): - err = KernelError(content['etype'], content['evalue']) - err.evalue = content['evalue'] - err.etype = content['etype'] - err.traceback = ''.join(content['traceback']) + err = RemoteError(content['ename'], content['evalue'], + ''.join(content['traceback']), + content.get('engine_info', {})) return err @@ -402,7 +401,7 @@ class StreamSession(object): pprint.pprint(buffers) return omsg - def send_raw(self, stream, msg, flags=0, copy=True, idents=None): + def send_raw(self, stream, msg, flags=0, copy=True, ident=None): """Send a raw message via idents. Parameters