xref: /aosp_15_r20/external/armnn/python/pyarmnn/examples/common/audio_capture.py (revision 89c4ff92f2867872bb9e2354d150bf0c8c502810)
1# Copyright © 2021 Arm Ltd and Contributors. All rights reserved.
2# SPDX-License-Identifier: MIT
3"""Contains CaptureAudioStream class for capturing chunks of audio data from incoming
4  stream and generic capture_audio function for capturing from files."""
5import collections
6import time
7from queue import Queue
8from typing import Generator
9
10import numpy as np
11import sounddevice as sd
12import soundfile as sf
13
14AudioCaptureParams = collections.namedtuple('AudioCaptureParams',
15                                            ['dtype', 'overlap', 'min_samples', 'sampling_freq', 'mono'])
16
17
18def capture_audio(audio_file_path, params_tuple) -> Generator[np.ndarray, None, None]:
19    """Creates a generator that yields audio data from a file. Data is padded with
20    zeros if necessary to make up minimum number of samples.
21    Args:
22        audio_file_path: Path to audio file provided by user.
23        params_tuple: Sampling parameters for model used
24    Yields:
25        Blocks of audio data of minimum sample size.
26    """
27    with sf.SoundFile(audio_file_path) as audio_file:
28        for block in audio_file.blocks(
29                blocksize=params_tuple.min_samples,
30                dtype=params_tuple.dtype,
31                always_2d=True,
32                fill_value=0,
33                overlap=params_tuple.overlap
34        ):
35            if params_tuple.mono and block.shape[0] > 1:
36                block = np.mean(block, dtype=block.dtype, axis=1)
37            yield block
38
39
40class CaptureAudioStream:
41
42    def __init__(self, audio_capture_params):
43        self.audio_capture_params = audio_capture_params
44        self.collection = np.zeros(self.audio_capture_params.min_samples + self.audio_capture_params.overlap).astype(
45            dtype=self.audio_capture_params.dtype)
46        self.is_active = True
47        self.is_first_window = True
48        self.duration = False
49        self.block_count = 0
50        self.current_block = 0
51        self.queue = Queue(2)
52
53    def set_stream_defaults(self):
54        """Discovers input devices on the system and sets default stream parameters."""
55        print(sd.query_devices())
56        device = input("Select input device by index or name: ")
57
58        try:
59            sd.default.device = int(device)
60        except ValueError:
61            sd.default.device = str(device)
62
63        sd.default.samplerate = self.audio_capture_params.sampling_freq
64        sd.default.blocksize = self.audio_capture_params.min_samples
65        sd.default.dtype = self.audio_capture_params.dtype
66        sd.default.channels = 1 if self.audio_capture_params.mono else 2
67
68    def set_recording_duration(self, duration):
69        """Sets a time duration (in integer seconds) for recording audio. Total time duration is
70        adjusted to a minimum based on the parameters of the model used. Durations less than 1
71        result in endless recording.
72
73        Args:
74            duration (int): User-provided command line argument for time duration of recording.
75        """
76        if duration > 0:
77            min_duration = int(
78                np.ceil(self.audio_capture_params.min_samples / self.audio_capture_params.sampling_freq)
79            )
80            if duration < min_duration:
81                print(f"Minimum duration must be {min_duration} seconds of audio")
82                print(f"Setting minimum recording duration...")
83                duration = min_duration
84
85            print(f"Recording duration is {duration} seconds")
86            self.duration = self.audio_capture_params.sampling_freq * duration
87            self.block_count, remainder_samples = divmod(
88                self.duration, self.audio_capture_params.min_samples
89            )
90
91            if remainder_samples > 0.5 * self.audio_capture_params.sampling_freq:
92                self.block_count += 1
93        else:
94            self.duration = False  # Record forever
95
96    def countdown(self, delay=3):
97        """3 second countdown prior to recording audio."""
98        print("Beginning recording in...")
99        for i in range(delay, 0, -1):
100            print(f"{i}...")
101            time.sleep(1)
102
103    def update(self):
104        """If a duration has been set, increments a counter to update the number of blocks of audio
105        data left to be collected. The stream is deactivated upon reaching the maximum block count
106        determined by the duration.
107        """
108        if self.duration:
109            self.current_block += 1
110            if self.current_block == self.block_count:
111                self.is_active = False
112
113    def capture_data(self):
114        """Gets the next window of audio data by retrieving the newest data from a queue and
115        shifting the position of the data in the collection. Overlap values of less than `min_samples` are supported.
116        """
117        new_data = self.queue.get()
118
119        if self.is_first_window or self.audio_capture_params.overlap == 0:
120            self.collection[:self.audio_capture_params.min_samples] = new_data[:]
121
122        elif self.audio_capture_params.overlap < self.audio_capture_params.min_samples:
123            #
124            self.collection[0:self.audio_capture_params.overlap] = \
125                self.collection[(self.audio_capture_params.min_samples - self.audio_capture_params.overlap):
126                                self.audio_capture_params.min_samples]
127
128            self.collection[self.audio_capture_params.overlap:(
129                    self.audio_capture_params.overlap + self.audio_capture_params.min_samples)] = new_data[:]
130        else:
131            raise ValueError(
132                "Capture Error: Overlap must be less than {}".format(self.audio_capture_params.min_samples))
133        audio_data = self.collection[0:self.audio_capture_params.min_samples]
134        return np.asarray(audio_data).astype(self.audio_capture_params.dtype)
135
136    def callback(self, data, frames, time, status):
137        """Places audio data from active stream into a queue for processing.
138        Update counter if recording duration is finite.
139         """
140
141        if self.duration:
142            self.update()
143
144        if self.audio_capture_params.mono:
145            audio_data = data.copy().flatten()
146        else:
147            audio_data = data.copy()
148
149        self.queue.put(audio_data)
150