xref: /aosp_15_r20/external/pytorch/test/jit/test_models.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1*da0073e9SAndroid Build Coastguard Worker# Owner(s): ["oncall: jit"]
2*da0073e9SAndroid Build Coastguard Worker
3*da0073e9SAndroid Build Coastguard Workerimport os
4*da0073e9SAndroid Build Coastguard Workerimport sys
5*da0073e9SAndroid Build Coastguard Workerimport unittest
6*da0073e9SAndroid Build Coastguard Worker
7*da0073e9SAndroid Build Coastguard Workerimport torch
8*da0073e9SAndroid Build Coastguard Workerimport torch.nn as nn
9*da0073e9SAndroid Build Coastguard Workerimport torch.nn.functional as F
10*da0073e9SAndroid Build Coastguard Workerfrom torch.testing._internal.common_utils import (
11*da0073e9SAndroid Build Coastguard Worker    enable_profiling_mode_for_profiling_tests,
12*da0073e9SAndroid Build Coastguard Worker    GRAPH_EXECUTOR,
13*da0073e9SAndroid Build Coastguard Worker    ProfilingMode,
14*da0073e9SAndroid Build Coastguard Worker    set_default_dtype,
15*da0073e9SAndroid Build Coastguard Worker)
16*da0073e9SAndroid Build Coastguard Worker
17*da0073e9SAndroid Build Coastguard Worker
18*da0073e9SAndroid Build Coastguard Worker# Make the helper files in test/ importable
19*da0073e9SAndroid Build Coastguard Workerpytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
20*da0073e9SAndroid Build Coastguard Workersys.path.append(pytorch_test_dir)
21*da0073e9SAndroid Build Coastguard Workerfrom torch.testing._internal.common_utils import slowTest, suppress_warnings
22*da0073e9SAndroid Build Coastguard Workerfrom torch.testing._internal.jit_utils import JitTestCase, RUN_CUDA
23*da0073e9SAndroid Build Coastguard Worker
24*da0073e9SAndroid Build Coastguard Worker
25*da0073e9SAndroid Build Coastguard Workerif __name__ == "__main__":
26*da0073e9SAndroid Build Coastguard Worker    raise RuntimeError(
27*da0073e9SAndroid Build Coastguard Worker        "This test file is not meant to be run directly, use:\n\n"
28*da0073e9SAndroid Build Coastguard Worker        "\tpython test/test_jit.py TESTNAME\n\n"
29*da0073e9SAndroid Build Coastguard Worker        "instead."
30*da0073e9SAndroid Build Coastguard Worker    )
31*da0073e9SAndroid Build Coastguard Worker
32*da0073e9SAndroid Build Coastguard Workertry:
33*da0073e9SAndroid Build Coastguard Worker    import torchvision
34*da0073e9SAndroid Build Coastguard Worker
35*da0073e9SAndroid Build Coastguard Worker    HAS_TORCHVISION = True
36*da0073e9SAndroid Build Coastguard Workerexcept ImportError:
37*da0073e9SAndroid Build Coastguard Worker    HAS_TORCHVISION = False
38*da0073e9SAndroid Build Coastguard Workerexcept RuntimeError:
39*da0073e9SAndroid Build Coastguard Worker    HAS_TORCHVISION = False
40*da0073e9SAndroid Build Coastguard WorkerskipIfNoTorchVision = unittest.skipIf(not HAS_TORCHVISION, "no torchvision")
41*da0073e9SAndroid Build Coastguard Worker
42*da0073e9SAndroid Build Coastguard Worker
43*da0073e9SAndroid Build Coastguard Workerclass MnistNet(nn.Module):
44*da0073e9SAndroid Build Coastguard Worker    def __init__(self) -> None:
45*da0073e9SAndroid Build Coastguard Worker        super().__init__()
46*da0073e9SAndroid Build Coastguard Worker        self.conv1 = nn.Conv2d(1, 10, kernel_size=5)
47*da0073e9SAndroid Build Coastguard Worker        self.conv2 = nn.Conv2d(10, 20, kernel_size=5)
48*da0073e9SAndroid Build Coastguard Worker        self.conv2_drop = nn.Dropout2d()
49*da0073e9SAndroid Build Coastguard Worker        self.fc1 = nn.Linear(320, 50)
50*da0073e9SAndroid Build Coastguard Worker        self.fc2 = nn.Linear(50, 10)
51*da0073e9SAndroid Build Coastguard Worker
52*da0073e9SAndroid Build Coastguard Worker    def forward(self, x):
53*da0073e9SAndroid Build Coastguard Worker        x = F.relu(F.max_pool2d(self.conv1(x), 2))
54*da0073e9SAndroid Build Coastguard Worker        x = F.relu(F.max_pool2d(self.conv2_drop(self.conv2(x)), 2))
55*da0073e9SAndroid Build Coastguard Worker        x = x.reshape(-1, 320)
56*da0073e9SAndroid Build Coastguard Worker        x = F.relu(self.fc1(x))
57*da0073e9SAndroid Build Coastguard Worker        x = F.dropout(x, training=self.training)
58*da0073e9SAndroid Build Coastguard Worker        x = self.fc2(x)
59*da0073e9SAndroid Build Coastguard Worker        return F.log_softmax(x, dim=1)
60*da0073e9SAndroid Build Coastguard Worker
61*da0073e9SAndroid Build Coastguard Worker
62*da0073e9SAndroid Build Coastguard Workerclass TestModels(JitTestCase):
63*da0073e9SAndroid Build Coastguard Worker    @staticmethod
64*da0073e9SAndroid Build Coastguard Worker    def _test_dcgan_models(self, device, check_export_import=True):
65*da0073e9SAndroid Build Coastguard Worker        class DCGANGenerator(nn.Module):
66*da0073e9SAndroid Build Coastguard Worker            def __init__(self, nz, ngf, nc):
67*da0073e9SAndroid Build Coastguard Worker                super().__init__()
68*da0073e9SAndroid Build Coastguard Worker                self.main = nn.Sequential(
69*da0073e9SAndroid Build Coastguard Worker                    # input is Z, going into a convolution
70*da0073e9SAndroid Build Coastguard Worker                    nn.ConvTranspose2d(nz, ngf * 8, 4, 1, 0, bias=False),
71*da0073e9SAndroid Build Coastguard Worker                    nn.BatchNorm2d(ngf * 8),
72*da0073e9SAndroid Build Coastguard Worker                    nn.ReLU(True),
73*da0073e9SAndroid Build Coastguard Worker                    # state size. (ngf*8) x 4 x 4
74*da0073e9SAndroid Build Coastguard Worker                    nn.ConvTranspose2d(ngf * 8, ngf * 4, 4, 2, 1, bias=False),
75*da0073e9SAndroid Build Coastguard Worker                    nn.BatchNorm2d(ngf * 4),
76*da0073e9SAndroid Build Coastguard Worker                    nn.ReLU(True),
77*da0073e9SAndroid Build Coastguard Worker                    # state size. (ngf*4) x 8 x 8
78*da0073e9SAndroid Build Coastguard Worker                    nn.ConvTranspose2d(ngf * 4, ngf * 2, 4, 2, 1, bias=False),
79*da0073e9SAndroid Build Coastguard Worker                    nn.BatchNorm2d(ngf * 2),
80*da0073e9SAndroid Build Coastguard Worker                    nn.ReLU(True),
81*da0073e9SAndroid Build Coastguard Worker                    # state size. (ngf*2) x 16 x 16
82*da0073e9SAndroid Build Coastguard Worker                    nn.ConvTranspose2d(ngf * 2, ngf, 4, 2, 1, bias=False),
83*da0073e9SAndroid Build Coastguard Worker                    nn.BatchNorm2d(ngf),
84*da0073e9SAndroid Build Coastguard Worker                    nn.ReLU(True),
85*da0073e9SAndroid Build Coastguard Worker                    # state size. (ngf) x 32 x 32
86*da0073e9SAndroid Build Coastguard Worker                    nn.ConvTranspose2d(ngf, nc, 4, 2, 1, bias=False),
87*da0073e9SAndroid Build Coastguard Worker                    nn.Tanh()
88*da0073e9SAndroid Build Coastguard Worker                    # state size. (nc) x 64 x 64
89*da0073e9SAndroid Build Coastguard Worker                )
90*da0073e9SAndroid Build Coastguard Worker
91*da0073e9SAndroid Build Coastguard Worker            def forward(self, input):
92*da0073e9SAndroid Build Coastguard Worker                return self.main(input)
93*da0073e9SAndroid Build Coastguard Worker
94*da0073e9SAndroid Build Coastguard Worker        class DCGANDiscriminator(nn.Module):
95*da0073e9SAndroid Build Coastguard Worker            def __init__(self, nc, ndf):
96*da0073e9SAndroid Build Coastguard Worker                super().__init__()
97*da0073e9SAndroid Build Coastguard Worker                self.main = nn.Sequential(
98*da0073e9SAndroid Build Coastguard Worker                    # input is (nc) x 64 x 64
99*da0073e9SAndroid Build Coastguard Worker                    nn.Conv2d(nc, ndf, 4, 2, 1, bias=False),
100*da0073e9SAndroid Build Coastguard Worker                    nn.LeakyReLU(0.2, inplace=True),
101*da0073e9SAndroid Build Coastguard Worker                    # state size. (ndf) x 32 x 32
102*da0073e9SAndroid Build Coastguard Worker                    nn.Conv2d(ndf, ndf * 2, 4, 2, 1, bias=False),
103*da0073e9SAndroid Build Coastguard Worker                    nn.BatchNorm2d(ndf * 2),
104*da0073e9SAndroid Build Coastguard Worker                    nn.LeakyReLU(0.2, inplace=True),
105*da0073e9SAndroid Build Coastguard Worker                    # state size. (ndf*2) x 16 x 16
106*da0073e9SAndroid Build Coastguard Worker                    nn.Conv2d(ndf * 2, ndf * 4, 4, 2, 1, bias=False),
107*da0073e9SAndroid Build Coastguard Worker                    nn.BatchNorm2d(ndf * 4),
108*da0073e9SAndroid Build Coastguard Worker                    nn.LeakyReLU(0.2, inplace=True),
109*da0073e9SAndroid Build Coastguard Worker                    # state size. (ndf*4) x 8 x 8
110*da0073e9SAndroid Build Coastguard Worker                    nn.Conv2d(ndf * 4, ndf * 8, 4, 2, 1, bias=False),
111*da0073e9SAndroid Build Coastguard Worker                    nn.BatchNorm2d(ndf * 8),
112*da0073e9SAndroid Build Coastguard Worker                    nn.LeakyReLU(0.2, inplace=True),
113*da0073e9SAndroid Build Coastguard Worker                    # state size. (ndf*8) x 4 x 4
114*da0073e9SAndroid Build Coastguard Worker                    nn.Conv2d(ndf * 8, 1, 4, 1, 0, bias=False),
115*da0073e9SAndroid Build Coastguard Worker                    nn.Sigmoid(),
116*da0073e9SAndroid Build Coastguard Worker                )
117*da0073e9SAndroid Build Coastguard Worker
118*da0073e9SAndroid Build Coastguard Worker            def forward(self, input):
119*da0073e9SAndroid Build Coastguard Worker                return self.main(input).view(-1, 1).squeeze(1)
120*da0073e9SAndroid Build Coastguard Worker
121*da0073e9SAndroid Build Coastguard Worker        bs, nz, ngf, nc, ndf = 5, 6, 9, 3, 10
122*da0073e9SAndroid Build Coastguard Worker        self.checkTrace(
123*da0073e9SAndroid Build Coastguard Worker            DCGANGenerator(nz, ngf, nc).to(device),
124*da0073e9SAndroid Build Coastguard Worker            (torch.rand(bs, nz, 1, 1, device=device),),
125*da0073e9SAndroid Build Coastguard Worker            export_import=check_export_import,
126*da0073e9SAndroid Build Coastguard Worker        )
127*da0073e9SAndroid Build Coastguard Worker        example_input = DCGANGenerator(nz, ngf, nc).to(device)(
128*da0073e9SAndroid Build Coastguard Worker            torch.rand(bs, nz, 1, 1, device=device)
129*da0073e9SAndroid Build Coastguard Worker        )
130*da0073e9SAndroid Build Coastguard Worker        self.checkTrace(
131*da0073e9SAndroid Build Coastguard Worker            DCGANDiscriminator(nc, ndf).to(device),
132*da0073e9SAndroid Build Coastguard Worker            (example_input,),
133*da0073e9SAndroid Build Coastguard Worker            export_import=check_export_import,
134*da0073e9SAndroid Build Coastguard Worker        )
135*da0073e9SAndroid Build Coastguard Worker
136*da0073e9SAndroid Build Coastguard Worker    def test_dcgan_models(self):
137*da0073e9SAndroid Build Coastguard Worker        # Note: Can sometimes fail with low precision if run with float dtype
138*da0073e9SAndroid Build Coastguard Worker        with set_default_dtype(torch.double):
139*da0073e9SAndroid Build Coastguard Worker            self._test_dcgan_models(self, device="cpu")
140*da0073e9SAndroid Build Coastguard Worker
141*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(not RUN_CUDA, "no CUDA")
142*da0073e9SAndroid Build Coastguard Worker    def test_dcgan_models_cuda(self):
143*da0073e9SAndroid Build Coastguard Worker        # Note: Can sometimes fail with low precision if run with float dtype
144*da0073e9SAndroid Build Coastguard Worker        with set_default_dtype(torch.double):
145*da0073e9SAndroid Build Coastguard Worker            # XXX: export_import on CUDA modules doesn't work (#11480)
146*da0073e9SAndroid Build Coastguard Worker            self._test_dcgan_models(self, device="cuda", check_export_import=False)
147*da0073e9SAndroid Build Coastguard Worker
148*da0073e9SAndroid Build Coastguard Worker    @staticmethod
149*da0073e9SAndroid Build Coastguard Worker    def _test_neural_style(self, device, check_export_import=True):
150*da0073e9SAndroid Build Coastguard Worker        class TransformerNet(torch.nn.Module):
151*da0073e9SAndroid Build Coastguard Worker            def __init__(self) -> None:
152*da0073e9SAndroid Build Coastguard Worker                super().__init__()
153*da0073e9SAndroid Build Coastguard Worker                # Initial convolution layers
154*da0073e9SAndroid Build Coastguard Worker                self.conv1 = ConvLayer(3, 32, kernel_size=9, stride=1)
155*da0073e9SAndroid Build Coastguard Worker                self.in1 = torch.nn.InstanceNorm2d(32, affine=True)
156*da0073e9SAndroid Build Coastguard Worker                self.conv2 = ConvLayer(32, 64, kernel_size=3, stride=2)
157*da0073e9SAndroid Build Coastguard Worker                self.in2 = torch.nn.InstanceNorm2d(64, affine=True)
158*da0073e9SAndroid Build Coastguard Worker                self.conv3 = ConvLayer(64, 128, kernel_size=3, stride=2)
159*da0073e9SAndroid Build Coastguard Worker                self.in3 = torch.nn.InstanceNorm2d(128, affine=True)
160*da0073e9SAndroid Build Coastguard Worker                # Residual layers
161*da0073e9SAndroid Build Coastguard Worker                self.res1 = ResidualBlock(128)
162*da0073e9SAndroid Build Coastguard Worker                self.res2 = ResidualBlock(128)
163*da0073e9SAndroid Build Coastguard Worker                self.res3 = ResidualBlock(128)
164*da0073e9SAndroid Build Coastguard Worker                self.res4 = ResidualBlock(128)
165*da0073e9SAndroid Build Coastguard Worker                self.res5 = ResidualBlock(128)
166*da0073e9SAndroid Build Coastguard Worker                # Upsampling Layers
167*da0073e9SAndroid Build Coastguard Worker                self.deconv1 = UpsampleConvLayer(
168*da0073e9SAndroid Build Coastguard Worker                    128, 64, kernel_size=3, stride=1, upsample=2
169*da0073e9SAndroid Build Coastguard Worker                )
170*da0073e9SAndroid Build Coastguard Worker                self.in4 = torch.nn.InstanceNorm2d(64, affine=True)
171*da0073e9SAndroid Build Coastguard Worker                self.deconv2 = UpsampleConvLayer(
172*da0073e9SAndroid Build Coastguard Worker                    64, 32, kernel_size=3, stride=1, upsample=2
173*da0073e9SAndroid Build Coastguard Worker                )
174*da0073e9SAndroid Build Coastguard Worker                self.in5 = torch.nn.InstanceNorm2d(32, affine=True)
175*da0073e9SAndroid Build Coastguard Worker                self.deconv3 = ConvLayer(32, 3, kernel_size=9, stride=1)
176*da0073e9SAndroid Build Coastguard Worker                # Non-linearities
177*da0073e9SAndroid Build Coastguard Worker                self.relu = torch.nn.ReLU()
178*da0073e9SAndroid Build Coastguard Worker
179*da0073e9SAndroid Build Coastguard Worker            def forward(self, X):
180*da0073e9SAndroid Build Coastguard Worker                y = self.relu(self.in1(self.conv1(X)))
181*da0073e9SAndroid Build Coastguard Worker                y = self.relu(self.in2(self.conv2(y)))
182*da0073e9SAndroid Build Coastguard Worker                y = self.relu(self.in3(self.conv3(y)))
183*da0073e9SAndroid Build Coastguard Worker                y = self.res1(y)
184*da0073e9SAndroid Build Coastguard Worker                y = self.res2(y)
185*da0073e9SAndroid Build Coastguard Worker                y = self.res3(y)
186*da0073e9SAndroid Build Coastguard Worker                y = self.res4(y)
187*da0073e9SAndroid Build Coastguard Worker                y = self.res5(y)
188*da0073e9SAndroid Build Coastguard Worker                y = self.relu(self.in4(self.deconv1(y)))
189*da0073e9SAndroid Build Coastguard Worker                y = self.relu(self.in5(self.deconv2(y)))
190*da0073e9SAndroid Build Coastguard Worker                y = self.deconv3(y)
191*da0073e9SAndroid Build Coastguard Worker                return y
192*da0073e9SAndroid Build Coastguard Worker
193*da0073e9SAndroid Build Coastguard Worker        class ConvLayer(torch.nn.Module):
194*da0073e9SAndroid Build Coastguard Worker            def __init__(self, in_channels, out_channels, kernel_size, stride):
195*da0073e9SAndroid Build Coastguard Worker                super().__init__()
196*da0073e9SAndroid Build Coastguard Worker                reflection_padding = kernel_size // 2
197*da0073e9SAndroid Build Coastguard Worker                self.reflection_pad = torch.nn.ReflectionPad2d(reflection_padding)
198*da0073e9SAndroid Build Coastguard Worker                self.conv2d = torch.nn.Conv2d(
199*da0073e9SAndroid Build Coastguard Worker                    in_channels, out_channels, kernel_size, stride
200*da0073e9SAndroid Build Coastguard Worker                )
201*da0073e9SAndroid Build Coastguard Worker
202*da0073e9SAndroid Build Coastguard Worker            def forward(self, x):
203*da0073e9SAndroid Build Coastguard Worker                out = self.reflection_pad(x)
204*da0073e9SAndroid Build Coastguard Worker                out = self.conv2d(out)
205*da0073e9SAndroid Build Coastguard Worker                return out
206*da0073e9SAndroid Build Coastguard Worker
207*da0073e9SAndroid Build Coastguard Worker        class ResidualBlock(torch.nn.Module):
208*da0073e9SAndroid Build Coastguard Worker            """ResidualBlock
209*da0073e9SAndroid Build Coastguard Worker            introduced in: https://arxiv.org/abs/1512.03385
210*da0073e9SAndroid Build Coastguard Worker            recommended architecture: http://torch.ch/blog/2016/02/04/resnets.html
211*da0073e9SAndroid Build Coastguard Worker            """
212*da0073e9SAndroid Build Coastguard Worker
213*da0073e9SAndroid Build Coastguard Worker            def __init__(self, channels):
214*da0073e9SAndroid Build Coastguard Worker                super().__init__()
215*da0073e9SAndroid Build Coastguard Worker                self.conv1 = ConvLayer(channels, channels, kernel_size=3, stride=1)
216*da0073e9SAndroid Build Coastguard Worker                self.in1 = torch.nn.InstanceNorm2d(channels, affine=True)
217*da0073e9SAndroid Build Coastguard Worker                self.conv2 = ConvLayer(channels, channels, kernel_size=3, stride=1)
218*da0073e9SAndroid Build Coastguard Worker                self.in2 = torch.nn.InstanceNorm2d(channels, affine=True)
219*da0073e9SAndroid Build Coastguard Worker                self.relu = torch.nn.ReLU()
220*da0073e9SAndroid Build Coastguard Worker
221*da0073e9SAndroid Build Coastguard Worker            def forward(self, x):
222*da0073e9SAndroid Build Coastguard Worker                residual = x
223*da0073e9SAndroid Build Coastguard Worker                out = self.relu(self.in1(self.conv1(x)))
224*da0073e9SAndroid Build Coastguard Worker                out = self.in2(self.conv2(out))
225*da0073e9SAndroid Build Coastguard Worker                out = out + residual
226*da0073e9SAndroid Build Coastguard Worker                return out
227*da0073e9SAndroid Build Coastguard Worker
228*da0073e9SAndroid Build Coastguard Worker        class UpsampleConvLayer(torch.nn.Module):
229*da0073e9SAndroid Build Coastguard Worker            """UpsampleConvLayer
230*da0073e9SAndroid Build Coastguard Worker            Upsamples the input and then does a convolution. This method gives better results
231*da0073e9SAndroid Build Coastguard Worker            compared to ConvTranspose2d.
232*da0073e9SAndroid Build Coastguard Worker            ref: http://distill.pub/2016/deconv-checkerboard/
233*da0073e9SAndroid Build Coastguard Worker            """
234*da0073e9SAndroid Build Coastguard Worker
235*da0073e9SAndroid Build Coastguard Worker            def __init__(
236*da0073e9SAndroid Build Coastguard Worker                self, in_channels, out_channels, kernel_size, stride, upsample=None
237*da0073e9SAndroid Build Coastguard Worker            ):
238*da0073e9SAndroid Build Coastguard Worker                super().__init__()
239*da0073e9SAndroid Build Coastguard Worker                self.upsample = upsample
240*da0073e9SAndroid Build Coastguard Worker                if upsample:
241*da0073e9SAndroid Build Coastguard Worker                    self.upsample_layer = torch.nn.Upsample(
242*da0073e9SAndroid Build Coastguard Worker                        mode="nearest", scale_factor=upsample
243*da0073e9SAndroid Build Coastguard Worker                    )
244*da0073e9SAndroid Build Coastguard Worker                reflection_padding = kernel_size // 2
245*da0073e9SAndroid Build Coastguard Worker                self.reflection_pad = torch.nn.ReflectionPad2d(reflection_padding)
246*da0073e9SAndroid Build Coastguard Worker                self.conv2d = torch.nn.Conv2d(
247*da0073e9SAndroid Build Coastguard Worker                    in_channels, out_channels, kernel_size, stride
248*da0073e9SAndroid Build Coastguard Worker                )
249*da0073e9SAndroid Build Coastguard Worker
250*da0073e9SAndroid Build Coastguard Worker            def forward(self, x):
251*da0073e9SAndroid Build Coastguard Worker                x_in = x
252*da0073e9SAndroid Build Coastguard Worker                if self.upsample:
253*da0073e9SAndroid Build Coastguard Worker                    x_in = self.upsample_layer(x_in)
254*da0073e9SAndroid Build Coastguard Worker                out = self.reflection_pad(x_in)
255*da0073e9SAndroid Build Coastguard Worker                out = self.conv2d(out)
256*da0073e9SAndroid Build Coastguard Worker                return out
257*da0073e9SAndroid Build Coastguard Worker
258*da0073e9SAndroid Build Coastguard Worker        self.checkTrace(
259*da0073e9SAndroid Build Coastguard Worker            TransformerNet(),
260*da0073e9SAndroid Build Coastguard Worker            (torch.rand(5, 3, 16, 16),),
261*da0073e9SAndroid Build Coastguard Worker            export_import=check_export_import,
262*da0073e9SAndroid Build Coastguard Worker        )
263*da0073e9SAndroid Build Coastguard Worker
264*da0073e9SAndroid Build Coastguard Worker    @slowTest
265*da0073e9SAndroid Build Coastguard Worker    def test_neural_style(self):
266*da0073e9SAndroid Build Coastguard Worker        self._test_neural_style(self, device="cpu")
267*da0073e9SAndroid Build Coastguard Worker
268*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(not RUN_CUDA, "no CUDA")
269*da0073e9SAndroid Build Coastguard Worker    def test_neural_style_cuda(self):
270*da0073e9SAndroid Build Coastguard Worker        # XXX: export_import on CUDA modules doesn't work (#11480)
271*da0073e9SAndroid Build Coastguard Worker        self._test_neural_style(self, device="cuda", check_export_import=False)
272*da0073e9SAndroid Build Coastguard Worker
273*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(
274*da0073e9SAndroid Build Coastguard Worker        GRAPH_EXECUTOR == ProfilingMode.LEGACY, "Bug found in deprecated executor"
275*da0073e9SAndroid Build Coastguard Worker    )
276*da0073e9SAndroid Build Coastguard Worker    @staticmethod
277*da0073e9SAndroid Build Coastguard Worker    def _test_mnist(self, device, check_export_import=True):
278*da0073e9SAndroid Build Coastguard Worker        # eval() is present because dropout makes this nondeterministic
279*da0073e9SAndroid Build Coastguard Worker        with enable_profiling_mode_for_profiling_tests():
280*da0073e9SAndroid Build Coastguard Worker            self.checkTrace(
281*da0073e9SAndroid Build Coastguard Worker                MnistNet().to(device).eval(),
282*da0073e9SAndroid Build Coastguard Worker                (torch.rand(5, 1, 28, 28, device=device),),
283*da0073e9SAndroid Build Coastguard Worker                export_import=check_export_import,
284*da0073e9SAndroid Build Coastguard Worker            )
285*da0073e9SAndroid Build Coastguard Worker
286*da0073e9SAndroid Build Coastguard Worker    def test_mnist(self):
287*da0073e9SAndroid Build Coastguard Worker        self._test_mnist(self, device="cpu")
288*da0073e9SAndroid Build Coastguard Worker
289*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(not RUN_CUDA, "no CUDA")
290*da0073e9SAndroid Build Coastguard Worker    def test_mnist_cuda(self):
291*da0073e9SAndroid Build Coastguard Worker        # XXX: export_import on CUDA modules doesn't work (#11480)
292*da0073e9SAndroid Build Coastguard Worker        self._test_mnist(self, device="cuda", check_export_import=False)
293*da0073e9SAndroid Build Coastguard Worker
294*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(not RUN_CUDA, "no CUDA")
295*da0073e9SAndroid Build Coastguard Worker    def test_mnist_training_leaks_no_memory_cuda(self):
296*da0073e9SAndroid Build Coastguard Worker        net = MnistNet().cuda()
297*da0073e9SAndroid Build Coastguard Worker        # MnistNet uses dropout, don't check its trace
298*da0073e9SAndroid Build Coastguard Worker        traced_net = torch.jit.trace(
299*da0073e9SAndroid Build Coastguard Worker            net, [torch.randn(5, 1, 28, 28, device="cuda")], check_trace=False
300*da0073e9SAndroid Build Coastguard Worker        )
301*da0073e9SAndroid Build Coastguard Worker
302*da0073e9SAndroid Build Coastguard Worker        def train(iters):
303*da0073e9SAndroid Build Coastguard Worker            for _ in range(iters):
304*da0073e9SAndroid Build Coastguard Worker                # Get some fake data
305*da0073e9SAndroid Build Coastguard Worker                inp = torch.randn(5, 1, 28, 28, device="cuda")
306*da0073e9SAndroid Build Coastguard Worker                out = traced_net(inp)
307*da0073e9SAndroid Build Coastguard Worker
308*da0073e9SAndroid Build Coastguard Worker                # Here's some fake loss
309*da0073e9SAndroid Build Coastguard Worker                out.sum().backward()
310*da0073e9SAndroid Build Coastguard Worker
311*da0073e9SAndroid Build Coastguard Worker                # Zero out grads
312*da0073e9SAndroid Build Coastguard Worker                traced_net.zero_grad()
313*da0073e9SAndroid Build Coastguard Worker
314*da0073e9SAndroid Build Coastguard Worker        # Set it up so the params have .grad fields so they are not reported as leaks
315*da0073e9SAndroid Build Coastguard Worker        train(1)
316*da0073e9SAndroid Build Coastguard Worker
317*da0073e9SAndroid Build Coastguard Worker        with self.assertLeaksNoCudaTensors():
318*da0073e9SAndroid Build Coastguard Worker            train(5)
319*da0073e9SAndroid Build Coastguard Worker
320*da0073e9SAndroid Build Coastguard Worker    @staticmethod
321*da0073e9SAndroid Build Coastguard Worker    def _test_reinforcement_learning(self, device, test_export_import=True):
322*da0073e9SAndroid Build Coastguard Worker        class Policy(nn.Module):
323*da0073e9SAndroid Build Coastguard Worker            def __init__(self) -> None:
324*da0073e9SAndroid Build Coastguard Worker                super().__init__()
325*da0073e9SAndroid Build Coastguard Worker                self.affine1 = nn.Linear(4, 128)
326*da0073e9SAndroid Build Coastguard Worker                self.affine2 = nn.Linear(128, 2)
327*da0073e9SAndroid Build Coastguard Worker
328*da0073e9SAndroid Build Coastguard Worker            def forward(self, x):
329*da0073e9SAndroid Build Coastguard Worker                x = F.relu(self.affine1(x))
330*da0073e9SAndroid Build Coastguard Worker                action_scores = self.affine2(x)
331*da0073e9SAndroid Build Coastguard Worker                return F.softmax(action_scores, dim=1)
332*da0073e9SAndroid Build Coastguard Worker
333*da0073e9SAndroid Build Coastguard Worker        with enable_profiling_mode_for_profiling_tests():
334*da0073e9SAndroid Build Coastguard Worker            self.checkTrace(
335*da0073e9SAndroid Build Coastguard Worker                Policy().to(device),
336*da0073e9SAndroid Build Coastguard Worker                (torch.rand(1, 4, device=device),),
337*da0073e9SAndroid Build Coastguard Worker                export_import=test_export_import,
338*da0073e9SAndroid Build Coastguard Worker            )
339*da0073e9SAndroid Build Coastguard Worker
340*da0073e9SAndroid Build Coastguard Worker    def test_reinforcement_learning(self):
341*da0073e9SAndroid Build Coastguard Worker        self._test_reinforcement_learning(self, device="cpu")
342*da0073e9SAndroid Build Coastguard Worker
343*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(not RUN_CUDA, "no CUDA")
344*da0073e9SAndroid Build Coastguard Worker    def test_reinforcement_learning_cuda(self):
345*da0073e9SAndroid Build Coastguard Worker        # XXX: export_import on CUDA modules doesn't work (#11480)
346*da0073e9SAndroid Build Coastguard Worker        self._test_reinforcement_learning(self, device="cuda", test_export_import=False)
347*da0073e9SAndroid Build Coastguard Worker
348*da0073e9SAndroid Build Coastguard Worker    @staticmethod
349*da0073e9SAndroid Build Coastguard Worker    def _test_snli(self, device, check_export_import=True):
350*da0073e9SAndroid Build Coastguard Worker        class Bottle(nn.Module):
351*da0073e9SAndroid Build Coastguard Worker            def forward(self, input):
352*da0073e9SAndroid Build Coastguard Worker                if len(input.size()) <= 2:
353*da0073e9SAndroid Build Coastguard Worker                    return super().forward(input)
354*da0073e9SAndroid Build Coastguard Worker                size = input.size()[:2]
355*da0073e9SAndroid Build Coastguard Worker                out = super().forward(input.view(size[0] * size[1], -1))
356*da0073e9SAndroid Build Coastguard Worker                return out.view(size[0], size[1], -1)
357*da0073e9SAndroid Build Coastguard Worker
358*da0073e9SAndroid Build Coastguard Worker        class Linear(Bottle, nn.Linear):
359*da0073e9SAndroid Build Coastguard Worker            pass
360*da0073e9SAndroid Build Coastguard Worker
361*da0073e9SAndroid Build Coastguard Worker        class Encoder(nn.Module):
362*da0073e9SAndroid Build Coastguard Worker            def __init__(self, config):
363*da0073e9SAndroid Build Coastguard Worker                super().__init__()
364*da0073e9SAndroid Build Coastguard Worker                self.config = config
365*da0073e9SAndroid Build Coastguard Worker                input_size = config.d_proj if config.projection else config.d_embed
366*da0073e9SAndroid Build Coastguard Worker                dropout = 0 if config.n_layers == 1 else config.dp_ratio
367*da0073e9SAndroid Build Coastguard Worker                self.rnn = nn.LSTM(
368*da0073e9SAndroid Build Coastguard Worker                    input_size=input_size,
369*da0073e9SAndroid Build Coastguard Worker                    hidden_size=config.d_hidden,
370*da0073e9SAndroid Build Coastguard Worker                    num_layers=config.n_layers,
371*da0073e9SAndroid Build Coastguard Worker                    dropout=dropout,
372*da0073e9SAndroid Build Coastguard Worker                    bidirectional=config.birnn,
373*da0073e9SAndroid Build Coastguard Worker                )
374*da0073e9SAndroid Build Coastguard Worker
375*da0073e9SAndroid Build Coastguard Worker            def forward(self, inputs):
376*da0073e9SAndroid Build Coastguard Worker                batch_size = inputs.size()[1]
377*da0073e9SAndroid Build Coastguard Worker                state_shape = self.config.n_cells, batch_size, self.config.d_hidden
378*da0073e9SAndroid Build Coastguard Worker                h0 = c0 = inputs.new_zeros(state_shape)
379*da0073e9SAndroid Build Coastguard Worker                outputs, (ht, ct) = self.rnn(inputs, (h0, c0))
380*da0073e9SAndroid Build Coastguard Worker                return (
381*da0073e9SAndroid Build Coastguard Worker                    ht[-1]
382*da0073e9SAndroid Build Coastguard Worker                    if not self.config.birnn
383*da0073e9SAndroid Build Coastguard Worker                    else ht[-2:].transpose(0, 1).contiguous().view(batch_size, -1)
384*da0073e9SAndroid Build Coastguard Worker                )
385*da0073e9SAndroid Build Coastguard Worker
386*da0073e9SAndroid Build Coastguard Worker        class SNLIClassifier(nn.Module):
387*da0073e9SAndroid Build Coastguard Worker            def __init__(self, config):
388*da0073e9SAndroid Build Coastguard Worker                super().__init__()
389*da0073e9SAndroid Build Coastguard Worker                self.config = config
390*da0073e9SAndroid Build Coastguard Worker                self.embed = nn.Embedding(config.n_embed, config.d_embed)
391*da0073e9SAndroid Build Coastguard Worker                self.projection = Linear(config.d_embed, config.d_proj)
392*da0073e9SAndroid Build Coastguard Worker                self.encoder = Encoder(config)
393*da0073e9SAndroid Build Coastguard Worker                self.dropout = nn.Dropout(p=config.dp_ratio)
394*da0073e9SAndroid Build Coastguard Worker                self.relu = nn.ReLU()
395*da0073e9SAndroid Build Coastguard Worker                seq_in_size = 2 * config.d_hidden
396*da0073e9SAndroid Build Coastguard Worker                if self.config.birnn:
397*da0073e9SAndroid Build Coastguard Worker                    seq_in_size *= 2
398*da0073e9SAndroid Build Coastguard Worker                lin_config = [seq_in_size] * 2
399*da0073e9SAndroid Build Coastguard Worker                self.out = nn.Sequential(
400*da0073e9SAndroid Build Coastguard Worker                    Linear(*lin_config),
401*da0073e9SAndroid Build Coastguard Worker                    self.relu,
402*da0073e9SAndroid Build Coastguard Worker                    self.dropout,
403*da0073e9SAndroid Build Coastguard Worker                    Linear(*lin_config),
404*da0073e9SAndroid Build Coastguard Worker                    self.relu,
405*da0073e9SAndroid Build Coastguard Worker                    self.dropout,
406*da0073e9SAndroid Build Coastguard Worker                    Linear(*lin_config),
407*da0073e9SAndroid Build Coastguard Worker                    self.relu,
408*da0073e9SAndroid Build Coastguard Worker                    self.dropout,
409*da0073e9SAndroid Build Coastguard Worker                    Linear(seq_in_size, config.d_out),
410*da0073e9SAndroid Build Coastguard Worker                )
411*da0073e9SAndroid Build Coastguard Worker
412*da0073e9SAndroid Build Coastguard Worker            def forward(self, premise, hypothesis):
413*da0073e9SAndroid Build Coastguard Worker                prem_embed = self.embed(premise)
414*da0073e9SAndroid Build Coastguard Worker                hypo_embed = self.embed(hypothesis)
415*da0073e9SAndroid Build Coastguard Worker                if self.config.fix_emb:
416*da0073e9SAndroid Build Coastguard Worker                    prem_embed = prem_embed.detach()
417*da0073e9SAndroid Build Coastguard Worker                    hypo_embed = hypo_embed.detach()
418*da0073e9SAndroid Build Coastguard Worker                if self.config.projection:
419*da0073e9SAndroid Build Coastguard Worker                    prem_embed = self.relu(self.projection(prem_embed))
420*da0073e9SAndroid Build Coastguard Worker                    hypo_embed = self.relu(self.projection(hypo_embed))
421*da0073e9SAndroid Build Coastguard Worker                premise = self.encoder(prem_embed)
422*da0073e9SAndroid Build Coastguard Worker                hypothesis = self.encoder(hypo_embed)
423*da0073e9SAndroid Build Coastguard Worker                scores = self.out(torch.cat([premise, hypothesis], 1))
424*da0073e9SAndroid Build Coastguard Worker                return scores
425*da0073e9SAndroid Build Coastguard Worker
426*da0073e9SAndroid Build Coastguard Worker        class Config:
427*da0073e9SAndroid Build Coastguard Worker            n_embed = 100
428*da0073e9SAndroid Build Coastguard Worker            d_embed = 100
429*da0073e9SAndroid Build Coastguard Worker            d_proj = 300
430*da0073e9SAndroid Build Coastguard Worker            dp_ratio = 0.0  # For deterministic testing TODO: change by fixing seed in checkTrace?
431*da0073e9SAndroid Build Coastguard Worker            d_hidden = 30
432*da0073e9SAndroid Build Coastguard Worker            birnn = True
433*da0073e9SAndroid Build Coastguard Worker            d_out = 300
434*da0073e9SAndroid Build Coastguard Worker            fix_emb = True
435*da0073e9SAndroid Build Coastguard Worker            projection = True
436*da0073e9SAndroid Build Coastguard Worker            n_layers = 2
437*da0073e9SAndroid Build Coastguard Worker            n_cells = 4  # 2 * n_layers because birnn = True
438*da0073e9SAndroid Build Coastguard Worker
439*da0073e9SAndroid Build Coastguard Worker        premise = torch.LongTensor(48, 64).random_(0, 100).to(device)
440*da0073e9SAndroid Build Coastguard Worker        hypothesis = torch.LongTensor(24, 64).random_(0, 100).to(device)
441*da0073e9SAndroid Build Coastguard Worker
442*da0073e9SAndroid Build Coastguard Worker        self.checkTrace(
443*da0073e9SAndroid Build Coastguard Worker            SNLIClassifier(Config()).to(device),
444*da0073e9SAndroid Build Coastguard Worker            (premise, hypothesis),
445*da0073e9SAndroid Build Coastguard Worker            inputs_require_grads=False,
446*da0073e9SAndroid Build Coastguard Worker            export_import=check_export_import,
447*da0073e9SAndroid Build Coastguard Worker        )
448*da0073e9SAndroid Build Coastguard Worker
449*da0073e9SAndroid Build Coastguard Worker    @slowTest
450*da0073e9SAndroid Build Coastguard Worker    def test_snli(self):
451*da0073e9SAndroid Build Coastguard Worker        self._test_snli(self, device="cpu")
452*da0073e9SAndroid Build Coastguard Worker
453*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(not RUN_CUDA, "no CUDA")
454*da0073e9SAndroid Build Coastguard Worker    def test_snli_cuda(self):
455*da0073e9SAndroid Build Coastguard Worker        # XXX: export_import on CUDA modules doesn't work (#11480)
456*da0073e9SAndroid Build Coastguard Worker        self._test_snli(self, device="cuda", check_export_import=False)
457*da0073e9SAndroid Build Coastguard Worker
458*da0073e9SAndroid Build Coastguard Worker    @staticmethod
459*da0073e9SAndroid Build Coastguard Worker    def _test_super_resolution(self, device, check_export_import=True):
460*da0073e9SAndroid Build Coastguard Worker        class Net(nn.Module):
461*da0073e9SAndroid Build Coastguard Worker            def __init__(self, upscale_factor):
462*da0073e9SAndroid Build Coastguard Worker                super().__init__()
463*da0073e9SAndroid Build Coastguard Worker
464*da0073e9SAndroid Build Coastguard Worker                self.relu = nn.ReLU()
465*da0073e9SAndroid Build Coastguard Worker                self.conv1 = nn.Conv2d(1, 64, (5, 5), (1, 1), (2, 2))
466*da0073e9SAndroid Build Coastguard Worker                self.conv2 = nn.Conv2d(64, 64, (3, 3), (1, 1), (1, 1))
467*da0073e9SAndroid Build Coastguard Worker                self.conv3 = nn.Conv2d(64, 32, (3, 3), (1, 1), (1, 1))
468*da0073e9SAndroid Build Coastguard Worker                self.conv4 = nn.Conv2d(32, upscale_factor**2, (3, 3), (1, 1), (1, 1))
469*da0073e9SAndroid Build Coastguard Worker                self.pixel_shuffle = nn.PixelShuffle(upscale_factor)
470*da0073e9SAndroid Build Coastguard Worker
471*da0073e9SAndroid Build Coastguard Worker            def forward(self, x):
472*da0073e9SAndroid Build Coastguard Worker                x = self.relu(self.conv1(x))
473*da0073e9SAndroid Build Coastguard Worker                x = self.relu(self.conv2(x))
474*da0073e9SAndroid Build Coastguard Worker                x = self.relu(self.conv3(x))
475*da0073e9SAndroid Build Coastguard Worker                x = self.pixel_shuffle(self.conv4(x))
476*da0073e9SAndroid Build Coastguard Worker                return x
477*da0073e9SAndroid Build Coastguard Worker
478*da0073e9SAndroid Build Coastguard Worker        net = Net(upscale_factor=4).to(device)
479*da0073e9SAndroid Build Coastguard Worker        self.checkTrace(
480*da0073e9SAndroid Build Coastguard Worker            net,
481*da0073e9SAndroid Build Coastguard Worker            (torch.rand(5, 1, 32, 32, device=device),),
482*da0073e9SAndroid Build Coastguard Worker            export_import=check_export_import,
483*da0073e9SAndroid Build Coastguard Worker        )
484*da0073e9SAndroid Build Coastguard Worker
485*da0073e9SAndroid Build Coastguard Worker    @slowTest
486*da0073e9SAndroid Build Coastguard Worker    def test_super_resolution(self):
487*da0073e9SAndroid Build Coastguard Worker        self._test_super_resolution(self, device="cpu")
488*da0073e9SAndroid Build Coastguard Worker
489*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(not RUN_CUDA, "no CUDA")
490*da0073e9SAndroid Build Coastguard Worker    def test_super_resolution_cuda(self):
491*da0073e9SAndroid Build Coastguard Worker        # XXX: export_import on CUDA modules doesn't work (#11480)
492*da0073e9SAndroid Build Coastguard Worker        self._test_super_resolution(self, device="cuda", check_export_import=False)
493*da0073e9SAndroid Build Coastguard Worker
494*da0073e9SAndroid Build Coastguard Worker    @suppress_warnings
495*da0073e9SAndroid Build Coastguard Worker    def test_time_sequence_prediction(self):
496*da0073e9SAndroid Build Coastguard Worker        class Sequence(torch.jit.ScriptModule):
497*da0073e9SAndroid Build Coastguard Worker            def __init__(self) -> None:
498*da0073e9SAndroid Build Coastguard Worker                super().__init__()
499*da0073e9SAndroid Build Coastguard Worker                self.lstm1 = nn.LSTMCell(1, 51)
500*da0073e9SAndroid Build Coastguard Worker                self.lstm2 = nn.LSTMCell(51, 51)
501*da0073e9SAndroid Build Coastguard Worker                self.linear = nn.Linear(51, 1)
502*da0073e9SAndroid Build Coastguard Worker
503*da0073e9SAndroid Build Coastguard Worker            @torch.jit.script_method
504*da0073e9SAndroid Build Coastguard Worker            def forward(self, input):
505*da0073e9SAndroid Build Coastguard Worker                # TODO: add future as input with default val
506*da0073e9SAndroid Build Coastguard Worker                # see https://github.com/pytorch/pytorch/issues/8724
507*da0073e9SAndroid Build Coastguard Worker                outputs = torch.empty((3, 0))
508*da0073e9SAndroid Build Coastguard Worker                h_t = torch.zeros((3, 51))
509*da0073e9SAndroid Build Coastguard Worker                c_t = torch.zeros((3, 51))
510*da0073e9SAndroid Build Coastguard Worker                h_t2 = torch.zeros((3, 51))
511*da0073e9SAndroid Build Coastguard Worker                c_t2 = torch.zeros((3, 51))
512*da0073e9SAndroid Build Coastguard Worker
513*da0073e9SAndroid Build Coastguard Worker                output = torch.zeros([3, 51])
514*da0073e9SAndroid Build Coastguard Worker                future = 2
515*da0073e9SAndroid Build Coastguard Worker
516*da0073e9SAndroid Build Coastguard Worker                # TODO: chunk call should appear as the for loop iterable
517*da0073e9SAndroid Build Coastguard Worker                # We hard-code it to 4 for now.
518*da0073e9SAndroid Build Coastguard Worker                a, b, c, d = input.chunk(input.size(1), dim=1)
519*da0073e9SAndroid Build Coastguard Worker                for input_t in (a, b, c, d):
520*da0073e9SAndroid Build Coastguard Worker                    h_t, c_t = self.lstm1(input_t, (h_t, c_t))
521*da0073e9SAndroid Build Coastguard Worker                    h_t2, c_t2 = self.lstm2(h_t, (h_t2, c_t2))
522*da0073e9SAndroid Build Coastguard Worker                    output = self.linear(h_t2)
523*da0073e9SAndroid Build Coastguard Worker                    outputs = torch.cat((outputs, output), 1)
524*da0073e9SAndroid Build Coastguard Worker                for _ in range(future):  # if we should predict the future
525*da0073e9SAndroid Build Coastguard Worker                    h_t, c_t = self.lstm1(output, (h_t, c_t))
526*da0073e9SAndroid Build Coastguard Worker                    h_t2, c_t2 = self.lstm2(h_t, (h_t2, c_t2))
527*da0073e9SAndroid Build Coastguard Worker                    output = self.linear(h_t2)
528*da0073e9SAndroid Build Coastguard Worker                    outputs = torch.cat((outputs, output), 1)
529*da0073e9SAndroid Build Coastguard Worker                return outputs
530*da0073e9SAndroid Build Coastguard Worker
531*da0073e9SAndroid Build Coastguard Worker        class Traced(nn.Module):
532*da0073e9SAndroid Build Coastguard Worker            def __init__(self) -> None:
533*da0073e9SAndroid Build Coastguard Worker                super().__init__()
534*da0073e9SAndroid Build Coastguard Worker                self.seq = Sequence()
535*da0073e9SAndroid Build Coastguard Worker
536*da0073e9SAndroid Build Coastguard Worker            def forward(self, input):
537*da0073e9SAndroid Build Coastguard Worker                return self.seq.forward(input)
538*da0073e9SAndroid Build Coastguard Worker
539*da0073e9SAndroid Build Coastguard Worker        # disabled due to a jitter issues that will be fixed by using load/store in the compiler
540*da0073e9SAndroid Build Coastguard Worker        with torch._jit_internal._disable_emit_hooks():
541*da0073e9SAndroid Build Coastguard Worker            # TODO: toggle export_import once above issues are fixed
542*da0073e9SAndroid Build Coastguard Worker            self.checkTrace(Traced(), (torch.rand(3, 4),), export_import=False)
543*da0073e9SAndroid Build Coastguard Worker
544*da0073e9SAndroid Build Coastguard Worker    @staticmethod
545*da0073e9SAndroid Build Coastguard Worker    def _test_vae(self, device, check_export_import=True):
546*da0073e9SAndroid Build Coastguard Worker        class VAE(nn.Module):
547*da0073e9SAndroid Build Coastguard Worker            def __init__(self) -> None:
548*da0073e9SAndroid Build Coastguard Worker                super().__init__()
549*da0073e9SAndroid Build Coastguard Worker
550*da0073e9SAndroid Build Coastguard Worker                self.fc1 = nn.Linear(784, 400)
551*da0073e9SAndroid Build Coastguard Worker                self.fc21 = nn.Linear(400, 20)
552*da0073e9SAndroid Build Coastguard Worker                self.fc22 = nn.Linear(400, 20)
553*da0073e9SAndroid Build Coastguard Worker                self.fc3 = nn.Linear(20, 400)
554*da0073e9SAndroid Build Coastguard Worker                self.fc4 = nn.Linear(400, 784)
555*da0073e9SAndroid Build Coastguard Worker
556*da0073e9SAndroid Build Coastguard Worker            def encode(self, x):
557*da0073e9SAndroid Build Coastguard Worker                h1 = F.relu(self.fc1(x))
558*da0073e9SAndroid Build Coastguard Worker                return self.fc21(h1), self.fc22(h1)
559*da0073e9SAndroid Build Coastguard Worker
560*da0073e9SAndroid Build Coastguard Worker            def reparameterize(self, mu, logvar):
561*da0073e9SAndroid Build Coastguard Worker                if self.training:
562*da0073e9SAndroid Build Coastguard Worker                    std = torch.exp(0.5 * logvar)
563*da0073e9SAndroid Build Coastguard Worker                    eps = torch.randn_like(std)
564*da0073e9SAndroid Build Coastguard Worker                    return eps.mul(std).add_(mu)
565*da0073e9SAndroid Build Coastguard Worker                else:
566*da0073e9SAndroid Build Coastguard Worker                    return mu
567*da0073e9SAndroid Build Coastguard Worker
568*da0073e9SAndroid Build Coastguard Worker            def decode(self, z):
569*da0073e9SAndroid Build Coastguard Worker                h3 = F.relu(self.fc3(z))
570*da0073e9SAndroid Build Coastguard Worker                return torch.sigmoid(self.fc4(h3))
571*da0073e9SAndroid Build Coastguard Worker
572*da0073e9SAndroid Build Coastguard Worker            def forward(self, x):
573*da0073e9SAndroid Build Coastguard Worker                mu, logvar = self.encode(x.view(-1, 784))
574*da0073e9SAndroid Build Coastguard Worker                z = self.reparameterize(mu, logvar)
575*da0073e9SAndroid Build Coastguard Worker                return self.decode(z), mu, logvar
576*da0073e9SAndroid Build Coastguard Worker
577*da0073e9SAndroid Build Coastguard Worker        with enable_profiling_mode_for_profiling_tests():
578*da0073e9SAndroid Build Coastguard Worker            # eval() is present because randn_like makes this nondeterministic
579*da0073e9SAndroid Build Coastguard Worker            self.checkTrace(
580*da0073e9SAndroid Build Coastguard Worker                VAE().to(device).eval(),
581*da0073e9SAndroid Build Coastguard Worker                (torch.rand(128, 1, 28, 28, device=device),),
582*da0073e9SAndroid Build Coastguard Worker                export_import=check_export_import,
583*da0073e9SAndroid Build Coastguard Worker            )
584*da0073e9SAndroid Build Coastguard Worker
585*da0073e9SAndroid Build Coastguard Worker    def test_vae(self):
586*da0073e9SAndroid Build Coastguard Worker        self._test_vae(self, device="cpu")
587*da0073e9SAndroid Build Coastguard Worker
588*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(not RUN_CUDA, "no CUDA")
589*da0073e9SAndroid Build Coastguard Worker    def test_vae_cuda(self):
590*da0073e9SAndroid Build Coastguard Worker        # XXX: export_import on CUDA modules doesn't work (#11480)
591*da0073e9SAndroid Build Coastguard Worker        self._test_vae(self, device="cuda", check_export_import=False)
592*da0073e9SAndroid Build Coastguard Worker
593*da0073e9SAndroid Build Coastguard Worker    @slowTest
594*da0073e9SAndroid Build Coastguard Worker    @skipIfNoTorchVision
595*da0073e9SAndroid Build Coastguard Worker    def test_script_module_trace_resnet18(self):
596*da0073e9SAndroid Build Coastguard Worker        x = torch.ones(1, 3, 224, 224)
597*da0073e9SAndroid Build Coastguard Worker        m_orig = torch.jit.trace(
598*da0073e9SAndroid Build Coastguard Worker            torchvision.models.resnet18(), torch.ones(1, 3, 224, 224)
599*da0073e9SAndroid Build Coastguard Worker        )
600*da0073e9SAndroid Build Coastguard Worker        m_import = self.getExportImportCopy(m_orig)
601*da0073e9SAndroid Build Coastguard Worker
602*da0073e9SAndroid Build Coastguard Worker        input = torch.randn(1, 3, 224, 224, requires_grad=True)
603*da0073e9SAndroid Build Coastguard Worker        output_orig = m_orig(input)
604*da0073e9SAndroid Build Coastguard Worker        output_orig.sum().backward()
605*da0073e9SAndroid Build Coastguard Worker        grad_orig = input.grad.clone()
606*da0073e9SAndroid Build Coastguard Worker        input.grad.zero_()
607*da0073e9SAndroid Build Coastguard Worker
608*da0073e9SAndroid Build Coastguard Worker        output_import = m_import(input)
609*da0073e9SAndroid Build Coastguard Worker        output_import.sum().backward()
610*da0073e9SAndroid Build Coastguard Worker        grad_import = input.grad.clone()
611*da0073e9SAndroid Build Coastguard Worker
612*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(output_orig, output_import)
613*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(grad_orig, grad_import)
614*da0073e9SAndroid Build Coastguard Worker
615*da0073e9SAndroid Build Coastguard Worker    @slowTest
616*da0073e9SAndroid Build Coastguard Worker    @skipIfNoTorchVision
617*da0073e9SAndroid Build Coastguard Worker    def test_script_module_script_resnet(self):
618*da0073e9SAndroid Build Coastguard Worker        def conv1x1(in_planes, out_planes, stride=1):
619*da0073e9SAndroid Build Coastguard Worker            """1x1 convolution"""
620*da0073e9SAndroid Build Coastguard Worker            return nn.Conv2d(
621*da0073e9SAndroid Build Coastguard Worker                in_planes, out_planes, kernel_size=1, stride=stride, bias=False
622*da0073e9SAndroid Build Coastguard Worker            )
623*da0073e9SAndroid Build Coastguard Worker
624*da0073e9SAndroid Build Coastguard Worker        def conv3x3(in_planes, out_planes, stride=1):
625*da0073e9SAndroid Build Coastguard Worker            """3x3 convolution with padding"""
626*da0073e9SAndroid Build Coastguard Worker            return nn.Conv2d(
627*da0073e9SAndroid Build Coastguard Worker                in_planes,
628*da0073e9SAndroid Build Coastguard Worker                out_planes,
629*da0073e9SAndroid Build Coastguard Worker                kernel_size=3,
630*da0073e9SAndroid Build Coastguard Worker                stride=stride,
631*da0073e9SAndroid Build Coastguard Worker                padding=1,
632*da0073e9SAndroid Build Coastguard Worker                bias=False,
633*da0073e9SAndroid Build Coastguard Worker            )
634*da0073e9SAndroid Build Coastguard Worker
635*da0073e9SAndroid Build Coastguard Worker        class BasicBlock(torch.jit.ScriptModule):
636*da0073e9SAndroid Build Coastguard Worker            expansion = 1
637*da0073e9SAndroid Build Coastguard Worker            __constants__ = ["downsample"]
638*da0073e9SAndroid Build Coastguard Worker
639*da0073e9SAndroid Build Coastguard Worker            def __init__(self, inplanes, planes, stride=1, downsample=None):
640*da0073e9SAndroid Build Coastguard Worker                super().__init__()
641*da0073e9SAndroid Build Coastguard Worker                self.conv1 = conv3x3(inplanes, planes, stride)
642*da0073e9SAndroid Build Coastguard Worker                self.bn1 = nn.BatchNorm2d(planes)
643*da0073e9SAndroid Build Coastguard Worker                self.relu = nn.ReLU(inplace=True)
644*da0073e9SAndroid Build Coastguard Worker                self.conv2 = conv3x3(planes, planes)
645*da0073e9SAndroid Build Coastguard Worker                self.bn2 = nn.BatchNorm2d(planes)
646*da0073e9SAndroid Build Coastguard Worker                self.downsample = downsample
647*da0073e9SAndroid Build Coastguard Worker                self.stride = stride
648*da0073e9SAndroid Build Coastguard Worker
649*da0073e9SAndroid Build Coastguard Worker            @torch.jit.script_method
650*da0073e9SAndroid Build Coastguard Worker            def forward(self, x):
651*da0073e9SAndroid Build Coastguard Worker                residual = x
652*da0073e9SAndroid Build Coastguard Worker
653*da0073e9SAndroid Build Coastguard Worker                out = self.conv1(x)
654*da0073e9SAndroid Build Coastguard Worker                out = self.bn1(out)
655*da0073e9SAndroid Build Coastguard Worker                out = self.relu(out)
656*da0073e9SAndroid Build Coastguard Worker
657*da0073e9SAndroid Build Coastguard Worker                out = self.conv2(out)
658*da0073e9SAndroid Build Coastguard Worker                out = self.bn2(out)
659*da0073e9SAndroid Build Coastguard Worker
660*da0073e9SAndroid Build Coastguard Worker                if self.downsample is not None:
661*da0073e9SAndroid Build Coastguard Worker                    residual = self.downsample(x)
662*da0073e9SAndroid Build Coastguard Worker
663*da0073e9SAndroid Build Coastguard Worker                out += residual
664*da0073e9SAndroid Build Coastguard Worker                out = self.relu(out)
665*da0073e9SAndroid Build Coastguard Worker
666*da0073e9SAndroid Build Coastguard Worker                return out
667*da0073e9SAndroid Build Coastguard Worker
668*da0073e9SAndroid Build Coastguard Worker        class ResNet(torch.jit.ScriptModule):
669*da0073e9SAndroid Build Coastguard Worker            __constants__ = ["layer1", "layer2", "layer3", "layer4"]
670*da0073e9SAndroid Build Coastguard Worker
671*da0073e9SAndroid Build Coastguard Worker            def __init__(self, block, layers, num_classes=1000):
672*da0073e9SAndroid Build Coastguard Worker                super().__init__()
673*da0073e9SAndroid Build Coastguard Worker                self.inplanes = 64
674*da0073e9SAndroid Build Coastguard Worker                self.conv1 = nn.Conv2d(
675*da0073e9SAndroid Build Coastguard Worker                    3, 64, kernel_size=7, stride=2, padding=3, bias=False
676*da0073e9SAndroid Build Coastguard Worker                )
677*da0073e9SAndroid Build Coastguard Worker                self.bn1 = nn.BatchNorm2d(64)
678*da0073e9SAndroid Build Coastguard Worker                self.relu = nn.ReLU(inplace=True)
679*da0073e9SAndroid Build Coastguard Worker                self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
680*da0073e9SAndroid Build Coastguard Worker                self.layer1 = self._make_layer(block, 64, layers[0])
681*da0073e9SAndroid Build Coastguard Worker                self.layer2 = self._make_layer(block, 128, layers[1], stride=2)
682*da0073e9SAndroid Build Coastguard Worker                self.layer3 = self._make_layer(block, 256, layers[2], stride=2)
683*da0073e9SAndroid Build Coastguard Worker                self.layer4 = self._make_layer(block, 512, layers[3], stride=2)
684*da0073e9SAndroid Build Coastguard Worker                self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
685*da0073e9SAndroid Build Coastguard Worker                self.fc = nn.Linear(512 * block.expansion, num_classes)
686*da0073e9SAndroid Build Coastguard Worker
687*da0073e9SAndroid Build Coastguard Worker                for m in self.modules():
688*da0073e9SAndroid Build Coastguard Worker                    if isinstance(m, nn.Conv2d):
689*da0073e9SAndroid Build Coastguard Worker                        nn.init.kaiming_normal_(
690*da0073e9SAndroid Build Coastguard Worker                            m.weight, mode="fan_out", nonlinearity="relu"
691*da0073e9SAndroid Build Coastguard Worker                        )
692*da0073e9SAndroid Build Coastguard Worker                    elif isinstance(m, nn.BatchNorm2d):
693*da0073e9SAndroid Build Coastguard Worker                        nn.init.constant_(m.weight, 1)
694*da0073e9SAndroid Build Coastguard Worker                        nn.init.constant_(m.bias, 0)
695*da0073e9SAndroid Build Coastguard Worker
696*da0073e9SAndroid Build Coastguard Worker            def _make_layer(self, block, planes, blocks, stride=1):
697*da0073e9SAndroid Build Coastguard Worker                downsample = None
698*da0073e9SAndroid Build Coastguard Worker                if stride != 1 or self.inplanes != planes * block.expansion:
699*da0073e9SAndroid Build Coastguard Worker                    downsample = nn.Sequential(
700*da0073e9SAndroid Build Coastguard Worker                        conv1x1(self.inplanes, planes * block.expansion, stride),
701*da0073e9SAndroid Build Coastguard Worker                        nn.BatchNorm2d(planes * block.expansion),
702*da0073e9SAndroid Build Coastguard Worker                    )
703*da0073e9SAndroid Build Coastguard Worker
704*da0073e9SAndroid Build Coastguard Worker                layers = []
705*da0073e9SAndroid Build Coastguard Worker                layers.append(block(self.inplanes, planes, stride, downsample))
706*da0073e9SAndroid Build Coastguard Worker                self.inplanes = planes * block.expansion
707*da0073e9SAndroid Build Coastguard Worker                for _ in range(1, blocks):
708*da0073e9SAndroid Build Coastguard Worker                    layers.append(block(self.inplanes, planes))
709*da0073e9SAndroid Build Coastguard Worker
710*da0073e9SAndroid Build Coastguard Worker                return nn.Sequential(*layers)
711*da0073e9SAndroid Build Coastguard Worker
712*da0073e9SAndroid Build Coastguard Worker            @torch.jit.script_method
713*da0073e9SAndroid Build Coastguard Worker            def forward(self, x):
714*da0073e9SAndroid Build Coastguard Worker                x = self.conv1(x)
715*da0073e9SAndroid Build Coastguard Worker                x = self.bn1(x)
716*da0073e9SAndroid Build Coastguard Worker                x = self.relu(x)
717*da0073e9SAndroid Build Coastguard Worker                x = self.maxpool(x)
718*da0073e9SAndroid Build Coastguard Worker
719*da0073e9SAndroid Build Coastguard Worker                x = self.layer1(x)
720*da0073e9SAndroid Build Coastguard Worker                x = self.layer2(x)
721*da0073e9SAndroid Build Coastguard Worker                x = self.layer3(x)
722*da0073e9SAndroid Build Coastguard Worker                x = self.layer4(x)
723*da0073e9SAndroid Build Coastguard Worker
724*da0073e9SAndroid Build Coastguard Worker                x = self.avgpool(x)
725*da0073e9SAndroid Build Coastguard Worker                x = x.view(x.size(0), -1)
726*da0073e9SAndroid Build Coastguard Worker                x = self.fc(x)
727*da0073e9SAndroid Build Coastguard Worker
728*da0073e9SAndroid Build Coastguard Worker                return x
729*da0073e9SAndroid Build Coastguard Worker
730*da0073e9SAndroid Build Coastguard Worker        resnet18 = ResNet(BasicBlock, [2, 2, 2, 2])
731*da0073e9SAndroid Build Coastguard Worker
732*da0073e9SAndroid Build Coastguard Worker        resnet18_imported = self.getExportImportCopy(resnet18)
733*da0073e9SAndroid Build Coastguard Worker
734*da0073e9SAndroid Build Coastguard Worker        input = torch.randn(1, 3, 224, 224, requires_grad=True)
735*da0073e9SAndroid Build Coastguard Worker        output_orig = resnet18(input)
736*da0073e9SAndroid Build Coastguard Worker        output_orig.sum().backward()
737*da0073e9SAndroid Build Coastguard Worker        grad_orig = input.grad.clone()
738*da0073e9SAndroid Build Coastguard Worker        input.grad.zero_()
739*da0073e9SAndroid Build Coastguard Worker        output_import = resnet18_imported(input)
740*da0073e9SAndroid Build Coastguard Worker        output_import.sum().backward()
741*da0073e9SAndroid Build Coastguard Worker        grad_import = input.grad.clone()
742*da0073e9SAndroid Build Coastguard Worker
743*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(output_orig, output_import)
744*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(grad_orig, grad_import)
745*da0073e9SAndroid Build Coastguard Worker
746*da0073e9SAndroid Build Coastguard Worker    @skipIfNoTorchVision
747*da0073e9SAndroid Build Coastguard Worker    def test_alexnet(self):
748*da0073e9SAndroid Build Coastguard Worker        x = torch.ones(1, 3, 224, 224)
749*da0073e9SAndroid Build Coastguard Worker        model = torchvision.models.AlexNet()
750*da0073e9SAndroid Build Coastguard Worker        with torch.random.fork_rng(devices=[]):
751*da0073e9SAndroid Build Coastguard Worker            g, outputs, inputs = torch.jit._get_trace_graph(
752*da0073e9SAndroid Build Coastguard Worker                model, x, return_inputs=True
753*da0073e9SAndroid Build Coastguard Worker            )
754*da0073e9SAndroid Build Coastguard Worker        self.run_pass("cse", g)
755*da0073e9SAndroid Build Coastguard Worker        m = self.createFunctionFromGraph(g)
756*da0073e9SAndroid Build Coastguard Worker        with torch.random.fork_rng(devices=[]):
757*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(outputs, m(*inputs))
758