xref: /aosp_15_r20/external/autotest/server/brillo/audio_utils.py (revision 9c5db1993ded3edbeafc8092d69fe5de2ee02df7)
1# Lint as: python2, python3
2# Copyright (c) 2016 The Chromium Authors. All rights reserved.
3# Use of this source code is governed by a BSD-style license that can be
4# found in the LICENSE file.
5
6"""Server side audio utilities functions for Brillo."""
7
8from __future__ import absolute_import
9from __future__ import division
10from __future__ import print_function
11
12import contextlib
13import logging
14import numpy
15import os
16import struct
17import subprocess
18import tempfile
19import wave
20
21from autotest_lib.client.common_lib import error
22from six.moves import map
23from six.moves import range
24
25
26_BITS_PER_BYTE=8
27
28# Thresholds used when comparing files.
29#
30# The frequency threshold used when comparing files. The frequency of the
31# recorded audio has to be within _FREQUENCY_THRESHOLD percent of the frequency
32# of the original audio.
33_FREQUENCY_THRESHOLD = 0.01
34# Noise threshold controls how much noise is allowed as a fraction of the
35# magnitude of the peak frequency after taking an FFT. The power of all the
36# other frequencies in the signal should be within _FFT_NOISE_THRESHOLD percent
37# of the power of the main frequency.
38_FFT_NOISE_THRESHOLD = 0.05
39
40# Command used to encode audio. If you want to test with something different,
41# this should be changed.
42_ENCODING_CMD = 'sox'
43
44
45def extract_wav_frames(wave_file):
46    """Extract all frames from a WAV file.
47
48    @param wave_file: A Wave_read object representing a WAV file opened for
49                      reading.
50
51    @return: A list containing the frames in the WAV file.
52    """
53    num_frames = wave_file.getnframes()
54    sample_width = wave_file.getsampwidth()
55    if sample_width == 1:
56        fmt = '%iB'  # Read 1 byte.
57    elif sample_width == 2:
58        fmt = '%ih'  # Read 2 bytes.
59    elif sample_width == 4:
60        fmt = '%ii'  # Read 4 bytes.
61    else:
62        raise ValueError('Unsupported sample width')
63    frames =  list(struct.unpack(fmt % num_frames * wave_file.getnchannels(),
64                                 wave_file.readframes(num_frames)))
65
66    # Since 8-bit PCM is unsigned with an offset of 128, we subtract the offset
67    # to make it signed since the rest of the code assumes signed numbers.
68    if sample_width == 1:
69        frames = [val - 128 for val in frames]
70
71    return frames
72
73
74def check_wav_file(filename, num_channels=None, sample_rate=None,
75                   sample_width=None):
76    """Checks a WAV file and returns its peak PCM values.
77
78    @param filename: Input WAV file to analyze.
79    @param num_channels: Number of channels to expect (None to not check).
80    @param sample_rate: Sample rate to expect (None to not check).
81    @param sample_width: Sample width to expect (None to not check).
82
83    @return A list of the absolute maximum PCM values for each channel in the
84            WAV file.
85
86    @raise ValueError: Failed to process the WAV file or validate an attribute.
87    """
88    chk_file = None
89    try:
90        chk_file = wave.open(filename, 'r')
91        if num_channels is not None and chk_file.getnchannels() != num_channels:
92            raise ValueError('Expected %d channels but got %d instead.',
93                             num_channels, chk_file.getnchannels())
94        if sample_rate is not None and chk_file.getframerate() != sample_rate:
95            raise ValueError('Expected sample rate %d but got %d instead.',
96                             sample_rate, chk_file.getframerate())
97        if sample_width is not None and chk_file.getsampwidth() != sample_width:
98            raise ValueError('Expected sample width %d but got %d instead.',
99                             sample_width, chk_file.getsampwidth())
100        frames = extract_wav_frames(chk_file)
101    except wave.Error as e:
102        raise ValueError('Error processing WAV file: %s' % e)
103    finally:
104        if chk_file is not None:
105            chk_file.close()
106
107    peaks = []
108    for i in range(chk_file.getnchannels()):
109        peaks.append(max(list(map(abs, frames[i::chk_file.getnchannels()]))))
110    return peaks;
111
112
113def generate_sine_file(host, num_channels, sample_rate, sample_width,
114                       duration_secs, sine_frequency, temp_dir,
115                       file_format='wav'):
116    """Generate a sine file and push it to the DUT.
117
118    @param host: An object representing the DUT.
119    @param num_channels: Number of channels to use.
120    @param sample_rate: Sample rate to use for sine wave generation.
121    @param sample_width: Sample width to use for sine wave generation.
122    @param duration_secs: Duration in seconds to generate sine wave for.
123    @param sine_frequency: Frequency to generate sine wave at.
124    @param temp_dir: A temporary directory on the host.
125    @param file_format: A string representing the encoding for the audio file.
126
127    @return A tuple of the filename on the server and the DUT.
128    """;
129    _, local_filename = tempfile.mkstemp(
130        prefix='sine-', suffix='.' + file_format, dir=temp_dir)
131    if sample_width == 1:
132        byte_format = '-e unsigned'
133    else:
134        byte_format = '-e signed'
135    gen_file_cmd = ('sox -n -t wav -c %d %s -b %d -r %d %s synth %d sine %d '
136                    'vol 0.9' % (num_channels, byte_format,
137                                 sample_width * _BITS_PER_BYTE, sample_rate,
138                                 local_filename, duration_secs, sine_frequency))
139    logging.info('Command to generate sine wave: %s', gen_file_cmd)
140    subprocess.call(gen_file_cmd, shell=True)
141    if file_format != 'wav':
142        # Convert the file to the appropriate format.
143        logging.info('Converting file to %s', file_format)
144        _, local_encoded_filename = tempfile.mkstemp(
145                prefix='sine-', suffix='.' + file_format, dir=temp_dir)
146        cvt_file_cmd = '%s %s %s' % (_ENCODING_CMD, local_filename,
147                                     local_encoded_filename)
148        logging.info('Command to convert file: %s', cvt_file_cmd)
149        subprocess.call(cvt_file_cmd, shell=True)
150    else:
151        local_encoded_filename = local_filename
152    dut_tmp_dir = '/data'
153    remote_filename = os.path.join(dut_tmp_dir, 'sine.' + file_format)
154    logging.info('Send file to DUT.')
155    # TODO(ralphnathan): Find a better place to put this file once the SELinux
156    # issues are resolved.
157    logging.info('remote_filename %s', remote_filename)
158    host.send_file(local_encoded_filename, remote_filename)
159    return local_filename, remote_filename
160
161
162def _is_outside_frequency_threshold(freq_reference, freq_rec):
163    """Compares the frequency of the recorded audio with the reference audio.
164
165    This function checks to see if the frequencies corresponding to the peak
166    FFT values are similiar meaning that the dominant frequency in the audio
167    signal is the same for the recorded audio as that in the audio played.
168
169    @param req_reference: The dominant frequency in the reference audio file.
170    @param freq_rec: The dominant frequency in the recorded audio file.
171
172    @return: True is freq_rec is with _FREQUENCY_THRESHOLD percent of
173              freq_reference.
174    """
175    ratio = float(freq_rec) / freq_reference
176    if ratio > 1 + _FREQUENCY_THRESHOLD or ratio < 1 - _FREQUENCY_THRESHOLD:
177        return True
178    return False
179
180
181def _compare_frames(reference_file_frames, rec_file_frames, num_channels,
182                    sample_rate):
183    """Compares audio frames from the reference file and the recorded file.
184
185    This method checks for two things:
186      1. That the main frequency is the same in both the files. This is done
187         using the FFT and observing the frequency corresponding to the
188         peak.
189      2. That there is no other dominant frequency in the recorded file.
190         This is done by sweeping the frequency domain and checking that the
191         frequency is always less than _FFT_NOISE_THRESHOLD percentage of
192         the peak.
193
194    The key assumption here is that the reference audio file contains only
195    one frequency.
196
197    @param reference_file_frames: Audio frames from the reference file.
198    @param rec_file_frames: Audio frames from the recorded file.
199    @param num_channels: Number of channels in the files.
200    @param sample_rate: Sample rate of the files.
201
202    @raise error.TestFail: The frequency of the recorded signal doesn't
203                           match that of the reference signal.
204    @raise error.TestFail: There is too much noise in the recorded signal.
205    """
206    for channel in range(num_channels):
207        reference_data = reference_file_frames[channel::num_channels]
208        rec_data = rec_file_frames[channel::num_channels]
209
210        # Get fft and frequencies corresponding to the fft values.
211        fft_reference = numpy.fft.rfft(reference_data)
212        fft_rec = numpy.fft.rfft(rec_data)
213        fft_freqs_reference = numpy.fft.rfftfreq(len(reference_data),
214                                                 1.0 / sample_rate)
215        fft_freqs_rec = numpy.fft.rfftfreq(len(rec_data), 1.0 / sample_rate)
216
217        # Get frequency at highest peak.
218        freq_reference = fft_freqs_reference[
219                numpy.argmax(numpy.abs(fft_reference))]
220        abs_fft_rec = numpy.abs(fft_rec)
221        freq_rec = fft_freqs_rec[numpy.argmax(abs_fft_rec)]
222
223        # Compare the two frequencies.
224        logging.info('Golden frequency of channel %i is %f', channel,
225                     freq_reference)
226        logging.info('Recorded frequency of channel %i is  %f', channel,
227                     freq_rec)
228        if _is_outside_frequency_threshold(freq_reference, freq_rec):
229            raise error.TestFail('The recorded audio frequency does not match '
230                                 'that of the audio played.')
231
232        # Check for noise in the frequency domain.
233        fft_rec_peak_val = numpy.max(abs_fft_rec)
234        noise_detected = False
235        for fft_index, fft_val in enumerate(abs_fft_rec):
236            if _is_outside_frequency_threshold(freq_reference, freq_rec):
237                # If the frequency exceeds _FFT_NOISE_THRESHOLD, then fail.
238                if fft_val > _FFT_NOISE_THRESHOLD * fft_rec_peak_val:
239                    logging.warning('Unexpected frequency peak detected at %f '
240                                    'Hz.', fft_freqs_rec[fft_index])
241                    noise_detected = True
242
243        if noise_detected:
244            raise error.TestFail('Signal is noiser than expected.')
245
246
247def compare_file(reference_audio_filename, test_audio_filename):
248    """Compares the recorded audio file to the reference audio file.
249
250    @param reference_audio_filename : Reference audio file containing the
251                                      reference signal.
252    @param test_audio_filename: Audio file containing audio captured from
253                                the test.
254    """
255    with contextlib.closing(wave.open(reference_audio_filename,
256                                      'rb')) as reference_file:
257        with contextlib.closing(wave.open(test_audio_filename,
258                                          'rb')) as rec_file:
259            # Extract data from files.
260            reference_file_frames = extract_wav_frames(reference_file)
261            rec_file_frames = extract_wav_frames(rec_file)
262
263            num_channels = reference_file.getnchannels()
264            _compare_frames(reference_file_frames, rec_file_frames,
265                            reference_file.getnchannels(),
266                            reference_file.getframerate())
267