xref: /aosp_15_r20/external/libopus/dnn/training_tf2/train_lpcnet.py (revision a58d3d2adb790c104798cd88c8a3aff4fa8b82cc)
1*a58d3d2aSXin Li#!/usr/bin/python3
2*a58d3d2aSXin Li'''Copyright (c) 2018 Mozilla
3*a58d3d2aSXin Li
4*a58d3d2aSXin Li   Redistribution and use in source and binary forms, with or without
5*a58d3d2aSXin Li   modification, are permitted provided that the following conditions
6*a58d3d2aSXin Li   are met:
7*a58d3d2aSXin Li
8*a58d3d2aSXin Li   - Redistributions of source code must retain the above copyright
9*a58d3d2aSXin Li   notice, this list of conditions and the following disclaimer.
10*a58d3d2aSXin Li
11*a58d3d2aSXin Li   - Redistributions in binary form must reproduce the above copyright
12*a58d3d2aSXin Li   notice, this list of conditions and the following disclaimer in the
13*a58d3d2aSXin Li   documentation and/or other materials provided with the distribution.
14*a58d3d2aSXin Li
15*a58d3d2aSXin Li   THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
16*a58d3d2aSXin Li   ``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
17*a58d3d2aSXin Li   LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
18*a58d3d2aSXin Li   A PARTICULAR PURPOSE ARE DISCLAIMED.  IN NO EVENT SHALL THE FOUNDATION OR
19*a58d3d2aSXin Li   CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
20*a58d3d2aSXin Li   EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
21*a58d3d2aSXin Li   PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
22*a58d3d2aSXin Li   PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
23*a58d3d2aSXin Li   LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
24*a58d3d2aSXin Li   NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
25*a58d3d2aSXin Li   SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
26*a58d3d2aSXin Li'''
27*a58d3d2aSXin Li
28*a58d3d2aSXin Li# Train an LPCNet model
29*a58d3d2aSXin Li
30*a58d3d2aSXin Liimport argparse
31*a58d3d2aSXin Liimport os
32*a58d3d2aSXin Li
33*a58d3d2aSXin Lifrom dataloader import LPCNetLoader
34*a58d3d2aSXin Li
35*a58d3d2aSXin Liparser = argparse.ArgumentParser(description='Train an LPCNet model')
36*a58d3d2aSXin Li
37*a58d3d2aSXin Liparser.add_argument('features', metavar='<features file>', help='binary features file (float32)')
38*a58d3d2aSXin Liparser.add_argument('data', metavar='<audio data file>', help='binary audio data file (uint8)')
39*a58d3d2aSXin Liparser.add_argument('output', metavar='<output>', help='trained model file (.h5)')
40*a58d3d2aSXin Liparser.add_argument('--model', metavar='<model>', default='lpcnet', help='LPCNet model python definition (without .py)')
41*a58d3d2aSXin Ligroup1 = parser.add_mutually_exclusive_group()
42*a58d3d2aSXin Ligroup1.add_argument('--quantize', metavar='<input weights>', help='quantize model')
43*a58d3d2aSXin Ligroup1.add_argument('--retrain', metavar='<input weights>', help='continue training model')
44*a58d3d2aSXin Liparser.add_argument('--density', metavar='<global density>', type=float, help='average density of the recurrent weights (default 0.1)')
45*a58d3d2aSXin Liparser.add_argument('--density-split', nargs=3, metavar=('<update>', '<reset>', '<state>'), type=float, help='density of each recurrent gate (default 0.05, 0.05, 0.2)')
46*a58d3d2aSXin Liparser.add_argument('--grub-density', metavar='<global GRU B density>', type=float, help='average density of the recurrent weights (default 1.0)')
47*a58d3d2aSXin Liparser.add_argument('--grub-density-split', nargs=3, metavar=('<update>', '<reset>', '<state>'), type=float, help='density of each GRU B input gate (default 1.0, 1.0, 1.0)')
48*a58d3d2aSXin Liparser.add_argument('--grua-size', metavar='<units>', default=384, type=int, help='number of units in GRU A (default 384)')
49*a58d3d2aSXin Liparser.add_argument('--grub-size', metavar='<units>', default=16, type=int, help='number of units in GRU B (default 16)')
50*a58d3d2aSXin Liparser.add_argument('--cond-size', metavar='<units>', default=128, type=int, help='number of units in conditioning network, aka frame rate network (default 128)')
51*a58d3d2aSXin Liparser.add_argument('--epochs', metavar='<epochs>', default=120, type=int, help='number of epochs to train for (default 120)')
52*a58d3d2aSXin Liparser.add_argument('--batch-size', metavar='<batch size>', default=128, type=int, help='batch size to use (default 128)')
53*a58d3d2aSXin Liparser.add_argument('--end2end', dest='flag_e2e', action='store_true', help='Enable end-to-end training (with differentiable LPC computation')
54*a58d3d2aSXin Liparser.add_argument('--lr', metavar='<learning rate>', type=float, help='learning rate')
55*a58d3d2aSXin Liparser.add_argument('--decay', metavar='<decay>', type=float, help='learning rate decay')
56*a58d3d2aSXin Liparser.add_argument('--gamma', metavar='<gamma>', type=float, help='adjust u-law compensation (default 2.0, should not be less than 1.0)')
57*a58d3d2aSXin Liparser.add_argument('--lookahead', metavar='<nb frames>', default=2, type=int, help='Number of look-ahead frames (default 2)')
58*a58d3d2aSXin Liparser.add_argument('--logdir', metavar='<log dir>', help='directory for tensorboard log files')
59*a58d3d2aSXin Liparser.add_argument('--lpc-gamma', type=float, default=1, help='gamma for LPC weighting')
60*a58d3d2aSXin Liparser.add_argument('--cuda-devices', metavar='<cuda devices>', type=str, default=None, help='string with comma separated cuda device ids')
61*a58d3d2aSXin Li
62*a58d3d2aSXin Liargs = parser.parse_args()
63*a58d3d2aSXin Li
64*a58d3d2aSXin Li# set visible cuda devices
65*a58d3d2aSXin Liif args.cuda_devices != None:
66*a58d3d2aSXin Li    os.environ['CUDA_VISIBLE_DEVICES'] = args.cuda_devices
67*a58d3d2aSXin Li
68*a58d3d2aSXin Lidensity = (0.05, 0.05, 0.2)
69*a58d3d2aSXin Liif args.density_split is not None:
70*a58d3d2aSXin Li    density = args.density_split
71*a58d3d2aSXin Lielif args.density is not None:
72*a58d3d2aSXin Li    density = [0.5*args.density, 0.5*args.density, 2.0*args.density];
73*a58d3d2aSXin Li
74*a58d3d2aSXin Ligrub_density = (1., 1., 1.)
75*a58d3d2aSXin Liif args.grub_density_split is not None:
76*a58d3d2aSXin Li    grub_density = args.grub_density_split
77*a58d3d2aSXin Lielif args.grub_density is not None:
78*a58d3d2aSXin Li    grub_density = [0.5*args.grub_density, 0.5*args.grub_density, 2.0*args.grub_density];
79*a58d3d2aSXin Li
80*a58d3d2aSXin Ligamma = 2.0 if args.gamma is None else args.gamma
81*a58d3d2aSXin Li
82*a58d3d2aSXin Liimport importlib
83*a58d3d2aSXin Lilpcnet = importlib.import_module(args.model)
84*a58d3d2aSXin Li
85*a58d3d2aSXin Liimport sys
86*a58d3d2aSXin Liimport numpy as np
87*a58d3d2aSXin Lifrom tensorflow.keras.optimizers import Adam
88*a58d3d2aSXin Lifrom tensorflow.keras.callbacks import ModelCheckpoint, CSVLogger
89*a58d3d2aSXin Lifrom ulaw import ulaw2lin, lin2ulaw
90*a58d3d2aSXin Liimport tensorflow.keras.backend as K
91*a58d3d2aSXin Liimport h5py
92*a58d3d2aSXin Li
93*a58d3d2aSXin Liimport tensorflow as tf
94*a58d3d2aSXin Lifrom tf_funcs import *
95*a58d3d2aSXin Lifrom lossfuncs import *
96*a58d3d2aSXin Li#gpus = tf.config.experimental.list_physical_devices('GPU')
97*a58d3d2aSXin Li#if gpus:
98*a58d3d2aSXin Li#  try:
99*a58d3d2aSXin Li#    tf.config.experimental.set_virtual_device_configuration(gpus[0], [tf.config.experimental.VirtualDeviceConfiguration(memory_limit=5120)])
100*a58d3d2aSXin Li#  except RuntimeError as e:
101*a58d3d2aSXin Li#    print(e)
102*a58d3d2aSXin Li
103*a58d3d2aSXin Linb_epochs = args.epochs
104*a58d3d2aSXin Li
105*a58d3d2aSXin Li# Try reducing batch_size if you run out of memory on your GPU
106*a58d3d2aSXin Libatch_size = args.batch_size
107*a58d3d2aSXin Li
108*a58d3d2aSXin Liquantize = args.quantize is not None
109*a58d3d2aSXin Liretrain = args.retrain is not None
110*a58d3d2aSXin Li
111*a58d3d2aSXin Lilpc_order = 16
112*a58d3d2aSXin Li
113*a58d3d2aSXin Liif quantize:
114*a58d3d2aSXin Li    lr = 0.00003
115*a58d3d2aSXin Li    decay = 0
116*a58d3d2aSXin Li    input_model = args.quantize
117*a58d3d2aSXin Lielse:
118*a58d3d2aSXin Li    lr = 0.001
119*a58d3d2aSXin Li    decay = 5e-5
120*a58d3d2aSXin Li
121*a58d3d2aSXin Liif args.lr is not None:
122*a58d3d2aSXin Li    lr = args.lr
123*a58d3d2aSXin Li
124*a58d3d2aSXin Liif args.decay is not None:
125*a58d3d2aSXin Li    decay = args.decay
126*a58d3d2aSXin Li
127*a58d3d2aSXin Liif retrain:
128*a58d3d2aSXin Li    input_model = args.retrain
129*a58d3d2aSXin Li
130*a58d3d2aSXin Liflag_e2e = args.flag_e2e
131*a58d3d2aSXin Li
132*a58d3d2aSXin Liopt = Adam(lr, decay=decay, beta_1=0.5, beta_2=0.8)
133*a58d3d2aSXin Listrategy = tf.distribute.experimental.MultiWorkerMirroredStrategy()
134*a58d3d2aSXin Li
135*a58d3d2aSXin Liwith strategy.scope():
136*a58d3d2aSXin Li    model, _, _ = lpcnet.new_lpcnet_model(rnn_units1=args.grua_size,
137*a58d3d2aSXin Li                                          rnn_units2=args.grub_size,
138*a58d3d2aSXin Li                                          batch_size=batch_size, training=True,
139*a58d3d2aSXin Li                                          quantize=quantize,
140*a58d3d2aSXin Li                                          flag_e2e=flag_e2e,
141*a58d3d2aSXin Li                                          cond_size=args.cond_size,
142*a58d3d2aSXin Li                                          lpc_gamma=args.lpc_gamma,
143*a58d3d2aSXin Li                                          lookahead=args.lookahead
144*a58d3d2aSXin Li                                          )
145*a58d3d2aSXin Li    if not flag_e2e:
146*a58d3d2aSXin Li        model.compile(optimizer=opt, loss=metric_cel, metrics=metric_cel)
147*a58d3d2aSXin Li    else:
148*a58d3d2aSXin Li        model.compile(optimizer=opt, loss = [interp_mulaw(gamma=gamma), loss_matchlar()], loss_weights = [1.0, 2.0], metrics={'pdf':[metric_cel,metric_icel,metric_exc_sd,metric_oginterploss]})
149*a58d3d2aSXin Li    model.summary()
150*a58d3d2aSXin Li
151*a58d3d2aSXin Lifeature_file = args.features
152*a58d3d2aSXin Lipcm_file = args.data     # 16 bit unsigned short PCM samples
153*a58d3d2aSXin Liframe_size = model.frame_size
154*a58d3d2aSXin Linb_features = model.nb_used_features + lpc_order
155*a58d3d2aSXin Linb_used_features = model.nb_used_features
156*a58d3d2aSXin Lifeature_chunk_size = 15
157*a58d3d2aSXin Lipcm_chunk_size = frame_size*feature_chunk_size
158*a58d3d2aSXin Li
159*a58d3d2aSXin Li# u for unquantised, load 16 bit PCM samples and convert to mu-law
160*a58d3d2aSXin Li
161*a58d3d2aSXin Lidata = np.memmap(pcm_file, dtype='int16', mode='r')
162*a58d3d2aSXin Linb_frames = (len(data)//(2*pcm_chunk_size)-1)//batch_size*batch_size
163*a58d3d2aSXin Li
164*a58d3d2aSXin Lifeatures = np.memmap(feature_file, dtype='float32', mode='r')
165*a58d3d2aSXin Li
166*a58d3d2aSXin Li# limit to discrete number of frames
167*a58d3d2aSXin Lidata = data[(4-args.lookahead)*2*frame_size:]
168*a58d3d2aSXin Lidata = data[:nb_frames*2*pcm_chunk_size]
169*a58d3d2aSXin Li
170*a58d3d2aSXin Li
171*a58d3d2aSXin Lidata = np.reshape(data, (nb_frames, pcm_chunk_size, 2))
172*a58d3d2aSXin Li
173*a58d3d2aSXin Li#print("ulaw std = ", np.std(out_exc))
174*a58d3d2aSXin Li
175*a58d3d2aSXin Lisizeof = features.strides[-1]
176*a58d3d2aSXin Lifeatures = np.lib.stride_tricks.as_strided(features, shape=(nb_frames, feature_chunk_size+4, nb_features),
177*a58d3d2aSXin Li                                           strides=(feature_chunk_size*nb_features*sizeof, nb_features*sizeof, sizeof))
178*a58d3d2aSXin Li#features = features[:, :, :nb_used_features]
179*a58d3d2aSXin Li
180*a58d3d2aSXin Li
181*a58d3d2aSXin Liperiods = (.1 + 50*features[:,:,nb_used_features-2:nb_used_features-1]+100).astype('int16')
182*a58d3d2aSXin Li#periods = np.minimum(periods, 255)
183*a58d3d2aSXin Li
184*a58d3d2aSXin Li# dump models to disk as we go
185*a58d3d2aSXin Licheckpoint = ModelCheckpoint('{}_{}_{}.h5'.format(args.output, args.grua_size, '{epoch:02d}'))
186*a58d3d2aSXin Li
187*a58d3d2aSXin Liif args.retrain is not None:
188*a58d3d2aSXin Li    model.load_weights(args.retrain)
189*a58d3d2aSXin Li
190*a58d3d2aSXin Liif quantize or retrain:
191*a58d3d2aSXin Li    #Adapting from an existing model
192*a58d3d2aSXin Li    model.load_weights(input_model)
193*a58d3d2aSXin Li    if quantize:
194*a58d3d2aSXin Li        sparsify = lpcnet.Sparsify(10000, 30000, 100, density, quantize=True)
195*a58d3d2aSXin Li        grub_sparsify = lpcnet.SparsifyGRUB(10000, 30000, 100, args.grua_size, grub_density, quantize=True)
196*a58d3d2aSXin Li    else:
197*a58d3d2aSXin Li        sparsify = lpcnet.Sparsify(0, 0, 1, density)
198*a58d3d2aSXin Li        grub_sparsify = lpcnet.SparsifyGRUB(0, 0, 1, args.grua_size, grub_density)
199*a58d3d2aSXin Lielse:
200*a58d3d2aSXin Li    #Training from scratch
201*a58d3d2aSXin Li    sparsify = lpcnet.Sparsify(2000, 20000, 400, density)
202*a58d3d2aSXin Li    grub_sparsify = lpcnet.SparsifyGRUB(2000, 40000, 400, args.grua_size, grub_density)
203*a58d3d2aSXin Li
204*a58d3d2aSXin Limodel.save_weights('{}_{}_initial.h5'.format(args.output, args.grua_size))
205*a58d3d2aSXin Li
206*a58d3d2aSXin Liloader = LPCNetLoader(data, features, periods, batch_size, e2e=flag_e2e, lookahead=args.lookahead)
207*a58d3d2aSXin Li
208*a58d3d2aSXin Licallbacks = [checkpoint, sparsify, grub_sparsify]
209*a58d3d2aSXin Liif args.logdir is not None:
210*a58d3d2aSXin Li    logdir = '{}/{}_{}_logs'.format(args.logdir, args.output, args.grua_size)
211*a58d3d2aSXin Li    tensorboard_callback = tf.keras.callbacks.TensorBoard(log_dir=logdir)
212*a58d3d2aSXin Li    callbacks.append(tensorboard_callback)
213*a58d3d2aSXin Li
214*a58d3d2aSXin Limodel.fit(loader, epochs=nb_epochs, validation_split=0.0, callbacks=callbacks)
215