xref: /aosp_15_r20/external/pytorch/benchmarks/fastrnns/scratch.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1*da0073e9SAndroid Build Coastguard Workerimport torch
2*da0073e9SAndroid Build Coastguard Worker
3*da0073e9SAndroid Build Coastguard Worker
4*da0073e9SAndroid Build Coastguard Worker@torch.jit.script
5*da0073e9SAndroid Build Coastguard Workerdef fn(x, scale, shift):
6*da0073e9SAndroid Build Coastguard Worker    return scale * x / shift
7*da0073e9SAndroid Build Coastguard Worker
8*da0073e9SAndroid Build Coastguard Worker
9*da0073e9SAndroid Build Coastguard Worker@torch.jit.script
10*da0073e9SAndroid Build Coastguard Workerdef recurrent(x, scale, shift):
11*da0073e9SAndroid Build Coastguard Worker    y = x
12*da0073e9SAndroid Build Coastguard Worker    for i in range(100):
13*da0073e9SAndroid Build Coastguard Worker        y = fn(y, scale, shift)
14*da0073e9SAndroid Build Coastguard Worker    return y
15*da0073e9SAndroid Build Coastguard Worker
16*da0073e9SAndroid Build Coastguard Worker
17*da0073e9SAndroid Build Coastguard Workerx = torch.randn(2, 2, device="cuda")
18*da0073e9SAndroid Build Coastguard Workerscale = torch.randn(2, 2, device="cuda", requires_grad=True)
19*da0073e9SAndroid Build Coastguard Workershift = torch.randn(2, 2, device="cuda", requires_grad=True)
20*da0073e9SAndroid Build Coastguard Workerinputs = [x, scale, shift]
21*da0073e9SAndroid Build Coastguard Worker
22*da0073e9SAndroid Build Coastguard Worker
23*da0073e9SAndroid Build Coastguard Workerout = recurrent(x, scale, shift)
24*da0073e9SAndroid Build Coastguard Workerrecurrent.graph_for(x, scale, shift)
25*da0073e9SAndroid Build Coastguard Worker
26*da0073e9SAndroid Build Coastguard Worker
27*da0073e9SAndroid Build Coastguard Workerimport torch
28*da0073e9SAndroid Build Coastguard Worker
29*da0073e9SAndroid Build Coastguard Worker
30*da0073e9SAndroid Build Coastguard Worker@torch.jit.script
31*da0073e9SAndroid Build Coastguard Workerdef recurrent_scaleshift(x, scale, shift):
32*da0073e9SAndroid Build Coastguard Worker    y = x
33*da0073e9SAndroid Build Coastguard Worker    for i in range(64):
34*da0073e9SAndroid Build Coastguard Worker        y = scale * y + shift
35*da0073e9SAndroid Build Coastguard Worker    return y
36*da0073e9SAndroid Build Coastguard Worker
37*da0073e9SAndroid Build Coastguard Worker
38*da0073e9SAndroid Build Coastguard Workerx = torch.randn(2, 2, device="cuda")
39*da0073e9SAndroid Build Coastguard Workerscale = torch.randn(2, 2, device="cuda", requires_grad=True)
40*da0073e9SAndroid Build Coastguard Workershift = torch.randn(2, 2, device="cuda", requires_grad=True)
41*da0073e9SAndroid Build Coastguard Workerinputs = [x, scale, shift]
42*da0073e9SAndroid Build Coastguard Workerout = recurrent_scaleshift(x, scale, shift)
43*da0073e9SAndroid Build Coastguard Workerrecurrent_scaleshift.graph_for(x, scale, shift)
44*da0073e9SAndroid Build Coastguard Worker
45*da0073e9SAndroid Build Coastguard Worker
46*da0073e9SAndroid Build Coastguard Workerimport torch
47*da0073e9SAndroid Build Coastguard Worker
48*da0073e9SAndroid Build Coastguard Worker
49*da0073e9SAndroid Build Coastguard Workerx = torch.tensor([])
50*da0073e9SAndroid Build Coastguard Workerx.requires_grad = True
51*da0073e9SAndroid Build Coastguard Workerx.mean().backward()  # no error triggered
52*da0073e9SAndroid Build Coastguard Workerx = x.cuda()
53*da0073e9SAndroid Build Coastguard Workerx.mean().backward()
54