From f3b5f76f0ac8465c52fa0dec97ac26e71d09a41b 2013-12-06 17:59:20 From: MinRK Date: 2013-12-06 17:59:20 Subject: [PATCH] another pass on formatter lookup - pop takes a default value (like dict) - support `type in Formatter` - allow lookup by type string (for consistency) --- diff --git a/IPython/core/formatters.py b/IPython/core/formatters.py index bde584c..06e7391 100644 --- a/IPython/core/formatters.py +++ b/IPython/core/formatters.py @@ -208,7 +208,10 @@ class FormatterABC(with_metaclass(abc.ABCMeta, object)): def _mod_name_key(typ): - """Return a '__module__.__name__' string key for a type.""" + """Return a (__module__, __name__) tuple for a type. + + Used as key in Formatter.deferred_printers. + """ module = getattr(typ, '__module__', None) name = getattr(typ, '__name__', None) return (module, name) @@ -218,6 +221,7 @@ def _get_type(obj): """Return the type of an instance (old and new-style)""" return getattr(obj, '__class__', None) or type(obj) +_raise_key_error = object() class BaseFormatter(Configurable): """A base formatter class that is configurable. @@ -283,6 +287,15 @@ class BaseFormatter(Configurable): else: return None + def __contains__(self, typ): + """map in to lookup_by_type""" + try: + self.lookup_by_type(typ) + except KeyError: + return False + else: + return True + def lookup(self, obj): """Look up the formatter for a given instance. @@ -293,7 +306,7 @@ class BaseFormatter(Configurable): Returns ------- f : callable - The registered fromatting callable for the type. + The registered formatting callable for the type. Raises ------ @@ -307,7 +320,7 @@ class BaseFormatter(Configurable): return self.lookup_by_type(_get_type(obj)) def lookup_by_type(self, typ): - """ Look up all the registered formatters for a type. + """Look up the registered formatter for a type. Parameters ---------- @@ -316,15 +329,26 @@ class BaseFormatter(Configurable): Returns ------- f : callable - The registered fromatting callable for the type. + The registered formatting callable for the type. Raises ------ KeyError if the type has not been registered. """ - for cls in pretty._get_mro(typ): - if cls in self.type_printers or self._in_deferred_types(cls): - return self.type_printers[cls] + if isinstance(typ, string_types): + typ_key = tuple(typ.rsplit('.',1)) + if typ_key not in self.deferred_printers: + # We may have it cached in the type map. We will have to + # iterate over all of the types to check. + for cls in self.type_printers: + if _mod_name_key(cls) == typ_key: + return self.type_printers[cls] + else: + return self.deferred_printers[typ_key] + else: + for cls in pretty._get_mro(typ): + if cls in self.type_printers or self._in_deferred_types(cls): + return self.type_printers[cls] # If we have reached here, the lookup failed. raise KeyError("No registered printer for {0!r}".format(typ)) @@ -398,17 +422,23 @@ class BaseFormatter(Configurable): """ key = (type_module, type_name) - oldfunc = self.deferred_printers.get(key, None) + try: + oldfunc = self.lookup_by_type("%s.%s" % key) + except KeyError: + oldfunc = None + if func is not None: self.deferred_printers[key] = func return oldfunc - def pop(self, typ): - """ Pop a registered object for the given type. + def pop(self, typ, default=_raise_key_error): + """Pop a formatter for the given type. Parameters ---------- typ : type or '__module__.__name__' string for a type + default : object + value to be returned if no formatter is registered for typ. Returns ------- @@ -417,8 +447,9 @@ class BaseFormatter(Configurable): Raises ------ - KeyError if the type is not registered. + KeyError if the type is not registered and default is not specified. """ + if isinstance(typ, string_types): typ_key = tuple(typ.rsplit('.',1)) if typ_key not in self.deferred_printers: @@ -429,14 +460,16 @@ class BaseFormatter(Configurable): old = self.type_printers.pop(cls) break else: - raise KeyError("No registered value for {0!r}".format(typ_key)) + old = default else: old = self.deferred_printers.pop(typ_key) else: if typ in self.type_printers: old = self.type_printers.pop(typ) else: - old = self.deferred_printers.pop(_mod_name_key(typ)) + old = self.deferred_printers.pop(_mod_name_key(typ), default) + if old is _raise_key_error: + raise KeyError("No registered value for {0!r}".format(typ)) return old def _in_deferred_types(self, cls): diff --git a/IPython/core/pylabtools.py b/IPython/core/pylabtools.py index 40cbabb..e581cc2 100644 --- a/IPython/core/pylabtools.py +++ b/IPython/core/pylabtools.py @@ -173,22 +173,13 @@ def select_figure_format(shell, fmt): png_formatter = shell.display_formatter.formatters['image/png'] if fmt == 'png': - try: - svg_formatter.pop(Figure) - except KeyError: - pass + svg_formatter.pop(Figure, None) png_formatter.for_type(Figure, lambda fig: print_figure(fig, 'png')) elif fmt in ('png2x', 'retina'): - try: - svg_formatter.pop(Figure) - except KeyError: - pass + svg_formatter.pop(Figure, None) png_formatter.for_type(Figure, retina_figure) elif fmt == 'svg': - try: - svg_formatter.pop(Figure) - except KeyError: - pass + png_formatter.pop(Figure, None) svg_formatter.for_type(Figure, lambda fig: print_figure(fig, 'svg')) else: raise ValueError("supported formats are: 'png', 'retina', 'svg', not %r" % fmt) diff --git a/IPython/core/tests/test_formatters.py b/IPython/core/tests/test_formatters.py index 843bef2..c792cae 100644 --- a/IPython/core/tests/test_formatters.py +++ b/IPython/core/tests/test_formatters.py @@ -169,22 +169,44 @@ def test_lookup_by_type_string(): nt.assert_in(_mod_name_key(C), f.deferred_printers) nt.assert_not_in(C, f.type_printers) + nt.assert_is(f.lookup_by_type(type_str), foo_printer) + # lookup by string doesn't cause import + nt.assert_in(_mod_name_key(C), f.deferred_printers) + nt.assert_not_in(C, f.type_printers) + nt.assert_is(f.lookup_by_type(C), foo_printer) # should move from deferred to imported dict nt.assert_not_in(_mod_name_key(C), f.deferred_printers) nt.assert_in(C, f.type_printers) +def test_in_formatter(): + f = PlainTextFormatter() + f.for_type(C, foo_printer) + type_str = '%s.%s' % (C.__module__, 'C') + nt.assert_in(C, f) + nt.assert_in(type_str, f) + +def test_string_in_formatter(): + f = PlainTextFormatter() + type_str = '%s.%s' % (C.__module__, 'C') + f.for_type(type_str, foo_printer) + nt.assert_in(type_str, f) + nt.assert_in(C, f) + def test_pop(): f = PlainTextFormatter() f.for_type(C, foo_printer) nt.assert_is(f.lookup_by_type(C), foo_printer) - f.pop(C) + nt.assert_is(f.pop(C, None), foo_printer) + f.for_type(C, foo_printer) + nt.assert_is(f.pop(C), foo_printer) with nt.assert_raises(KeyError): f.lookup_by_type(C) with nt.assert_raises(KeyError): f.pop(C) with nt.assert_raises(KeyError): f.pop(A) + nt.assert_is(f.pop(A, None), None) def test_pop_string(): f = PlainTextFormatter() @@ -201,10 +223,12 @@ def test_pop_string(): f.pop(type_str) f.for_type(C, foo_printer) - f.pop(type_str) + nt.assert_is(f.pop(type_str, None), foo_printer) with nt.assert_raises(KeyError): f.lookup_by_type(C) with nt.assert_raises(KeyError): f.pop(type_str) + nt.assert_is(f.pop(type_str, None), None) +