from collections import namedtuple from typing import List, Tuple import torch from torch import Tensor from .cells import flat_lstm_cell, lstm_cell, premul_lstm_cell, premul_lstm_cell_no_bias # list[list[T]] -> list[T] def flatten_list(lst): result = [] for inner in lst: result.extend(inner) return result """ Define a creator as a function: (options) -> (inputs, params, forward, backward_setup, backward) inputs: the inputs to the returned 'forward'. One can call forward(*inputs) directly. params: List[Tensor] all requires_grad=True parameters. forward: function / graph executor / module One can call rnn(rnn_inputs) using the outputs of the creator. backward_setup: backward_inputs = backward_setup(*outputs) Then, we pass backward_inputs to backward. If None, then it is assumed to be the identity function. backward: Given `output = backward_setup(*forward(*inputs))`, performs backpropagation. If None, then nothing happens. fastrnns.bench times the forward and backward invocations. """ ModelDef = namedtuple( "ModelDef", ["inputs", "params", "forward", "backward_setup", "backward"] ) def lstm_backward_setup(lstm_outputs, seed=None): hx, _ = lstm_outputs return simple_backward_setup(hx, seed) def simple_backward_setup(output, seed=None): assert isinstance(output, torch.Tensor) if seed: torch.manual_seed(seed) grad_output = torch.randn_like(output) return output, grad_output def simple_backward(output, grad_output, **kwargs): return output.backward(grad_output, **kwargs) def pytorch_lstm_creator(**kwargs): input, hidden, _, module = lstm_inputs(return_module=True, **kwargs) return ModelDef( inputs=[input, hidden], params=flatten_list(module.all_weights), forward=module, backward_setup=lstm_backward_setup, backward=simple_backward, ) def lstm_creator(script=True, **kwargs): input, hidden, params, _ = lstm_inputs(return_module=False, **kwargs) inputs = [input, hidden] + params[0] return ModelDef( inputs=inputs, params=flatten_list(params), forward=lstm_factory(lstm_cell, script), backward_setup=lstm_backward_setup, backward=simple_backward, ) def lnlstm_creator(script=True, decompose_layernorm=False, **kwargs): assert script is True from .custom_lstms import script_lnlstm input_size = kwargs["inputSize"] hidden_size = kwargs["hiddenSize"] seq_len = kwargs["seqLength"] batch_size = kwargs["miniBatch"] ge = script_lnlstm( input_size, hidden_size, 1, decompose_layernorm=decompose_layernorm ).cuda() input = torch.randn(seq_len, batch_size, input_size, device="cuda") states = [ ( torch.randn(batch_size, hidden_size, device="cuda"), torch.randn(batch_size, hidden_size, device="cuda"), ) ] return ModelDef( inputs=[input, states], params=ge.parameters(), forward=ge, backward_setup=lstm_backward_setup, backward=simple_backward, ) def dropoutlstm_creator(script=True, **kwargs): assert script is True from .custom_lstms import LSTMState, script_lstm input_size = kwargs["inputSize"] hidden_size = kwargs["hiddenSize"] seq_len = kwargs["seqLength"] batch_size = kwargs["miniBatch"] num_layers = kwargs["numLayers"] ge = script_lstm(input_size, hidden_size, num_layers, dropout=True).cuda() input = torch.randn(seq_len, batch_size, input_size, device="cuda") states = [ LSTMState( torch.randn(batch_size, hidden_size, device="cuda"), torch.randn(batch_size, hidden_size, device="cuda"), ) for _ in range(num_layers) ] return ModelDef( inputs=[input, states], params=ge.parameters(), forward=ge, backward_setup=lstm_backward_setup, backward=simple_backward, ) def lstm_premul_creator(script=True, **kwargs): input, hidden, params, _ = lstm_inputs(return_module=False, **kwargs) inputs = [input, hidden] + params[0] return ModelDef( inputs=inputs, params=flatten_list(params), forward=lstm_factory_premul(premul_lstm_cell, script), backward_setup=lstm_backward_setup, backward=simple_backward, ) def lstm_premul_bias_creator(script=True, **kwargs): input, hidden, params, _ = lstm_inputs(return_module=False, **kwargs) inputs = [input, hidden] + params[0] return ModelDef( inputs=inputs, params=flatten_list(params), forward=lstm_factory_premul_bias(premul_lstm_cell_no_bias, script), backward_setup=lstm_backward_setup, backward=simple_backward, ) def lstm_simple_creator(script=True, **kwargs): input, hidden, params, _ = lstm_inputs(return_module=False, **kwargs) inputs = [input] + [h[0] for h in hidden] + params[0] return ModelDef( inputs=inputs, params=flatten_list(params), forward=lstm_factory_simple(flat_lstm_cell, script), backward_setup=lstm_backward_setup, backward=simple_backward, ) def lstm_multilayer_creator(script=True, **kwargs): input, hidden, params, _ = lstm_inputs(return_module=False, **kwargs) inputs = [input, hidden, flatten_list(params)] return ModelDef( inputs=inputs, params=flatten_list(params), forward=lstm_factory_multilayer(lstm_cell, script), backward_setup=lstm_backward_setup, backward=simple_backward, ) def imagenet_cnn_creator(arch, jit=True): def creator(device="cuda", **kwargs): model = arch().to(device) x = torch.randn(32, 3, 224, 224, device=device) if jit: model = torch.jit.trace(model, x) return ModelDef( inputs=(x,), params=list(model.parameters()), forward=model, backward_setup=simple_backward_setup, backward=simple_backward, ) return creator def varlen_lstm_inputs( minlen=30, maxlen=100, numLayers=1, inputSize=512, hiddenSize=512, miniBatch=64, return_module=False, device="cuda", seed=None, **kwargs, ): if seed is not None: torch.manual_seed(seed) lengths = torch.randint( low=minlen, high=maxlen, size=[miniBatch], dtype=torch.long, device=device ) x = [torch.randn(length, inputSize, device=device) for length in lengths] hx = torch.randn(numLayers, miniBatch, hiddenSize, device=device) cx = torch.randn(numLayers, miniBatch, hiddenSize, device=device) lstm = torch.nn.LSTM(inputSize, hiddenSize, numLayers).to(device) if return_module: return x, lengths, (hx, cx), lstm.all_weights, lstm else: # NB: lstm.all_weights format: # wih, whh, bih, bhh = lstm.all_weights[layer] return x, lengths, (hx, cx), lstm.all_weights, None def varlen_lstm_backward_setup(forward_output, seed=None): if seed: torch.manual_seed(seed) rnn_utils = torch.nn.utils.rnn sequences = forward_output[0] padded = rnn_utils.pad_sequence(sequences) grad = torch.randn_like(padded) return padded, grad def varlen_pytorch_lstm_creator(**kwargs): rnn_utils = torch.nn.utils.rnn sequences, _, hidden, _, module = varlen_lstm_inputs(return_module=True, **kwargs) def forward(sequences, hidden): packed = rnn_utils.pack_sequence(sequences, enforce_sorted=False) out, new_hidden = module(packed, hidden) padded, lengths = rnn_utils.pad_packed_sequence(out) # XXX: It's more efficient to store the output in its padded form, # but that might not be conducive to loss computation. # Un-padding the output also makes the backward pass 2x slower... # return [padded[:lengths[i], i, :] for i in range(lengths.size(0))] return padded, new_hidden return ModelDef( inputs=[sequences, hidden], params=flatten_list(module.all_weights), forward=forward, backward_setup=lstm_backward_setup, backward=simple_backward, ) def varlen_lstm_factory(cell, script): def dynamic_rnn( sequences: List[Tensor], hiddens: Tuple[Tensor, Tensor], wih: Tensor, whh: Tensor, bih: Tensor, bhh: Tensor, ) -> Tuple[List[Tensor], Tuple[List[Tensor], List[Tensor]]]: hx, cx = hiddens hxs = hx.unbind(1) cxs = cx.unbind(1) # List of: (output, hx, cx) outputs = [] hx_outs = [] cx_outs = [] for batch in range(len(sequences)): output = [] hy, cy = hxs[batch], cxs[batch] inputs = sequences[batch].unbind(0) for seq_idx in range(len(inputs)): hy, cy = cell( inputs[seq_idx].unsqueeze(0), (hy, cy), wih, whh, bih, bhh ) output += [hy] outputs += [torch.stack(output)] hx_outs += [hy.unsqueeze(0)] cx_outs += [cy.unsqueeze(0)] return outputs, (hx_outs, cx_outs) if script: cell = torch.jit.script(cell) dynamic_rnn = torch.jit.script(dynamic_rnn) return dynamic_rnn def varlen_lstm_creator(script=False, **kwargs): sequences, _, hidden, params, _ = varlen_lstm_inputs(return_module=False, **kwargs) inputs = [sequences, hidden] + params[0] return ModelDef( inputs=inputs, params=flatten_list(params), forward=varlen_lstm_factory(lstm_cell, script), backward_setup=varlen_lstm_backward_setup, backward=simple_backward, ) # cudnn_layernorm_lstm: since cudnn does not have Layernorm LSTM, we cannot benchmark # the lowerbound directly. Instead, we only benchmark the forward pass by mimicing the # computation of a cudnn lstm + seq_len * 3 layernorm computation. This should serve # as a perf lowerbound for the Layernorm LSTM forward pass(given that Layernorm itself # is invariant), the lowerbound of backward pass is hard to get since we lose the # intermediate results, we can still optimize the layernorm implementation to make # a faster forward lowerbound though. def layernorm_pytorch_lstm_creator(**kwargs): input, hidden, _, module = lstm_inputs(return_module=True, **kwargs) batch_size = kwargs["miniBatch"] hidden_size = kwargs["hiddenSize"] ln_i = torch.nn.LayerNorm(4 * hidden_size).cuda() ln_h = torch.nn.LayerNorm(4 * hidden_size).cuda() ln_c = torch.nn.LayerNorm(hidden_size).cuda() ln_input1 = torch.randn(batch_size, 4 * hidden_size, device="cuda") def forward(input, hidden): out, new_hidden = module(input, hidden) # plus (seq_len * three laynorm cell computation) to mimic the lower bound of # Layernorm cudnn LSTM in the forward pass seq_len = len(input.unbind(0)) hy, cy = new_hidden for i in range(seq_len): ln_i_output = ln_i(ln_input1) ln_h_output = ln_h(ln_input1) cy = ln_c(cy) return out, (hy, cy) return ModelDef( inputs=[input, hidden], params=flatten_list(module.all_weights), forward=forward, backward_setup=lstm_backward_setup, backward=None, ) # input: lstm.all_weights format (wih, whh, bih, bhh = lstm.all_weights[layer]) # output: packed_weights with format # packed_weights[0] is wih with size (layer, 4*hiddenSize, inputSize) # packed_weights[1] is whh with size (layer, 4*hiddenSize, hiddenSize) # packed_weights[2] is bih with size (layer, 4*hiddenSize) # packed_weights[3] is bhh with size (layer, 4*hiddenSize) def stack_weights(weights): def unzip_columns(mat): assert isinstance(mat, list) assert isinstance(mat[0], list) layers = len(mat) columns = len(mat[0]) return [[mat[layer][col] for layer in range(layers)] for col in range(columns)] # XXX: script fns have problems indexing multidim lists, so we try to # avoid them by stacking tensors all_weights = weights packed_weights = [torch.stack(param) for param in unzip_columns(all_weights)] return packed_weights # returns: x, (hx, cx), all_weights, lstm module with all_weights as params def lstm_inputs( seqLength=100, numLayers=1, inputSize=512, hiddenSize=512, miniBatch=64, dropout=0.0, return_module=False, device="cuda", seed=None, ): if seed is not None: torch.manual_seed(seed) x = torch.randn(seqLength, miniBatch, inputSize, device=device) hx = torch.randn(numLayers, miniBatch, hiddenSize, device=device) cx = torch.randn(numLayers, miniBatch, hiddenSize, device=device) lstm = torch.nn.LSTM(inputSize, hiddenSize, numLayers, dropout=dropout) if "cuda" in device: lstm = lstm.cuda() if return_module: return x, (hx, cx), lstm.all_weights, lstm else: # NB: lstm.all_weights format: # wih, whh, bih, bhh = lstm.all_weights[layer] return x, (hx, cx), lstm.all_weights, None def lstm_factory(cell, script): def dynamic_rnn( input: Tensor, hidden: Tuple[Tensor, Tensor], wih: Tensor, whh: Tensor, bih: Tensor, bhh: Tensor, ) -> Tuple[Tensor, Tuple[Tensor, Tensor]]: hx, cx = hidden outputs = [] inputs = input.unbind(0) hy, cy = hx[0], cx[0] for seq_idx in range(len(inputs)): hy, cy = cell(inputs[seq_idx], (hy, cy), wih, whh, bih, bhh) outputs += [hy] return torch.stack(outputs), (hy.unsqueeze(0), cy.unsqueeze(0)) if script: cell = torch.jit.script(cell) dynamic_rnn = torch.jit.script(dynamic_rnn) return dynamic_rnn # premul: we're going to premultiply the inputs & weights def lstm_factory_premul(premul_cell, script): def dynamic_rnn( input: Tensor, hidden: Tuple[Tensor, Tensor], wih: Tensor, whh: Tensor, bih: Tensor, bhh: Tensor, ) -> Tuple[Tensor, Tuple[Tensor, Tensor]]: hx, cx = hidden outputs = [] inputs = torch.matmul(input, wih.t()).unbind(0) hy, cy = hx[0], cx[0] for seq_idx in range(len(inputs)): hy, cy = premul_cell(inputs[seq_idx], (hy, cy), whh, bih, bhh) outputs += [hy] return torch.stack(outputs), (hy.unsqueeze(0), cy.unsqueeze(0)) if script: premul_cell = torch.jit.script(premul_cell) dynamic_rnn = torch.jit.script(dynamic_rnn) return dynamic_rnn # premul: we're going to premultiply the inputs & weights, and add bias def lstm_factory_premul_bias(premul_cell, script): def dynamic_rnn( input: Tensor, hidden: Tuple[Tensor, Tensor], wih: Tensor, whh: Tensor, bih: Tensor, bhh: Tensor, ) -> Tuple[Tensor, Tuple[Tensor, Tensor]]: hx, cx = hidden outputs = [] inpSize = input.size() # add bias for all timesteps instead of going step-by-step, results in a single reduction kernel in the backward # FIXME matmul(x,y) + bias currently goes through jit AD, and backward formula in AD is not optimized for this # case. Workaround with mm and views. inpSize = input.size() inputs = torch.mm(input.view(-1, inpSize[2]), wih.t()) + bih inputs = inputs.view(inpSize[0], inpSize[1], -1).unbind(0) hy, cy = hx[0], cx[0] for seq_idx in range(len(inputs)): hy, cy = premul_cell(inputs[seq_idx], (hy, cy), whh, bhh) outputs += [hy] return torch.stack(outputs), (hy.unsqueeze(0), cy.unsqueeze(0)) if script: premul_cell = torch.jit.script(premul_cell) dynamic_rnn = torch.jit.script(dynamic_rnn) return dynamic_rnn # simple: flat inputs (no tuples), no list to accumulate outputs # useful mostly for benchmarking older JIT versions def lstm_factory_simple(cell, script): def dynamic_rnn(input, hx, cx, wih, whh, bih, bhh): hy = hx # for scoping cy = cx # for scoping inputs = input.unbind(0) for seq_idx in range(len(inputs)): hy, cy = cell(inputs[seq_idx], hy, cy, wih, whh, bih, bhh) return hy, cy if script: cell = torch.jit.script(cell) dynamic_rnn = torch.jit.script(dynamic_rnn) return dynamic_rnn def lstm_factory_multilayer(cell, script): def dynamic_rnn( input: Tensor, hidden: Tuple[Tensor, Tensor], params: List[Tensor] ) -> Tuple[Tensor, Tuple[Tensor, Tensor]]: params_stride = 4 # NB: this assumes that biases are there hx, cx = hidden hy, cy = hidden # for scoping... inputs, outputs = input.unbind(0), [] for layer in range(hx.size(0)): hy = hx[layer] cy = cx[layer] base_idx = layer * params_stride wih = params[base_idx] whh = params[base_idx + 1] bih = params[base_idx + 2] bhh = params[base_idx + 3] for seq_idx in range(len(inputs)): hy, cy = cell(inputs[seq_idx], (hy, cy), wih, whh, bih, bhh) outputs += [hy] inputs, outputs = outputs, [] return torch.stack(inputs), (hy.unsqueeze(0), cy.unsqueeze(0)) if script: cell = torch.jit.script(cell) dynamic_rnn = torch.jit.script(dynamic_rnn) return dynamic_rnn