##// END OF EJS Templates
Merge pull request #11650 from matangover/optional-normalization...
Matthias Bussonnier -
r24975:08aaa2e1 merge
parent child Browse files
Show More
@@ -0,0 +1,7
1 Make audio normalization optional
2 =================================
3
4 Added 'normalize' argument to `IPython.display.Audio`. This argument applies
5 when audio data is given as an array of samples. The default of `normalize=True`
6 preserves prior behavior of normalizing the audio to the maximum possible range.
7 Setting to `False` disables normalization. No newline at end of file
@@ -54,6 +54,12 class Audio(DisplayObject):
54 autoplay : bool
54 autoplay : bool
55 Set to True if the audio should immediately start playing.
55 Set to True if the audio should immediately start playing.
56 Default is `False`.
56 Default is `False`.
57 normalize : bool
58 Whether audio should be normalized (rescaled) to the maximum possible
59 range. Default is `True`. When set to `False`, `data` must be between
60 -1 and 1 (inclusive), otherwise an error is raised.
61 Applies only when `data` is a list or array of samples; other types of
62 audio are never normalized.
57
63
58 Examples
64 Examples
59 --------
65 --------
@@ -83,9 +89,9 class Audio(DisplayObject):
83 """
89 """
84 _read_flags = 'rb'
90 _read_flags = 'rb'
85
91
86 def __init__(self, data=None, filename=None, url=None, embed=None, rate=None, autoplay=False):
92 def __init__(self, data=None, filename=None, url=None, embed=None, rate=None, autoplay=False, normalize=True):
87 if filename is None and url is None and data is None:
93 if filename is None and url is None and data is None:
88 raise ValueError("No image data found. Expecting filename, url, or data.")
94 raise ValueError("No audio data found. Expecting filename, url, or data.")
89 if embed is False and url is None:
95 if embed is False and url is None:
90 raise ValueError("No url found. Expecting url when embed=False")
96 raise ValueError("No url found. Expecting url when embed=False")
91
97
@@ -97,7 +103,9 class Audio(DisplayObject):
97 super(Audio, self).__init__(data=data, url=url, filename=filename)
103 super(Audio, self).__init__(data=data, url=url, filename=filename)
98
104
99 if self.data is not None and not isinstance(self.data, bytes):
105 if self.data is not None and not isinstance(self.data, bytes):
100 self.data = self._make_wav(data,rate)
106 if rate is None:
107 raise ValueError("rate must be specified when data is a numpy array or list of audio samples.")
108 self.data = Audio._make_wav(data, rate, normalize)
101
109
102 def reload(self):
110 def reload(self):
103 """Reload the raw data from file or URL."""
111 """Reload the raw data from file or URL."""
@@ -112,41 +120,17 class Audio(DisplayObject):
112 else:
120 else:
113 self.mimetype = "audio/wav"
121 self.mimetype = "audio/wav"
114
122
115 def _make_wav(self, data, rate):
123 @staticmethod
124 def _make_wav(data, rate, normalize):
116 """ Transform a numpy array to a PCM bytestring """
125 """ Transform a numpy array to a PCM bytestring """
117 import struct
126 import struct
118 from io import BytesIO
127 from io import BytesIO
119 import wave
128 import wave
120
129
121 try:
130 try:
122 import numpy as np
131 scaled, nchan = Audio._validate_and_normalize_with_numpy(data, normalize)
123
124 data = np.array(data, dtype=float)
125 if len(data.shape) == 1:
126 nchan = 1
127 elif len(data.shape) == 2:
128 # In wave files,channels are interleaved. E.g.,
129 # "L1R1L2R2..." for stereo. See
130 # http://msdn.microsoft.com/en-us/library/windows/hardware/dn653308(v=vs.85).aspx
131 # for channel ordering
132 nchan = data.shape[0]
133 data = data.T.ravel()
134 else:
135 raise ValueError('Array audio input must be a 1D or 2D array')
136 scaled = np.int16(data/np.max(np.abs(data))*32767).tolist()
137 except ImportError:
132 except ImportError:
138 # check that it is a "1D" list
133 scaled, nchan = Audio._validate_and_normalize_without_numpy(data, normalize)
139 idata = iter(data) # fails if not an iterable
140 try:
141 iter(idata.next())
142 raise TypeError('Only lists of mono audio are '
143 'supported if numpy is not installed')
144 except TypeError:
145 # this means it's not a nested list, which is what we want
146 pass
147 maxabsvalue = float(max([abs(x) for x in data]))
148 scaled = [int(x/maxabsvalue*32767) for x in data]
149 nchan = 1
150
134
151 fp = BytesIO()
135 fp = BytesIO()
152 waveobj = wave.open(fp,mode='wb')
136 waveobj = wave.open(fp,mode='wb')
@@ -160,6 +144,48 class Audio(DisplayObject):
160
144
161 return val
145 return val
162
146
147 @staticmethod
148 def _validate_and_normalize_with_numpy(data, normalize):
149 import numpy as np
150
151 data = np.array(data, dtype=float)
152 if len(data.shape) == 1:
153 nchan = 1
154 elif len(data.shape) == 2:
155 # In wave files,channels are interleaved. E.g.,
156 # "L1R1L2R2..." for stereo. See
157 # http://msdn.microsoft.com/en-us/library/windows/hardware/dn653308(v=vs.85).aspx
158 # for channel ordering
159 nchan = data.shape[0]
160 data = data.T.ravel()
161 else:
162 raise ValueError('Array audio input must be a 1D or 2D array')
163
164 max_abs_value = np.max(np.abs(data))
165 normalization_factor = Audio._get_normalization_factor(max_abs_value, normalize)
166 scaled = np.int16(data / normalization_factor * 32767).tolist()
167 return scaled, nchan
168
169
170 @staticmethod
171 def _validate_and_normalize_without_numpy(data, normalize):
172 try:
173 max_abs_value = float(max([abs(x) for x in data]))
174 except TypeError:
175 raise TypeError('Only lists of mono audio are '
176 'supported if numpy is not installed')
177
178 normalization_factor = Audio._get_normalization_factor(max_abs_value, normalize)
179 scaled = [int(x / normalization_factor * 32767) for x in data]
180 nchan = 1
181 return scaled, nchan
182
183 @staticmethod
184 def _get_normalization_factor(max_abs_value, normalize):
185 if not normalize and max_abs_value > 1:
186 raise ValueError('Audio data must be between -1 and 1 when normalize=False.')
187 return max_abs_value if normalize else 1
188
163 def _data_and_metadata(self):
189 def _data_and_metadata(self):
164 """shortcut for returning metadata with url information, if defined"""
190 """shortcut for returning metadata with url information, if defined"""
165 md = {}
191 md = {}
@@ -19,13 +19,17 try:
19 import pathlib
19 import pathlib
20 except ImportError:
20 except ImportError:
21 pass
21 pass
22 from unittest import TestCase, mock
23 import struct
24 import wave
25 from io import BytesIO
22
26
23 # Third-party imports
27 # Third-party imports
24 import nose.tools as nt
28 import nose.tools as nt
29 import numpy
25
30
26 # Our own imports
31 # Our own imports
27 from IPython.lib import display
32 from IPython.lib import display
28 from IPython.testing.decorators import skipif_not_numpy
29
33
30 #-----------------------------------------------------------------------------
34 #-----------------------------------------------------------------------------
31 # Classes and functions
35 # Classes and functions
@@ -179,11 +183,71 def test_recursive_FileLinks():
179 actual = actual.split('\n')
183 actual = actual.split('\n')
180 nt.assert_equal(len(actual), 2, actual)
184 nt.assert_equal(len(actual), 2, actual)
181
185
182 @skipif_not_numpy
183 def test_audio_from_file():
186 def test_audio_from_file():
184 path = pjoin(dirname(__file__), 'test.wav')
187 path = pjoin(dirname(__file__), 'test.wav')
185 display.Audio(filename=path)
188 display.Audio(filename=path)
186
189
190 class TestAudioDataWithNumpy(TestCase):
191 def test_audio_from_numpy_array(self):
192 test_tone = get_test_tone()
193 audio = display.Audio(test_tone, rate=44100)
194 nt.assert_equal(len(read_wav(audio.data)), len(test_tone))
195
196 def test_audio_from_list(self):
197 test_tone = get_test_tone()
198 audio = display.Audio(list(test_tone), rate=44100)
199 nt.assert_equal(len(read_wav(audio.data)), len(test_tone))
200
201 def test_audio_from_numpy_array_without_rate_raises(self):
202 nt.assert_raises(ValueError, display.Audio, get_test_tone())
203
204 def test_audio_data_normalization(self):
205 expected_max_value = numpy.iinfo(numpy.int16).max
206 for scale in [1, 0.5, 2]:
207 audio = display.Audio(get_test_tone(scale), rate=44100)
208 actual_max_value = numpy.max(numpy.abs(read_wav(audio.data)))
209 nt.assert_equal(actual_max_value, expected_max_value)
210
211 def test_audio_data_without_normalization(self):
212 max_int16 = numpy.iinfo(numpy.int16).max
213 for scale in [1, 0.5, 0.2]:
214 test_tone = get_test_tone(scale)
215 test_tone_max_abs = numpy.max(numpy.abs(test_tone))
216 expected_max_value = int(max_int16 * test_tone_max_abs)
217 audio = display.Audio(test_tone, rate=44100, normalize=False)
218 actual_max_value = numpy.max(numpy.abs(read_wav(audio.data)))
219 nt.assert_equal(actual_max_value, expected_max_value)
220
221 def test_audio_data_without_normalization_raises_for_invalid_data(self):
222 nt.assert_raises(
223 ValueError,
224 lambda: display.Audio([1.001], rate=44100, normalize=False))
225 nt.assert_raises(
226 ValueError,
227 lambda: display.Audio([-1.001], rate=44100, normalize=False))
228
229 def simulate_numpy_not_installed():
230 return mock.patch('numpy.array', mock.MagicMock(side_effect=ImportError))
231
232 @simulate_numpy_not_installed()
233 class TestAudioDataWithoutNumpy(TestAudioDataWithNumpy):
234 # All tests from `TestAudioDataWithNumpy` are inherited.
235
236 def test_audio_raises_for_nested_list(self):
237 stereo_signal = [list(get_test_tone())] * 2
238 nt.assert_raises(
239 TypeError,
240 lambda: display.Audio(stereo_signal, rate=44100))
241
242 def get_test_tone(scale=1):
243 return numpy.sin(2 * numpy.pi * 440 * numpy.linspace(0, 1, 44100)) * scale
244
245 def read_wav(data):
246 with wave.open(BytesIO(data)) as wave_file:
247 wave_data = wave_file.readframes(wave_file.getnframes())
248 num_samples = wave_file.getnframes() * wave_file.getnchannels()
249 return struct.unpack('<%sh' % num_samples, wave_data)
250
187 def test_code_from_file():
251 def test_code_from_file():
188 c = display.Code(filename=__file__)
252 c = display.Code(filename=__file__)
189 assert c._repr_html_().startswith('<style>')
253 assert c._repr_html_().startswith('<style>')
General Comments 0
You need to be logged in to leave comments. Login now