xref: /aosp_15_r20/external/libopus/dnn/torch/osce/utils/lpcnet_features.py (revision a58d3d2adb790c104798cd88c8a3aff4fa8b82cc)
1import os
2
3import torch
4import numpy as np
5
6def load_lpcnet_features(feature_file, version=2):
7    if version == 2:
8        layout = {
9            'cepstrum': [0,18],
10            'periods': [18, 19],
11            'pitch_corr': [19, 20],
12            'lpc': [20, 36]
13            }
14        frame_length = 36
15
16    elif version == 1:
17        layout = {
18            'cepstrum': [0,18],
19            'periods': [36, 37],
20            'pitch_corr': [37, 38],
21            'lpc': [39, 55],
22            }
23        frame_length = 55
24    else:
25        raise ValueError(f'unknown feature version: {version}')
26
27
28    raw_features = torch.from_numpy(np.fromfile(feature_file, dtype='float32'))
29    raw_features = raw_features.reshape((-1, frame_length))
30
31    features = torch.cat(
32        [
33            raw_features[:, layout['cepstrum'][0]   : layout['cepstrum'][1]],
34            raw_features[:, layout['pitch_corr'][0] : layout['pitch_corr'][1]]
35        ],
36        dim=1
37    )
38
39    lpcs = raw_features[:, layout['lpc'][0]   : layout['lpc'][1]]
40    periods = (0.1 + 50 * raw_features[:, layout['periods'][0] : layout['periods'][1]] + 100).long()
41
42    return {'features' : features, 'periods' : periods, 'lpcs' : lpcs}
43
44
45
46def create_new_data(signal_path, reference_data_path, new_data_path, offset=320, preemph_factor=0.85):
47    ref_data = np.memmap(reference_data_path, dtype=np.int16)
48    signal = np.memmap(signal_path, dtype=np.int16)
49
50    signal_preemph_path = os.path.splitext(signal_path)[0] + '_preemph.raw'
51    signal_preemph = np.memmap(signal_preemph_path, dtype=np.int16, mode='write', shape=signal.shape)
52
53
54    assert len(signal) % 160 == 0
55    num_frames = len(signal) // 160
56    mem = np.zeros(1)
57    for fr in range(len(signal)//160):
58        signal_preemph[fr * 160 : (fr + 1) * 160] = np.convolve(np.concatenate((mem, signal[fr * 160 : (fr + 1) * 160])), [1, -preemph_factor], mode='valid')
59        mem = signal[(fr + 1) * 160 - 1 : (fr + 1) * 160]
60
61    new_data = np.memmap(new_data_path, dtype=np.int16, mode='write', shape=ref_data.shape)
62
63    new_data[:] = 0
64    N = len(signal) - offset
65    new_data[1 : 2*N + 1: 2] = signal_preemph[offset:]
66    new_data[2 : 2*N + 2: 2] = signal_preemph[offset:]
67
68
69def parse_warpq_scores(output_file):
70    """ extracts warpq scores from output file """
71
72    with open(output_file, "r") as f:
73        lines = f.readlines()
74
75    scores = [float(line.split("WARP-Q score:")[-1]) for line in lines if line.startswith("WARP-Q score:")]
76
77    return scores
78
79
80def parse_stats_file(file):
81
82    with open(file, "r") as f:
83        lines = f.readlines()
84
85    mean     = float(lines[0].split(":")[-1])
86    bt_mean  = float(lines[1].split(":")[-1])
87    top_mean = float(lines[2].split(":")[-1])
88
89    return mean, bt_mean, top_mean
90
91def collect_test_stats(test_folder):
92    """ collects statistics for all discovered metrics from test folder """
93
94    metrics = {'pesq', 'warpq', 'pitch_error', 'voicing_error'}
95
96    results = dict()
97
98    content = os.listdir(test_folder)
99
100    stats_files = [file for file in content if file.startswith('stats_')]
101
102    for file in stats_files:
103        metric = file[len("stats_") : -len(".txt")]
104
105        if metric not in metrics:
106            print(f"warning: unknown metric {metric}")
107
108        mean, bt_mean, top_mean = parse_stats_file(os.path.join(test_folder, file))
109
110        results[metric] = [mean, bt_mean, top_mean]
111
112    return results
113