diff --git a/IPython/core/display.py b/IPython/core/display.py
index 75ca479..5411287 100644
--- a/IPython/core/display.py
+++ b/IPython/core/display.py
@@ -423,8 +423,11 @@ class Javascript(DisplayObject):
class Image(DisplayObject):
_read_flags = 'rb'
+ _FMT_JPEG = u'jpeg'
+ _FMT_PNG = u'png'
+ _ACCEPTABLE_EMBEDDINGS = [_FMT_JPEG, _FMT_PNG]
- def __init__(self, data=None, url=None, filename=None, format=u'png', embed=None):
+ def __init__(self, data=None, url=None, filename=None, format=u'png', embed=None, width=None, height=None):
"""Create a display an PNG/JPEG image given raw data.
When this object is returned by an expression or passed to the
@@ -451,6 +454,10 @@ class Image(DisplayObject):
default value is `False`.
Note that QtConsole is not able to display images if `embed` is set to `False`
+ width : int
+ Width to which to constrain the image in html
+ height : int
+ Height to which to constrain the image in html
Examples
--------
@@ -466,17 +473,27 @@ class Image(DisplayObject):
ext = self._find_ext(filename)
elif url is not None:
ext = self._find_ext(url)
+ elif data is None:
+ raise ValueError("No image data found. Expecting filename, url, or data.")
elif data.startswith('http'):
ext = self._find_ext(data)
else:
ext = None
+
if ext is not None:
+ format = ext.lower()
if ext == u'jpg' or ext == u'jpeg':
- format = u'jpeg'
+ format = self._FMT_JPEG
if ext == u'png':
- format = u'png'
+ format = self._FMT_PNG
+
self.format = unicode(format).lower()
self.embed = embed if embed is not None else (url is None)
+
+ if self.embed and self.format not in self._ACCEPTABLE_EMBEDDINGS:
+ raise ValueError("Cannot embed the '%s' image format" % (self.format))
+ self.width = width
+ self.height = height
super(Image, self).__init__(data=data, url=url, filename=filename)
def reload(self):
@@ -486,7 +503,12 @@ class Image(DisplayObject):
def _repr_html_(self):
if not self.embed:
- return u'' % self.url
+ width = height = ''
+ if self.width:
+ width = ' width="%d"' % self.width
+ if self.height:
+ height = ' height="%d"' % self.height
+ return u'' % (self.url, width, height)
def _repr_png_(self):
if self.embed and self.format == u'png':
diff --git a/IPython/core/tests/test_display.py b/IPython/core/tests/test_display.py
new file mode 100644
index 0000000..61f5f8f
--- /dev/null
+++ b/IPython/core/tests/test_display.py
@@ -0,0 +1,38 @@
+#-----------------------------------------------------------------------------
+# Copyright (C) 2010-2011 The IPython Development Team.
+#
+# Distributed under the terms of the BSD License.
+#
+# The full license is in the file COPYING.txt, distributed with this software.
+#-----------------------------------------------------------------------------
+import os
+
+import nose.tools as nt
+
+from IPython.core import display
+from IPython.utils import path as ipath
+
+def test_image_size():
+ """Simple test for display.Image(args, width=x,height=y)"""
+ thisurl = 'http://www.google.fr/images/srpr/logo3w.png'
+ img = display.Image(url=thisurl, width=200, height=200)
+ nt.assert_equal(u'' % (thisurl), img._repr_html_())
+ img = display.Image(url=thisurl, width=200)
+ nt.assert_equal(u'' % (thisurl), img._repr_html_())
+ img = display.Image(url=thisurl)
+ nt.assert_equal(u'' % (thisurl), img._repr_html_())
+
+def test_image_filename_defaults():
+ '''test format constraint, and validity of jpeg and png'''
+ tpath = ipath.get_ipython_package_dir()
+ nt.assert_raises(ValueError, display.Image, filename=os.path.join(tpath, 'testing/tests/badformat.gif'),
+ embed=True)
+ nt.assert_raises(ValueError, display.Image)
+ nt.assert_raises(ValueError, display.Image, data='this is not an image', format='badformat', embed=True)
+ imgfile = os.path.join(tpath, 'frontend/html/notebook/static/ipynblogo.png')
+ img = display.Image(filename=imgfile)
+ nt.assert_equal('png', img.format)
+ nt.assert_is_not_none(img._repr_png_())
+ img = display.Image(filename=os.path.join(tpath, 'testing/tests/logo.jpg'), embed=False)
+ nt.assert_equal('jpeg', img.format)
+ nt.assert_is_none(img._repr_jpeg_())
diff --git a/IPython/testing/nose_assert_methods.py b/IPython/testing/nose_assert_methods.py
index c50a7fe..f41ad48 100644
--- a/IPython/testing/nose_assert_methods.py
+++ b/IPython/testing/nose_assert_methods.py
@@ -15,3 +15,15 @@ def assert_not_in(item, collection):
if not hasattr(nt, 'assert_not_in'):
nt.assert_not_in = assert_not_in
+
+def assert_is_none(obj):
+ assert obj is None, '%r is not None' % obj
+
+if not hasattr(nt, 'assert_is_none'):
+ nt.assert_is_none = assert_is_none
+
+def assert_is_not_none(obj):
+ assert obj is not None, '%r is None' % obj
+
+if not hasattr(nt, 'assert_is_not_none'):
+ nt.assert_is_not_none = assert_is_not_none