diff --git a/IPython/core/formatters.py b/IPython/core/formatters.py index 01a8637..371ff05 100644 --- a/IPython/core/formatters.py +++ b/IPython/core/formatters.py @@ -34,7 +34,9 @@ from IPython.lib import pretty from IPython.utils.traitlets import ( Bool, Dict, Integer, Unicode, CUnicode, ObjectName, List, ) -from IPython.utils.py3compat import unicode_to_str, with_metaclass, PY3 +from IPython.utils.py3compat import ( + unicode_to_str, with_metaclass, PY3, string_types, +) if PY3: from io import StringIO @@ -46,7 +48,6 @@ else: # The main DisplayFormatter class #----------------------------------------------------------------------------- - class DisplayFormatter(Configurable): # When set to true only the default plain text formatter will be used. @@ -206,6 +207,22 @@ class FormatterABC(with_metaclass(abc.ABCMeta, object)): return None +def _mod_name_key(typ): + """Return a (__module__, __name__) tuple for a type. + + Used as key in Formatter.deferred_printers. + """ + module = getattr(typ, '__module__', None) + name = getattr(typ, '__name__', None) + return (module, name) + + +def _get_type(obj): + """Return the type of an instance (old and new-style)""" + return getattr(obj, '__class__', None) or type(obj) + +_raise_key_error = object() + class BaseFormatter(Configurable): """A base formatter class that is configurable. @@ -240,74 +257,142 @@ class BaseFormatter(Configurable): # The singleton printers. # Maps the IDs of the builtin singleton objects to the format functions. singleton_printers = Dict(config=True) - def _singleton_printers_default(self): - return {} # The type-specific printers. # Map type objects to the format functions. type_printers = Dict(config=True) - def _type_printers_default(self): - return {} # The deferred-import type-specific printers. # Map (modulename, classname) pairs to the format functions. deferred_printers = Dict(config=True) - def _deferred_printers_default(self): - return {} def __call__(self, obj): """Compute the format for an object.""" if self.enabled: - obj_id = id(obj) try: - obj_class = getattr(obj, '__class__', None) or type(obj) - # First try to find registered singleton printers for the type. + # lookup registered printer try: - printer = self.singleton_printers[obj_id] - except (TypeError, KeyError): + printer = self.lookup(obj) + except KeyError: pass else: return printer(obj) - # Next look for type_printers. - for cls in pretty._get_mro(obj_class): - if cls in self.type_printers: - return self.type_printers[cls](obj) - else: - printer = self._in_deferred_types(cls) - if printer is not None: - return printer(obj) - # Finally look for special method names. - if hasattr(obj_class, self.print_method): - printer = getattr(obj_class, self.print_method) - return printer(obj) + # Finally look for special method names + method = pretty._safe_getattr(obj, self.print_method, None) + if method is not None: + return method() return None except Exception: pass else: return None + + def __contains__(self, typ): + """map in to lookup_by_type""" + try: + self.lookup_by_type(typ) + except KeyError: + return False + else: + return True + + def lookup(self, obj): + """Look up the formatter for a given instance. + + Parameters + ---------- + obj : object instance - def for_type(self, typ, func): - """Add a format function for a given type. + Returns + ------- + f : callable + The registered formatting callable for the type. + + Raises + ------ + KeyError if the type has not been registered. + """ + # look for singleton first + obj_id = id(obj) + if obj_id in self.singleton_printers: + return self.singleton_printers[obj_id] + # then lookup by type + return self.lookup_by_type(_get_type(obj)) + + def lookup_by_type(self, typ): + """Look up the registered formatter for a type. + + Parameters + ---------- + typ : type or '__module__.__name__' string for a type + + Returns + ------- + f : callable + The registered formatting callable for the type. + + Raises + ------ + KeyError if the type has not been registered. + """ + if isinstance(typ, string_types): + typ_key = tuple(typ.rsplit('.',1)) + if typ_key not in self.deferred_printers: + # We may have it cached in the type map. We will have to + # iterate over all of the types to check. + for cls in self.type_printers: + if _mod_name_key(cls) == typ_key: + return self.type_printers[cls] + else: + return self.deferred_printers[typ_key] + else: + for cls in pretty._get_mro(typ): + if cls in self.type_printers or self._in_deferred_types(cls): + return self.type_printers[cls] + + # If we have reached here, the lookup failed. + raise KeyError("No registered printer for {0!r}".format(typ)) + def for_type(self, typ, func=None): + """Add a format function for a given type. + Parameters ----------- - typ : class + typ : type or '__module__.__name__' string for a type The class of the object that will be formatted using `func`. func : callable - The callable that will be called to compute the format data. The - call signature of this function is simple, it must take the - object to be formatted and return the raw data for the given - format. Subclasses may use a different call signature for the + A callable for computing the format data. + `func` will be called with the object to be formatted, + and will return the raw data in this formatter's format. + Subclasses may use a different call signature for the `func` argument. + + If `func` is None or not specified, there will be no change, + only returning the current value. + + Returns + ------- + oldfunc : callable + The currently registered callable. + If you are registering a new formatter, + this will be the previous value (to enable restoring later). """ - oldfunc = self.type_printers.get(typ, None) + # if string given, interpret as 'pkg.module.class_name' + if isinstance(typ, string_types): + type_module, type_name = typ.rsplit('.', 1) + return self.for_type_by_name(type_module, type_name, func) + + try: + oldfunc = self.lookup_by_type(typ) + except KeyError: + oldfunc = None + if func is not None: - # To support easy restoration of old printers, we need to ignore - # Nones. self.type_printers[typ] = func + return oldfunc - def for_type_by_name(self, type_module, type_name, func): + def for_type_by_name(self, type_module, type_name, func=None): """Add a format function for a type specified by the full dotted module and name of the type, rather than the type of the object. @@ -319,37 +404,89 @@ class BaseFormatter(Configurable): type_name : str The name of the type (the class name), like ``dtype`` func : callable - The callable that will be called to compute the format data. The - call signature of this function is simple, it must take the - object to be formatted and return the raw data for the given - format. Subclasses may use a different call signature for the + A callable for computing the format data. + `func` will be called with the object to be formatted, + and will return the raw data in this formatter's format. + Subclasses may use a different call signature for the `func` argument. + + If `func` is None or unspecified, there will be no change, + only returning the current value. + + Returns + ------- + oldfunc : callable + The currently registered callable. + If you are registering a new formatter, + this will be the previous value (to enable restoring later). """ key = (type_module, type_name) - oldfunc = self.deferred_printers.get(key, None) + + try: + oldfunc = self.lookup_by_type("%s.%s" % key) + except KeyError: + oldfunc = None + if func is not None: - # To support easy restoration of old printers, we need to ignore - # Nones. self.deferred_printers[key] = func return oldfunc + + def pop(self, typ, default=_raise_key_error): + """Pop a formatter for the given type. + + Parameters + ---------- + typ : type or '__module__.__name__' string for a type + default : object + value to be returned if no formatter is registered for typ. + + Returns + ------- + obj : object + The last registered object for the type. + + Raises + ------ + KeyError if the type is not registered and default is not specified. + """ + + if isinstance(typ, string_types): + typ_key = tuple(typ.rsplit('.',1)) + if typ_key not in self.deferred_printers: + # We may have it cached in the type map. We will have to + # iterate over all of the types to check. + for cls in self.type_printers: + if _mod_name_key(cls) == typ_key: + old = self.type_printers.pop(cls) + break + else: + old = default + else: + old = self.deferred_printers.pop(typ_key) + else: + if typ in self.type_printers: + old = self.type_printers.pop(typ) + else: + old = self.deferred_printers.pop(_mod_name_key(typ), default) + if old is _raise_key_error: + raise KeyError("No registered value for {0!r}".format(typ)) + return old def _in_deferred_types(self, cls): """ Check if the given class is specified in the deferred type registry. - Returns the printer from the registry if it exists, and None if the - class is not in the registry. Successful matches will be moved to the - regular type registry for future use. + Successful matches will be moved to the regular type registry for future use. """ mod = getattr(cls, '__module__', None) name = getattr(cls, '__name__', None) key = (mod, name) - printer = None if key in self.deferred_printers: # Move the printer over to the regular registry. printer = self.deferred_printers.pop(key) self.type_printers[cls] = printer - return printer + return True + return False class PlainTextFormatter(BaseFormatter): diff --git a/IPython/core/pylabtools.py b/IPython/core/pylabtools.py index 986406b..a53a56e 100644 --- a/IPython/core/pylabtools.py +++ b/IPython/core/pylabtools.py @@ -175,13 +175,13 @@ def select_figure_format(shell, fmt): png_formatter = shell.display_formatter.formatters['image/png'] if fmt == 'png': - svg_formatter.type_printers.pop(Figure, None) + svg_formatter.pop(Figure, None) png_formatter.for_type(Figure, lambda fig: print_figure(fig, 'png')) elif fmt in ('png2x', 'retina'): - svg_formatter.type_printers.pop(Figure, None) + svg_formatter.pop(Figure, None) png_formatter.for_type(Figure, retina_figure) elif fmt == 'svg': - png_formatter.type_printers.pop(Figure, None) + png_formatter.pop(Figure, None) svg_formatter.for_type(Figure, lambda fig: print_figure(fig, 'svg')) else: raise ValueError("supported formats are: 'png', 'retina', 'svg', not %r" % fmt) diff --git a/IPython/core/tests/test_formatters.py b/IPython/core/tests/test_formatters.py index 98763ad..c792cae 100644 --- a/IPython/core/tests/test_formatters.py +++ b/IPython/core/tests/test_formatters.py @@ -9,7 +9,7 @@ except: numpy = None import nose.tools as nt -from IPython.core.formatters import PlainTextFormatter +from IPython.core.formatters import PlainTextFormatter, _mod_name_key class A(object): def __repr__(self): @@ -19,6 +19,9 @@ class B(A): def __repr__(self): return 'B()' +class C: + pass + class BadPretty(object): _repr_pretty_ = None @@ -87,4 +90,145 @@ def test_bad_precision(): nt.assert_raises(ValueError, set_fp, 'foo') nt.assert_raises(ValueError, set_fp, -1) +def test_for_type(): + f = PlainTextFormatter() + + # initial return, None + nt.assert_is(f.for_type(C, foo_printer), None) + # no func queries + nt.assert_is(f.for_type(C), foo_printer) + # shouldn't change anything + nt.assert_is(f.for_type(C), foo_printer) + # None should do the same + nt.assert_is(f.for_type(C, None), foo_printer) + nt.assert_is(f.for_type(C, None), foo_printer) + +def test_for_type_string(): + f = PlainTextFormatter() + + mod = C.__module__ + + type_str = '%s.%s' % (C.__module__, 'C') + + # initial return, None + nt.assert_is(f.for_type(type_str, foo_printer), None) + # no func queries + nt.assert_is(f.for_type(type_str), foo_printer) + nt.assert_in(_mod_name_key(C), f.deferred_printers) + nt.assert_is(f.for_type(C), foo_printer) + nt.assert_not_in(_mod_name_key(C), f.deferred_printers) + nt.assert_in(C, f.type_printers) + +def test_for_type_by_name(): + f = PlainTextFormatter() + + mod = C.__module__ + + # initial return, None + nt.assert_is(f.for_type_by_name(mod, 'C', foo_printer), None) + # no func queries + nt.assert_is(f.for_type_by_name(mod, 'C'), foo_printer) + # shouldn't change anything + nt.assert_is(f.for_type_by_name(mod, 'C'), foo_printer) + # None should do the same + nt.assert_is(f.for_type_by_name(mod, 'C', None), foo_printer) + nt.assert_is(f.for_type_by_name(mod, 'C', None), foo_printer) + +def test_lookup(): + f = PlainTextFormatter() + + f.for_type(C, foo_printer) + nt.assert_is(f.lookup(C()), foo_printer) + with nt.assert_raises(KeyError): + f.lookup(A()) + +def test_lookup_string(): + f = PlainTextFormatter() + type_str = '%s.%s' % (C.__module__, 'C') + + f.for_type(type_str, foo_printer) + nt.assert_is(f.lookup(C()), foo_printer) + # should move from deferred to imported dict + nt.assert_not_in(_mod_name_key(C), f.deferred_printers) + nt.assert_in(C, f.type_printers) + +def test_lookup_by_type(): + f = PlainTextFormatter() + f.for_type(C, foo_printer) + nt.assert_is(f.lookup_by_type(C), foo_printer) + type_str = '%s.%s' % (C.__module__, 'C') + with nt.assert_raises(KeyError): + f.lookup_by_type(A) + +def test_lookup_by_type_string(): + f = PlainTextFormatter() + type_str = '%s.%s' % (C.__module__, 'C') + f.for_type(type_str, foo_printer) + + # verify insertion + nt.assert_in(_mod_name_key(C), f.deferred_printers) + nt.assert_not_in(C, f.type_printers) + + nt.assert_is(f.lookup_by_type(type_str), foo_printer) + # lookup by string doesn't cause import + nt.assert_in(_mod_name_key(C), f.deferred_printers) + nt.assert_not_in(C, f.type_printers) + + nt.assert_is(f.lookup_by_type(C), foo_printer) + # should move from deferred to imported dict + nt.assert_not_in(_mod_name_key(C), f.deferred_printers) + nt.assert_in(C, f.type_printers) + +def test_in_formatter(): + f = PlainTextFormatter() + f.for_type(C, foo_printer) + type_str = '%s.%s' % (C.__module__, 'C') + nt.assert_in(C, f) + nt.assert_in(type_str, f) + +def test_string_in_formatter(): + f = PlainTextFormatter() + type_str = '%s.%s' % (C.__module__, 'C') + f.for_type(type_str, foo_printer) + nt.assert_in(type_str, f) + nt.assert_in(C, f) + +def test_pop(): + f = PlainTextFormatter() + f.for_type(C, foo_printer) + nt.assert_is(f.lookup_by_type(C), foo_printer) + nt.assert_is(f.pop(C, None), foo_printer) + f.for_type(C, foo_printer) + nt.assert_is(f.pop(C), foo_printer) + with nt.assert_raises(KeyError): + f.lookup_by_type(C) + with nt.assert_raises(KeyError): + f.pop(C) + with nt.assert_raises(KeyError): + f.pop(A) + nt.assert_is(f.pop(A, None), None) + +def test_pop_string(): + f = PlainTextFormatter() + type_str = '%s.%s' % (C.__module__, 'C') + + with nt.assert_raises(KeyError): + f.pop(type_str) + + f.for_type(type_str, foo_printer) + f.pop(type_str) + with nt.assert_raises(KeyError): + f.lookup_by_type(C) + with nt.assert_raises(KeyError): + f.pop(type_str) + + f.for_type(C, foo_printer) + nt.assert_is(f.pop(type_str, None), foo_printer) + with nt.assert_raises(KeyError): + f.lookup_by_type(C) + with nt.assert_raises(KeyError): + f.pop(type_str) + nt.assert_is(f.pop(type_str, None), None) + + diff --git a/docs/source/whatsnew/pr/formatters.txt b/docs/source/whatsnew/pr/formatters.txt new file mode 100644 index 0000000..6ea9efd --- /dev/null +++ b/docs/source/whatsnew/pr/formatters.txt @@ -0,0 +1,20 @@ +DisplayFormatter changes +======================== + +There was no official way to query or remove callbacks in the Formatter API. +To remedy this, the following methods are added to :class:`BaseFormatter`: + +- ``lookup(instance)`` - return appropriate callback or a given object +- ``lookup_by_type(type_or_str)`` - return appropriate callback for a given type or ``'mod.name'`` type string +- ``pop(type_or_str)`` - remove a type (by type or string). + Pass a second argument to avoid KeyError (like dict). + +All of the above methods raise a KeyError if no match is found. + +And the following methods are changed: + +- ``for_type(type_or_str)`` - behaves the same as before, only adding support for ``'mod.name'`` + type strings in addition to plain types. This removes the need for ``for_type_by_name()``, + but it remains for backward compatibility. + +