diff --git a/IPython/nbformat/v2/nbbase.py b/IPython/nbformat/v2/nbbase.py index 00f893e..0b5b517 100644 --- a/IPython/nbformat/v2/nbbase.py +++ b/IPython/nbformat/v2/nbbase.py @@ -119,12 +119,14 @@ def new_text_cell(cell_type, source=None, rendered=None): return cell -def new_heading_cell(source=None, level=1): +def new_heading_cell(source=None, rendered=None, level=1): """Create a new section cell with a given integer level.""" cell = NotebookNode() cell.cell_type = u'heading' if source is not None: cell.source = unicode(source) + if rendered is not None: + cell.rendered = unicode(rendered) cell.level = int(level) return cell diff --git a/IPython/nbformat/v2/rwbase.py b/IPython/nbformat/v2/rwbase.py index ae4e53d..7566b33 100644 --- a/IPython/nbformat/v2/rwbase.py +++ b/IPython/nbformat/v2/rwbase.py @@ -64,9 +64,7 @@ def rejoin_lines(nb): item = output.get(key, None) if isinstance(item, list): output[key] = u'\n'.join(item) - elif cell.cell_type == 'heading': - pass - else: # text cell + else: # text, heading cell for key in ['source', 'rendered']: item = cell.get(key, None) if isinstance(item, list): @@ -92,9 +90,7 @@ def split_lines(nb): item = output.get(key, None) if isinstance(item, basestring): output[key] = item.splitlines() - elif cell.cell_type == 'heading': - pass - else: # text cell + else: # text, heading cell for key in ['source', 'rendered']: item = cell.get(key, None) if isinstance(item, basestring): diff --git a/IPython/nbformat/v2/tests/test_nbbase.py b/IPython/nbformat/v2/tests/test_nbbase.py index 5a90d35..d973c79 100644 --- a/IPython/nbformat/v2/tests/test_nbbase.py +++ b/IPython/nbformat/v2/tests/test_nbbase.py @@ -74,10 +74,12 @@ class TestCell(TestCase): tc = new_heading_cell() self.assertEquals(tc.cell_type, u'heading') self.assertEquals(u'source' not in tc, True) + self.assertEquals(u'rendered' not in tc, True) def test_heading_cell(self): - tc = new_heading_cell(u'My Heading', level=2) - self.assertEquals(tc.source, u'My Heading') + tc = new_heading_cell(u'hi', u'hi', level=2) + self.assertEquals(tc.source, u'hi') + self.assertEquals(tc.rendered, u'hi') self.assertEquals(tc.level, 2)