1*da0073e9SAndroid Build Coastguard Workerimport torch 2*da0073e9SAndroid Build Coastguard Worker 3*da0073e9SAndroid Build Coastguard Worker 4*da0073e9SAndroid Build Coastguard Worker@torch.jit.script 5*da0073e9SAndroid Build Coastguard Workerdef fn(x, scale, shift): 6*da0073e9SAndroid Build Coastguard Worker return scale * x / shift 7*da0073e9SAndroid Build Coastguard Worker 8*da0073e9SAndroid Build Coastguard Worker 9*da0073e9SAndroid Build Coastguard Worker@torch.jit.script 10*da0073e9SAndroid Build Coastguard Workerdef recurrent(x, scale, shift): 11*da0073e9SAndroid Build Coastguard Worker y = x 12*da0073e9SAndroid Build Coastguard Worker for i in range(100): 13*da0073e9SAndroid Build Coastguard Worker y = fn(y, scale, shift) 14*da0073e9SAndroid Build Coastguard Worker return y 15*da0073e9SAndroid Build Coastguard Worker 16*da0073e9SAndroid Build Coastguard Worker 17*da0073e9SAndroid Build Coastguard Workerx = torch.randn(2, 2, device="cuda") 18*da0073e9SAndroid Build Coastguard Workerscale = torch.randn(2, 2, device="cuda", requires_grad=True) 19*da0073e9SAndroid Build Coastguard Workershift = torch.randn(2, 2, device="cuda", requires_grad=True) 20*da0073e9SAndroid Build Coastguard Workerinputs = [x, scale, shift] 21*da0073e9SAndroid Build Coastguard Worker 22*da0073e9SAndroid Build Coastguard Worker 23*da0073e9SAndroid Build Coastguard Workerout = recurrent(x, scale, shift) 24*da0073e9SAndroid Build Coastguard Workerrecurrent.graph_for(x, scale, shift) 25*da0073e9SAndroid Build Coastguard Worker 26*da0073e9SAndroid Build Coastguard Worker 27*da0073e9SAndroid Build Coastguard Workerimport torch 28*da0073e9SAndroid Build Coastguard Worker 29*da0073e9SAndroid Build Coastguard Worker 30*da0073e9SAndroid Build Coastguard Worker@torch.jit.script 31*da0073e9SAndroid Build Coastguard Workerdef recurrent_scaleshift(x, scale, shift): 32*da0073e9SAndroid Build Coastguard Worker y = x 33*da0073e9SAndroid Build Coastguard Worker for i in range(64): 34*da0073e9SAndroid Build Coastguard Worker y = scale * y + shift 35*da0073e9SAndroid Build Coastguard Worker return y 36*da0073e9SAndroid Build Coastguard Worker 37*da0073e9SAndroid Build Coastguard Worker 38*da0073e9SAndroid Build Coastguard Workerx = torch.randn(2, 2, device="cuda") 39*da0073e9SAndroid Build Coastguard Workerscale = torch.randn(2, 2, device="cuda", requires_grad=True) 40*da0073e9SAndroid Build Coastguard Workershift = torch.randn(2, 2, device="cuda", requires_grad=True) 41*da0073e9SAndroid Build Coastguard Workerinputs = [x, scale, shift] 42*da0073e9SAndroid Build Coastguard Workerout = recurrent_scaleshift(x, scale, shift) 43*da0073e9SAndroid Build Coastguard Workerrecurrent_scaleshift.graph_for(x, scale, shift) 44*da0073e9SAndroid Build Coastguard Worker 45*da0073e9SAndroid Build Coastguard Worker 46*da0073e9SAndroid Build Coastguard Workerimport torch 47*da0073e9SAndroid Build Coastguard Worker 48*da0073e9SAndroid Build Coastguard Worker 49*da0073e9SAndroid Build Coastguard Workerx = torch.tensor([]) 50*da0073e9SAndroid Build Coastguard Workerx.requires_grad = True 51*da0073e9SAndroid Build Coastguard Workerx.mean().backward() # no error triggered 52*da0073e9SAndroid Build Coastguard Workerx = x.cuda() 53*da0073e9SAndroid Build Coastguard Workerx.mean().backward() 54