1import os 2 3import torch 4import numpy as np 5 6def load_lpcnet_features(feature_file, version=2): 7 if version == 2: 8 layout = { 9 'cepstrum': [0,18], 10 'periods': [18, 19], 11 'pitch_corr': [19, 20], 12 'lpc': [20, 36] 13 } 14 frame_length = 36 15 16 elif version == 1: 17 layout = { 18 'cepstrum': [0,18], 19 'periods': [36, 37], 20 'pitch_corr': [37, 38], 21 'lpc': [39, 55], 22 } 23 frame_length = 55 24 else: 25 raise ValueError(f'unknown feature version: {version}') 26 27 28 raw_features = torch.from_numpy(np.fromfile(feature_file, dtype='float32')) 29 raw_features = raw_features.reshape((-1, frame_length)) 30 31 features = torch.cat( 32 [ 33 raw_features[:, layout['cepstrum'][0] : layout['cepstrum'][1]], 34 raw_features[:, layout['pitch_corr'][0] : layout['pitch_corr'][1]] 35 ], 36 dim=1 37 ) 38 39 lpcs = raw_features[:, layout['lpc'][0] : layout['lpc'][1]] 40 periods = (0.1 + 50 * raw_features[:, layout['periods'][0] : layout['periods'][1]] + 100).long() 41 42 return {'features' : features, 'periods' : periods, 'lpcs' : lpcs} 43 44 45 46def create_new_data(signal_path, reference_data_path, new_data_path, offset=320, preemph_factor=0.85): 47 ref_data = np.memmap(reference_data_path, dtype=np.int16) 48 signal = np.memmap(signal_path, dtype=np.int16) 49 50 signal_preemph_path = os.path.splitext(signal_path)[0] + '_preemph.raw' 51 signal_preemph = np.memmap(signal_preemph_path, dtype=np.int16, mode='write', shape=signal.shape) 52 53 54 assert len(signal) % 160 == 0 55 num_frames = len(signal) // 160 56 mem = np.zeros(1) 57 for fr in range(len(signal)//160): 58 signal_preemph[fr * 160 : (fr + 1) * 160] = np.convolve(np.concatenate((mem, signal[fr * 160 : (fr + 1) * 160])), [1, -preemph_factor], mode='valid') 59 mem = signal[(fr + 1) * 160 - 1 : (fr + 1) * 160] 60 61 new_data = np.memmap(new_data_path, dtype=np.int16, mode='write', shape=ref_data.shape) 62 63 new_data[:] = 0 64 N = len(signal) - offset 65 new_data[1 : 2*N + 1: 2] = signal_preemph[offset:] 66 new_data[2 : 2*N + 2: 2] = signal_preemph[offset:] 67 68 69def parse_warpq_scores(output_file): 70 """ extracts warpq scores from output file """ 71 72 with open(output_file, "r") as f: 73 lines = f.readlines() 74 75 scores = [float(line.split("WARP-Q score:")[-1]) for line in lines if line.startswith("WARP-Q score:")] 76 77 return scores 78 79 80def parse_stats_file(file): 81 82 with open(file, "r") as f: 83 lines = f.readlines() 84 85 mean = float(lines[0].split(":")[-1]) 86 bt_mean = float(lines[1].split(":")[-1]) 87 top_mean = float(lines[2].split(":")[-1]) 88 89 return mean, bt_mean, top_mean 90 91def collect_test_stats(test_folder): 92 """ collects statistics for all discovered metrics from test folder """ 93 94 metrics = {'pesq', 'warpq', 'pitch_error', 'voicing_error'} 95 96 results = dict() 97 98 content = os.listdir(test_folder) 99 100 stats_files = [file for file in content if file.startswith('stats_')] 101 102 for file in stats_files: 103 metric = file[len("stats_") : -len(".txt")] 104 105 if metric not in metrics: 106 print(f"warning: unknown metric {metric}") 107 108 mean, bt_mean, top_mean = parse_stats_file(os.path.join(test_folder, file)) 109 110 results[metric] = [mean, bt_mean, top_mean] 111 112 return results 113