1*da0073e9SAndroid Build Coastguard Worker# Owner(s): ["oncall: jit"] 2*da0073e9SAndroid Build Coastguard Worker 3*da0073e9SAndroid Build Coastguard Workerimport os 4*da0073e9SAndroid Build Coastguard Workerimport sys 5*da0073e9SAndroid Build Coastguard Workerimport unittest 6*da0073e9SAndroid Build Coastguard Worker 7*da0073e9SAndroid Build Coastguard Workerimport torch 8*da0073e9SAndroid Build Coastguard Workerimport torch.nn as nn 9*da0073e9SAndroid Build Coastguard Workerimport torch.nn.functional as F 10*da0073e9SAndroid Build Coastguard Workerfrom torch.testing._internal.common_utils import ( 11*da0073e9SAndroid Build Coastguard Worker enable_profiling_mode_for_profiling_tests, 12*da0073e9SAndroid Build Coastguard Worker GRAPH_EXECUTOR, 13*da0073e9SAndroid Build Coastguard Worker ProfilingMode, 14*da0073e9SAndroid Build Coastguard Worker set_default_dtype, 15*da0073e9SAndroid Build Coastguard Worker) 16*da0073e9SAndroid Build Coastguard Worker 17*da0073e9SAndroid Build Coastguard Worker 18*da0073e9SAndroid Build Coastguard Worker# Make the helper files in test/ importable 19*da0073e9SAndroid Build Coastguard Workerpytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__))) 20*da0073e9SAndroid Build Coastguard Workersys.path.append(pytorch_test_dir) 21*da0073e9SAndroid Build Coastguard Workerfrom torch.testing._internal.common_utils import slowTest, suppress_warnings 22*da0073e9SAndroid Build Coastguard Workerfrom torch.testing._internal.jit_utils import JitTestCase, RUN_CUDA 23*da0073e9SAndroid Build Coastguard Worker 24*da0073e9SAndroid Build Coastguard Worker 25*da0073e9SAndroid Build Coastguard Workerif __name__ == "__main__": 26*da0073e9SAndroid Build Coastguard Worker raise RuntimeError( 27*da0073e9SAndroid Build Coastguard Worker "This test file is not meant to be run directly, use:\n\n" 28*da0073e9SAndroid Build Coastguard Worker "\tpython test/test_jit.py TESTNAME\n\n" 29*da0073e9SAndroid Build Coastguard Worker "instead." 30*da0073e9SAndroid Build Coastguard Worker ) 31*da0073e9SAndroid Build Coastguard Worker 32*da0073e9SAndroid Build Coastguard Workertry: 33*da0073e9SAndroid Build Coastguard Worker import torchvision 34*da0073e9SAndroid Build Coastguard Worker 35*da0073e9SAndroid Build Coastguard Worker HAS_TORCHVISION = True 36*da0073e9SAndroid Build Coastguard Workerexcept ImportError: 37*da0073e9SAndroid Build Coastguard Worker HAS_TORCHVISION = False 38*da0073e9SAndroid Build Coastguard Workerexcept RuntimeError: 39*da0073e9SAndroid Build Coastguard Worker HAS_TORCHVISION = False 40*da0073e9SAndroid Build Coastguard WorkerskipIfNoTorchVision = unittest.skipIf(not HAS_TORCHVISION, "no torchvision") 41*da0073e9SAndroid Build Coastguard Worker 42*da0073e9SAndroid Build Coastguard Worker 43*da0073e9SAndroid Build Coastguard Workerclass MnistNet(nn.Module): 44*da0073e9SAndroid Build Coastguard Worker def __init__(self) -> None: 45*da0073e9SAndroid Build Coastguard Worker super().__init__() 46*da0073e9SAndroid Build Coastguard Worker self.conv1 = nn.Conv2d(1, 10, kernel_size=5) 47*da0073e9SAndroid Build Coastguard Worker self.conv2 = nn.Conv2d(10, 20, kernel_size=5) 48*da0073e9SAndroid Build Coastguard Worker self.conv2_drop = nn.Dropout2d() 49*da0073e9SAndroid Build Coastguard Worker self.fc1 = nn.Linear(320, 50) 50*da0073e9SAndroid Build Coastguard Worker self.fc2 = nn.Linear(50, 10) 51*da0073e9SAndroid Build Coastguard Worker 52*da0073e9SAndroid Build Coastguard Worker def forward(self, x): 53*da0073e9SAndroid Build Coastguard Worker x = F.relu(F.max_pool2d(self.conv1(x), 2)) 54*da0073e9SAndroid Build Coastguard Worker x = F.relu(F.max_pool2d(self.conv2_drop(self.conv2(x)), 2)) 55*da0073e9SAndroid Build Coastguard Worker x = x.reshape(-1, 320) 56*da0073e9SAndroid Build Coastguard Worker x = F.relu(self.fc1(x)) 57*da0073e9SAndroid Build Coastguard Worker x = F.dropout(x, training=self.training) 58*da0073e9SAndroid Build Coastguard Worker x = self.fc2(x) 59*da0073e9SAndroid Build Coastguard Worker return F.log_softmax(x, dim=1) 60*da0073e9SAndroid Build Coastguard Worker 61*da0073e9SAndroid Build Coastguard Worker 62*da0073e9SAndroid Build Coastguard Workerclass TestModels(JitTestCase): 63*da0073e9SAndroid Build Coastguard Worker @staticmethod 64*da0073e9SAndroid Build Coastguard Worker def _test_dcgan_models(self, device, check_export_import=True): 65*da0073e9SAndroid Build Coastguard Worker class DCGANGenerator(nn.Module): 66*da0073e9SAndroid Build Coastguard Worker def __init__(self, nz, ngf, nc): 67*da0073e9SAndroid Build Coastguard Worker super().__init__() 68*da0073e9SAndroid Build Coastguard Worker self.main = nn.Sequential( 69*da0073e9SAndroid Build Coastguard Worker # input is Z, going into a convolution 70*da0073e9SAndroid Build Coastguard Worker nn.ConvTranspose2d(nz, ngf * 8, 4, 1, 0, bias=False), 71*da0073e9SAndroid Build Coastguard Worker nn.BatchNorm2d(ngf * 8), 72*da0073e9SAndroid Build Coastguard Worker nn.ReLU(True), 73*da0073e9SAndroid Build Coastguard Worker # state size. (ngf*8) x 4 x 4 74*da0073e9SAndroid Build Coastguard Worker nn.ConvTranspose2d(ngf * 8, ngf * 4, 4, 2, 1, bias=False), 75*da0073e9SAndroid Build Coastguard Worker nn.BatchNorm2d(ngf * 4), 76*da0073e9SAndroid Build Coastguard Worker nn.ReLU(True), 77*da0073e9SAndroid Build Coastguard Worker # state size. (ngf*4) x 8 x 8 78*da0073e9SAndroid Build Coastguard Worker nn.ConvTranspose2d(ngf * 4, ngf * 2, 4, 2, 1, bias=False), 79*da0073e9SAndroid Build Coastguard Worker nn.BatchNorm2d(ngf * 2), 80*da0073e9SAndroid Build Coastguard Worker nn.ReLU(True), 81*da0073e9SAndroid Build Coastguard Worker # state size. (ngf*2) x 16 x 16 82*da0073e9SAndroid Build Coastguard Worker nn.ConvTranspose2d(ngf * 2, ngf, 4, 2, 1, bias=False), 83*da0073e9SAndroid Build Coastguard Worker nn.BatchNorm2d(ngf), 84*da0073e9SAndroid Build Coastguard Worker nn.ReLU(True), 85*da0073e9SAndroid Build Coastguard Worker # state size. (ngf) x 32 x 32 86*da0073e9SAndroid Build Coastguard Worker nn.ConvTranspose2d(ngf, nc, 4, 2, 1, bias=False), 87*da0073e9SAndroid Build Coastguard Worker nn.Tanh() 88*da0073e9SAndroid Build Coastguard Worker # state size. (nc) x 64 x 64 89*da0073e9SAndroid Build Coastguard Worker ) 90*da0073e9SAndroid Build Coastguard Worker 91*da0073e9SAndroid Build Coastguard Worker def forward(self, input): 92*da0073e9SAndroid Build Coastguard Worker return self.main(input) 93*da0073e9SAndroid Build Coastguard Worker 94*da0073e9SAndroid Build Coastguard Worker class DCGANDiscriminator(nn.Module): 95*da0073e9SAndroid Build Coastguard Worker def __init__(self, nc, ndf): 96*da0073e9SAndroid Build Coastguard Worker super().__init__() 97*da0073e9SAndroid Build Coastguard Worker self.main = nn.Sequential( 98*da0073e9SAndroid Build Coastguard Worker # input is (nc) x 64 x 64 99*da0073e9SAndroid Build Coastguard Worker nn.Conv2d(nc, ndf, 4, 2, 1, bias=False), 100*da0073e9SAndroid Build Coastguard Worker nn.LeakyReLU(0.2, inplace=True), 101*da0073e9SAndroid Build Coastguard Worker # state size. (ndf) x 32 x 32 102*da0073e9SAndroid Build Coastguard Worker nn.Conv2d(ndf, ndf * 2, 4, 2, 1, bias=False), 103*da0073e9SAndroid Build Coastguard Worker nn.BatchNorm2d(ndf * 2), 104*da0073e9SAndroid Build Coastguard Worker nn.LeakyReLU(0.2, inplace=True), 105*da0073e9SAndroid Build Coastguard Worker # state size. (ndf*2) x 16 x 16 106*da0073e9SAndroid Build Coastguard Worker nn.Conv2d(ndf * 2, ndf * 4, 4, 2, 1, bias=False), 107*da0073e9SAndroid Build Coastguard Worker nn.BatchNorm2d(ndf * 4), 108*da0073e9SAndroid Build Coastguard Worker nn.LeakyReLU(0.2, inplace=True), 109*da0073e9SAndroid Build Coastguard Worker # state size. (ndf*4) x 8 x 8 110*da0073e9SAndroid Build Coastguard Worker nn.Conv2d(ndf * 4, ndf * 8, 4, 2, 1, bias=False), 111*da0073e9SAndroid Build Coastguard Worker nn.BatchNorm2d(ndf * 8), 112*da0073e9SAndroid Build Coastguard Worker nn.LeakyReLU(0.2, inplace=True), 113*da0073e9SAndroid Build Coastguard Worker # state size. (ndf*8) x 4 x 4 114*da0073e9SAndroid Build Coastguard Worker nn.Conv2d(ndf * 8, 1, 4, 1, 0, bias=False), 115*da0073e9SAndroid Build Coastguard Worker nn.Sigmoid(), 116*da0073e9SAndroid Build Coastguard Worker ) 117*da0073e9SAndroid Build Coastguard Worker 118*da0073e9SAndroid Build Coastguard Worker def forward(self, input): 119*da0073e9SAndroid Build Coastguard Worker return self.main(input).view(-1, 1).squeeze(1) 120*da0073e9SAndroid Build Coastguard Worker 121*da0073e9SAndroid Build Coastguard Worker bs, nz, ngf, nc, ndf = 5, 6, 9, 3, 10 122*da0073e9SAndroid Build Coastguard Worker self.checkTrace( 123*da0073e9SAndroid Build Coastguard Worker DCGANGenerator(nz, ngf, nc).to(device), 124*da0073e9SAndroid Build Coastguard Worker (torch.rand(bs, nz, 1, 1, device=device),), 125*da0073e9SAndroid Build Coastguard Worker export_import=check_export_import, 126*da0073e9SAndroid Build Coastguard Worker ) 127*da0073e9SAndroid Build Coastguard Worker example_input = DCGANGenerator(nz, ngf, nc).to(device)( 128*da0073e9SAndroid Build Coastguard Worker torch.rand(bs, nz, 1, 1, device=device) 129*da0073e9SAndroid Build Coastguard Worker ) 130*da0073e9SAndroid Build Coastguard Worker self.checkTrace( 131*da0073e9SAndroid Build Coastguard Worker DCGANDiscriminator(nc, ndf).to(device), 132*da0073e9SAndroid Build Coastguard Worker (example_input,), 133*da0073e9SAndroid Build Coastguard Worker export_import=check_export_import, 134*da0073e9SAndroid Build Coastguard Worker ) 135*da0073e9SAndroid Build Coastguard Worker 136*da0073e9SAndroid Build Coastguard Worker def test_dcgan_models(self): 137*da0073e9SAndroid Build Coastguard Worker # Note: Can sometimes fail with low precision if run with float dtype 138*da0073e9SAndroid Build Coastguard Worker with set_default_dtype(torch.double): 139*da0073e9SAndroid Build Coastguard Worker self._test_dcgan_models(self, device="cpu") 140*da0073e9SAndroid Build Coastguard Worker 141*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf(not RUN_CUDA, "no CUDA") 142*da0073e9SAndroid Build Coastguard Worker def test_dcgan_models_cuda(self): 143*da0073e9SAndroid Build Coastguard Worker # Note: Can sometimes fail with low precision if run with float dtype 144*da0073e9SAndroid Build Coastguard Worker with set_default_dtype(torch.double): 145*da0073e9SAndroid Build Coastguard Worker # XXX: export_import on CUDA modules doesn't work (#11480) 146*da0073e9SAndroid Build Coastguard Worker self._test_dcgan_models(self, device="cuda", check_export_import=False) 147*da0073e9SAndroid Build Coastguard Worker 148*da0073e9SAndroid Build Coastguard Worker @staticmethod 149*da0073e9SAndroid Build Coastguard Worker def _test_neural_style(self, device, check_export_import=True): 150*da0073e9SAndroid Build Coastguard Worker class TransformerNet(torch.nn.Module): 151*da0073e9SAndroid Build Coastguard Worker def __init__(self) -> None: 152*da0073e9SAndroid Build Coastguard Worker super().__init__() 153*da0073e9SAndroid Build Coastguard Worker # Initial convolution layers 154*da0073e9SAndroid Build Coastguard Worker self.conv1 = ConvLayer(3, 32, kernel_size=9, stride=1) 155*da0073e9SAndroid Build Coastguard Worker self.in1 = torch.nn.InstanceNorm2d(32, affine=True) 156*da0073e9SAndroid Build Coastguard Worker self.conv2 = ConvLayer(32, 64, kernel_size=3, stride=2) 157*da0073e9SAndroid Build Coastguard Worker self.in2 = torch.nn.InstanceNorm2d(64, affine=True) 158*da0073e9SAndroid Build Coastguard Worker self.conv3 = ConvLayer(64, 128, kernel_size=3, stride=2) 159*da0073e9SAndroid Build Coastguard Worker self.in3 = torch.nn.InstanceNorm2d(128, affine=True) 160*da0073e9SAndroid Build Coastguard Worker # Residual layers 161*da0073e9SAndroid Build Coastguard Worker self.res1 = ResidualBlock(128) 162*da0073e9SAndroid Build Coastguard Worker self.res2 = ResidualBlock(128) 163*da0073e9SAndroid Build Coastguard Worker self.res3 = ResidualBlock(128) 164*da0073e9SAndroid Build Coastguard Worker self.res4 = ResidualBlock(128) 165*da0073e9SAndroid Build Coastguard Worker self.res5 = ResidualBlock(128) 166*da0073e9SAndroid Build Coastguard Worker # Upsampling Layers 167*da0073e9SAndroid Build Coastguard Worker self.deconv1 = UpsampleConvLayer( 168*da0073e9SAndroid Build Coastguard Worker 128, 64, kernel_size=3, stride=1, upsample=2 169*da0073e9SAndroid Build Coastguard Worker ) 170*da0073e9SAndroid Build Coastguard Worker self.in4 = torch.nn.InstanceNorm2d(64, affine=True) 171*da0073e9SAndroid Build Coastguard Worker self.deconv2 = UpsampleConvLayer( 172*da0073e9SAndroid Build Coastguard Worker 64, 32, kernel_size=3, stride=1, upsample=2 173*da0073e9SAndroid Build Coastguard Worker ) 174*da0073e9SAndroid Build Coastguard Worker self.in5 = torch.nn.InstanceNorm2d(32, affine=True) 175*da0073e9SAndroid Build Coastguard Worker self.deconv3 = ConvLayer(32, 3, kernel_size=9, stride=1) 176*da0073e9SAndroid Build Coastguard Worker # Non-linearities 177*da0073e9SAndroid Build Coastguard Worker self.relu = torch.nn.ReLU() 178*da0073e9SAndroid Build Coastguard Worker 179*da0073e9SAndroid Build Coastguard Worker def forward(self, X): 180*da0073e9SAndroid Build Coastguard Worker y = self.relu(self.in1(self.conv1(X))) 181*da0073e9SAndroid Build Coastguard Worker y = self.relu(self.in2(self.conv2(y))) 182*da0073e9SAndroid Build Coastguard Worker y = self.relu(self.in3(self.conv3(y))) 183*da0073e9SAndroid Build Coastguard Worker y = self.res1(y) 184*da0073e9SAndroid Build Coastguard Worker y = self.res2(y) 185*da0073e9SAndroid Build Coastguard Worker y = self.res3(y) 186*da0073e9SAndroid Build Coastguard Worker y = self.res4(y) 187*da0073e9SAndroid Build Coastguard Worker y = self.res5(y) 188*da0073e9SAndroid Build Coastguard Worker y = self.relu(self.in4(self.deconv1(y))) 189*da0073e9SAndroid Build Coastguard Worker y = self.relu(self.in5(self.deconv2(y))) 190*da0073e9SAndroid Build Coastguard Worker y = self.deconv3(y) 191*da0073e9SAndroid Build Coastguard Worker return y 192*da0073e9SAndroid Build Coastguard Worker 193*da0073e9SAndroid Build Coastguard Worker class ConvLayer(torch.nn.Module): 194*da0073e9SAndroid Build Coastguard Worker def __init__(self, in_channels, out_channels, kernel_size, stride): 195*da0073e9SAndroid Build Coastguard Worker super().__init__() 196*da0073e9SAndroid Build Coastguard Worker reflection_padding = kernel_size // 2 197*da0073e9SAndroid Build Coastguard Worker self.reflection_pad = torch.nn.ReflectionPad2d(reflection_padding) 198*da0073e9SAndroid Build Coastguard Worker self.conv2d = torch.nn.Conv2d( 199*da0073e9SAndroid Build Coastguard Worker in_channels, out_channels, kernel_size, stride 200*da0073e9SAndroid Build Coastguard Worker ) 201*da0073e9SAndroid Build Coastguard Worker 202*da0073e9SAndroid Build Coastguard Worker def forward(self, x): 203*da0073e9SAndroid Build Coastguard Worker out = self.reflection_pad(x) 204*da0073e9SAndroid Build Coastguard Worker out = self.conv2d(out) 205*da0073e9SAndroid Build Coastguard Worker return out 206*da0073e9SAndroid Build Coastguard Worker 207*da0073e9SAndroid Build Coastguard Worker class ResidualBlock(torch.nn.Module): 208*da0073e9SAndroid Build Coastguard Worker """ResidualBlock 209*da0073e9SAndroid Build Coastguard Worker introduced in: https://arxiv.org/abs/1512.03385 210*da0073e9SAndroid Build Coastguard Worker recommended architecture: http://torch.ch/blog/2016/02/04/resnets.html 211*da0073e9SAndroid Build Coastguard Worker """ 212*da0073e9SAndroid Build Coastguard Worker 213*da0073e9SAndroid Build Coastguard Worker def __init__(self, channels): 214*da0073e9SAndroid Build Coastguard Worker super().__init__() 215*da0073e9SAndroid Build Coastguard Worker self.conv1 = ConvLayer(channels, channels, kernel_size=3, stride=1) 216*da0073e9SAndroid Build Coastguard Worker self.in1 = torch.nn.InstanceNorm2d(channels, affine=True) 217*da0073e9SAndroid Build Coastguard Worker self.conv2 = ConvLayer(channels, channels, kernel_size=3, stride=1) 218*da0073e9SAndroid Build Coastguard Worker self.in2 = torch.nn.InstanceNorm2d(channels, affine=True) 219*da0073e9SAndroid Build Coastguard Worker self.relu = torch.nn.ReLU() 220*da0073e9SAndroid Build Coastguard Worker 221*da0073e9SAndroid Build Coastguard Worker def forward(self, x): 222*da0073e9SAndroid Build Coastguard Worker residual = x 223*da0073e9SAndroid Build Coastguard Worker out = self.relu(self.in1(self.conv1(x))) 224*da0073e9SAndroid Build Coastguard Worker out = self.in2(self.conv2(out)) 225*da0073e9SAndroid Build Coastguard Worker out = out + residual 226*da0073e9SAndroid Build Coastguard Worker return out 227*da0073e9SAndroid Build Coastguard Worker 228*da0073e9SAndroid Build Coastguard Worker class UpsampleConvLayer(torch.nn.Module): 229*da0073e9SAndroid Build Coastguard Worker """UpsampleConvLayer 230*da0073e9SAndroid Build Coastguard Worker Upsamples the input and then does a convolution. This method gives better results 231*da0073e9SAndroid Build Coastguard Worker compared to ConvTranspose2d. 232*da0073e9SAndroid Build Coastguard Worker ref: http://distill.pub/2016/deconv-checkerboard/ 233*da0073e9SAndroid Build Coastguard Worker """ 234*da0073e9SAndroid Build Coastguard Worker 235*da0073e9SAndroid Build Coastguard Worker def __init__( 236*da0073e9SAndroid Build Coastguard Worker self, in_channels, out_channels, kernel_size, stride, upsample=None 237*da0073e9SAndroid Build Coastguard Worker ): 238*da0073e9SAndroid Build Coastguard Worker super().__init__() 239*da0073e9SAndroid Build Coastguard Worker self.upsample = upsample 240*da0073e9SAndroid Build Coastguard Worker if upsample: 241*da0073e9SAndroid Build Coastguard Worker self.upsample_layer = torch.nn.Upsample( 242*da0073e9SAndroid Build Coastguard Worker mode="nearest", scale_factor=upsample 243*da0073e9SAndroid Build Coastguard Worker ) 244*da0073e9SAndroid Build Coastguard Worker reflection_padding = kernel_size // 2 245*da0073e9SAndroid Build Coastguard Worker self.reflection_pad = torch.nn.ReflectionPad2d(reflection_padding) 246*da0073e9SAndroid Build Coastguard Worker self.conv2d = torch.nn.Conv2d( 247*da0073e9SAndroid Build Coastguard Worker in_channels, out_channels, kernel_size, stride 248*da0073e9SAndroid Build Coastguard Worker ) 249*da0073e9SAndroid Build Coastguard Worker 250*da0073e9SAndroid Build Coastguard Worker def forward(self, x): 251*da0073e9SAndroid Build Coastguard Worker x_in = x 252*da0073e9SAndroid Build Coastguard Worker if self.upsample: 253*da0073e9SAndroid Build Coastguard Worker x_in = self.upsample_layer(x_in) 254*da0073e9SAndroid Build Coastguard Worker out = self.reflection_pad(x_in) 255*da0073e9SAndroid Build Coastguard Worker out = self.conv2d(out) 256*da0073e9SAndroid Build Coastguard Worker return out 257*da0073e9SAndroid Build Coastguard Worker 258*da0073e9SAndroid Build Coastguard Worker self.checkTrace( 259*da0073e9SAndroid Build Coastguard Worker TransformerNet(), 260*da0073e9SAndroid Build Coastguard Worker (torch.rand(5, 3, 16, 16),), 261*da0073e9SAndroid Build Coastguard Worker export_import=check_export_import, 262*da0073e9SAndroid Build Coastguard Worker ) 263*da0073e9SAndroid Build Coastguard Worker 264*da0073e9SAndroid Build Coastguard Worker @slowTest 265*da0073e9SAndroid Build Coastguard Worker def test_neural_style(self): 266*da0073e9SAndroid Build Coastguard Worker self._test_neural_style(self, device="cpu") 267*da0073e9SAndroid Build Coastguard Worker 268*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf(not RUN_CUDA, "no CUDA") 269*da0073e9SAndroid Build Coastguard Worker def test_neural_style_cuda(self): 270*da0073e9SAndroid Build Coastguard Worker # XXX: export_import on CUDA modules doesn't work (#11480) 271*da0073e9SAndroid Build Coastguard Worker self._test_neural_style(self, device="cuda", check_export_import=False) 272*da0073e9SAndroid Build Coastguard Worker 273*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf( 274*da0073e9SAndroid Build Coastguard Worker GRAPH_EXECUTOR == ProfilingMode.LEGACY, "Bug found in deprecated executor" 275*da0073e9SAndroid Build Coastguard Worker ) 276*da0073e9SAndroid Build Coastguard Worker @staticmethod 277*da0073e9SAndroid Build Coastguard Worker def _test_mnist(self, device, check_export_import=True): 278*da0073e9SAndroid Build Coastguard Worker # eval() is present because dropout makes this nondeterministic 279*da0073e9SAndroid Build Coastguard Worker with enable_profiling_mode_for_profiling_tests(): 280*da0073e9SAndroid Build Coastguard Worker self.checkTrace( 281*da0073e9SAndroid Build Coastguard Worker MnistNet().to(device).eval(), 282*da0073e9SAndroid Build Coastguard Worker (torch.rand(5, 1, 28, 28, device=device),), 283*da0073e9SAndroid Build Coastguard Worker export_import=check_export_import, 284*da0073e9SAndroid Build Coastguard Worker ) 285*da0073e9SAndroid Build Coastguard Worker 286*da0073e9SAndroid Build Coastguard Worker def test_mnist(self): 287*da0073e9SAndroid Build Coastguard Worker self._test_mnist(self, device="cpu") 288*da0073e9SAndroid Build Coastguard Worker 289*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf(not RUN_CUDA, "no CUDA") 290*da0073e9SAndroid Build Coastguard Worker def test_mnist_cuda(self): 291*da0073e9SAndroid Build Coastguard Worker # XXX: export_import on CUDA modules doesn't work (#11480) 292*da0073e9SAndroid Build Coastguard Worker self._test_mnist(self, device="cuda", check_export_import=False) 293*da0073e9SAndroid Build Coastguard Worker 294*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf(not RUN_CUDA, "no CUDA") 295*da0073e9SAndroid Build Coastguard Worker def test_mnist_training_leaks_no_memory_cuda(self): 296*da0073e9SAndroid Build Coastguard Worker net = MnistNet().cuda() 297*da0073e9SAndroid Build Coastguard Worker # MnistNet uses dropout, don't check its trace 298*da0073e9SAndroid Build Coastguard Worker traced_net = torch.jit.trace( 299*da0073e9SAndroid Build Coastguard Worker net, [torch.randn(5, 1, 28, 28, device="cuda")], check_trace=False 300*da0073e9SAndroid Build Coastguard Worker ) 301*da0073e9SAndroid Build Coastguard Worker 302*da0073e9SAndroid Build Coastguard Worker def train(iters): 303*da0073e9SAndroid Build Coastguard Worker for _ in range(iters): 304*da0073e9SAndroid Build Coastguard Worker # Get some fake data 305*da0073e9SAndroid Build Coastguard Worker inp = torch.randn(5, 1, 28, 28, device="cuda") 306*da0073e9SAndroid Build Coastguard Worker out = traced_net(inp) 307*da0073e9SAndroid Build Coastguard Worker 308*da0073e9SAndroid Build Coastguard Worker # Here's some fake loss 309*da0073e9SAndroid Build Coastguard Worker out.sum().backward() 310*da0073e9SAndroid Build Coastguard Worker 311*da0073e9SAndroid Build Coastguard Worker # Zero out grads 312*da0073e9SAndroid Build Coastguard Worker traced_net.zero_grad() 313*da0073e9SAndroid Build Coastguard Worker 314*da0073e9SAndroid Build Coastguard Worker # Set it up so the params have .grad fields so they are not reported as leaks 315*da0073e9SAndroid Build Coastguard Worker train(1) 316*da0073e9SAndroid Build Coastguard Worker 317*da0073e9SAndroid Build Coastguard Worker with self.assertLeaksNoCudaTensors(): 318*da0073e9SAndroid Build Coastguard Worker train(5) 319*da0073e9SAndroid Build Coastguard Worker 320*da0073e9SAndroid Build Coastguard Worker @staticmethod 321*da0073e9SAndroid Build Coastguard Worker def _test_reinforcement_learning(self, device, test_export_import=True): 322*da0073e9SAndroid Build Coastguard Worker class Policy(nn.Module): 323*da0073e9SAndroid Build Coastguard Worker def __init__(self) -> None: 324*da0073e9SAndroid Build Coastguard Worker super().__init__() 325*da0073e9SAndroid Build Coastguard Worker self.affine1 = nn.Linear(4, 128) 326*da0073e9SAndroid Build Coastguard Worker self.affine2 = nn.Linear(128, 2) 327*da0073e9SAndroid Build Coastguard Worker 328*da0073e9SAndroid Build Coastguard Worker def forward(self, x): 329*da0073e9SAndroid Build Coastguard Worker x = F.relu(self.affine1(x)) 330*da0073e9SAndroid Build Coastguard Worker action_scores = self.affine2(x) 331*da0073e9SAndroid Build Coastguard Worker return F.softmax(action_scores, dim=1) 332*da0073e9SAndroid Build Coastguard Worker 333*da0073e9SAndroid Build Coastguard Worker with enable_profiling_mode_for_profiling_tests(): 334*da0073e9SAndroid Build Coastguard Worker self.checkTrace( 335*da0073e9SAndroid Build Coastguard Worker Policy().to(device), 336*da0073e9SAndroid Build Coastguard Worker (torch.rand(1, 4, device=device),), 337*da0073e9SAndroid Build Coastguard Worker export_import=test_export_import, 338*da0073e9SAndroid Build Coastguard Worker ) 339*da0073e9SAndroid Build Coastguard Worker 340*da0073e9SAndroid Build Coastguard Worker def test_reinforcement_learning(self): 341*da0073e9SAndroid Build Coastguard Worker self._test_reinforcement_learning(self, device="cpu") 342*da0073e9SAndroid Build Coastguard Worker 343*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf(not RUN_CUDA, "no CUDA") 344*da0073e9SAndroid Build Coastguard Worker def test_reinforcement_learning_cuda(self): 345*da0073e9SAndroid Build Coastguard Worker # XXX: export_import on CUDA modules doesn't work (#11480) 346*da0073e9SAndroid Build Coastguard Worker self._test_reinforcement_learning(self, device="cuda", test_export_import=False) 347*da0073e9SAndroid Build Coastguard Worker 348*da0073e9SAndroid Build Coastguard Worker @staticmethod 349*da0073e9SAndroid Build Coastguard Worker def _test_snli(self, device, check_export_import=True): 350*da0073e9SAndroid Build Coastguard Worker class Bottle(nn.Module): 351*da0073e9SAndroid Build Coastguard Worker def forward(self, input): 352*da0073e9SAndroid Build Coastguard Worker if len(input.size()) <= 2: 353*da0073e9SAndroid Build Coastguard Worker return super().forward(input) 354*da0073e9SAndroid Build Coastguard Worker size = input.size()[:2] 355*da0073e9SAndroid Build Coastguard Worker out = super().forward(input.view(size[0] * size[1], -1)) 356*da0073e9SAndroid Build Coastguard Worker return out.view(size[0], size[1], -1) 357*da0073e9SAndroid Build Coastguard Worker 358*da0073e9SAndroid Build Coastguard Worker class Linear(Bottle, nn.Linear): 359*da0073e9SAndroid Build Coastguard Worker pass 360*da0073e9SAndroid Build Coastguard Worker 361*da0073e9SAndroid Build Coastguard Worker class Encoder(nn.Module): 362*da0073e9SAndroid Build Coastguard Worker def __init__(self, config): 363*da0073e9SAndroid Build Coastguard Worker super().__init__() 364*da0073e9SAndroid Build Coastguard Worker self.config = config 365*da0073e9SAndroid Build Coastguard Worker input_size = config.d_proj if config.projection else config.d_embed 366*da0073e9SAndroid Build Coastguard Worker dropout = 0 if config.n_layers == 1 else config.dp_ratio 367*da0073e9SAndroid Build Coastguard Worker self.rnn = nn.LSTM( 368*da0073e9SAndroid Build Coastguard Worker input_size=input_size, 369*da0073e9SAndroid Build Coastguard Worker hidden_size=config.d_hidden, 370*da0073e9SAndroid Build Coastguard Worker num_layers=config.n_layers, 371*da0073e9SAndroid Build Coastguard Worker dropout=dropout, 372*da0073e9SAndroid Build Coastguard Worker bidirectional=config.birnn, 373*da0073e9SAndroid Build Coastguard Worker ) 374*da0073e9SAndroid Build Coastguard Worker 375*da0073e9SAndroid Build Coastguard Worker def forward(self, inputs): 376*da0073e9SAndroid Build Coastguard Worker batch_size = inputs.size()[1] 377*da0073e9SAndroid Build Coastguard Worker state_shape = self.config.n_cells, batch_size, self.config.d_hidden 378*da0073e9SAndroid Build Coastguard Worker h0 = c0 = inputs.new_zeros(state_shape) 379*da0073e9SAndroid Build Coastguard Worker outputs, (ht, ct) = self.rnn(inputs, (h0, c0)) 380*da0073e9SAndroid Build Coastguard Worker return ( 381*da0073e9SAndroid Build Coastguard Worker ht[-1] 382*da0073e9SAndroid Build Coastguard Worker if not self.config.birnn 383*da0073e9SAndroid Build Coastguard Worker else ht[-2:].transpose(0, 1).contiguous().view(batch_size, -1) 384*da0073e9SAndroid Build Coastguard Worker ) 385*da0073e9SAndroid Build Coastguard Worker 386*da0073e9SAndroid Build Coastguard Worker class SNLIClassifier(nn.Module): 387*da0073e9SAndroid Build Coastguard Worker def __init__(self, config): 388*da0073e9SAndroid Build Coastguard Worker super().__init__() 389*da0073e9SAndroid Build Coastguard Worker self.config = config 390*da0073e9SAndroid Build Coastguard Worker self.embed = nn.Embedding(config.n_embed, config.d_embed) 391*da0073e9SAndroid Build Coastguard Worker self.projection = Linear(config.d_embed, config.d_proj) 392*da0073e9SAndroid Build Coastguard Worker self.encoder = Encoder(config) 393*da0073e9SAndroid Build Coastguard Worker self.dropout = nn.Dropout(p=config.dp_ratio) 394*da0073e9SAndroid Build Coastguard Worker self.relu = nn.ReLU() 395*da0073e9SAndroid Build Coastguard Worker seq_in_size = 2 * config.d_hidden 396*da0073e9SAndroid Build Coastguard Worker if self.config.birnn: 397*da0073e9SAndroid Build Coastguard Worker seq_in_size *= 2 398*da0073e9SAndroid Build Coastguard Worker lin_config = [seq_in_size] * 2 399*da0073e9SAndroid Build Coastguard Worker self.out = nn.Sequential( 400*da0073e9SAndroid Build Coastguard Worker Linear(*lin_config), 401*da0073e9SAndroid Build Coastguard Worker self.relu, 402*da0073e9SAndroid Build Coastguard Worker self.dropout, 403*da0073e9SAndroid Build Coastguard Worker Linear(*lin_config), 404*da0073e9SAndroid Build Coastguard Worker self.relu, 405*da0073e9SAndroid Build Coastguard Worker self.dropout, 406*da0073e9SAndroid Build Coastguard Worker Linear(*lin_config), 407*da0073e9SAndroid Build Coastguard Worker self.relu, 408*da0073e9SAndroid Build Coastguard Worker self.dropout, 409*da0073e9SAndroid Build Coastguard Worker Linear(seq_in_size, config.d_out), 410*da0073e9SAndroid Build Coastguard Worker ) 411*da0073e9SAndroid Build Coastguard Worker 412*da0073e9SAndroid Build Coastguard Worker def forward(self, premise, hypothesis): 413*da0073e9SAndroid Build Coastguard Worker prem_embed = self.embed(premise) 414*da0073e9SAndroid Build Coastguard Worker hypo_embed = self.embed(hypothesis) 415*da0073e9SAndroid Build Coastguard Worker if self.config.fix_emb: 416*da0073e9SAndroid Build Coastguard Worker prem_embed = prem_embed.detach() 417*da0073e9SAndroid Build Coastguard Worker hypo_embed = hypo_embed.detach() 418*da0073e9SAndroid Build Coastguard Worker if self.config.projection: 419*da0073e9SAndroid Build Coastguard Worker prem_embed = self.relu(self.projection(prem_embed)) 420*da0073e9SAndroid Build Coastguard Worker hypo_embed = self.relu(self.projection(hypo_embed)) 421*da0073e9SAndroid Build Coastguard Worker premise = self.encoder(prem_embed) 422*da0073e9SAndroid Build Coastguard Worker hypothesis = self.encoder(hypo_embed) 423*da0073e9SAndroid Build Coastguard Worker scores = self.out(torch.cat([premise, hypothesis], 1)) 424*da0073e9SAndroid Build Coastguard Worker return scores 425*da0073e9SAndroid Build Coastguard Worker 426*da0073e9SAndroid Build Coastguard Worker class Config: 427*da0073e9SAndroid Build Coastguard Worker n_embed = 100 428*da0073e9SAndroid Build Coastguard Worker d_embed = 100 429*da0073e9SAndroid Build Coastguard Worker d_proj = 300 430*da0073e9SAndroid Build Coastguard Worker dp_ratio = 0.0 # For deterministic testing TODO: change by fixing seed in checkTrace? 431*da0073e9SAndroid Build Coastguard Worker d_hidden = 30 432*da0073e9SAndroid Build Coastguard Worker birnn = True 433*da0073e9SAndroid Build Coastguard Worker d_out = 300 434*da0073e9SAndroid Build Coastguard Worker fix_emb = True 435*da0073e9SAndroid Build Coastguard Worker projection = True 436*da0073e9SAndroid Build Coastguard Worker n_layers = 2 437*da0073e9SAndroid Build Coastguard Worker n_cells = 4 # 2 * n_layers because birnn = True 438*da0073e9SAndroid Build Coastguard Worker 439*da0073e9SAndroid Build Coastguard Worker premise = torch.LongTensor(48, 64).random_(0, 100).to(device) 440*da0073e9SAndroid Build Coastguard Worker hypothesis = torch.LongTensor(24, 64).random_(0, 100).to(device) 441*da0073e9SAndroid Build Coastguard Worker 442*da0073e9SAndroid Build Coastguard Worker self.checkTrace( 443*da0073e9SAndroid Build Coastguard Worker SNLIClassifier(Config()).to(device), 444*da0073e9SAndroid Build Coastguard Worker (premise, hypothesis), 445*da0073e9SAndroid Build Coastguard Worker inputs_require_grads=False, 446*da0073e9SAndroid Build Coastguard Worker export_import=check_export_import, 447*da0073e9SAndroid Build Coastguard Worker ) 448*da0073e9SAndroid Build Coastguard Worker 449*da0073e9SAndroid Build Coastguard Worker @slowTest 450*da0073e9SAndroid Build Coastguard Worker def test_snli(self): 451*da0073e9SAndroid Build Coastguard Worker self._test_snli(self, device="cpu") 452*da0073e9SAndroid Build Coastguard Worker 453*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf(not RUN_CUDA, "no CUDA") 454*da0073e9SAndroid Build Coastguard Worker def test_snli_cuda(self): 455*da0073e9SAndroid Build Coastguard Worker # XXX: export_import on CUDA modules doesn't work (#11480) 456*da0073e9SAndroid Build Coastguard Worker self._test_snli(self, device="cuda", check_export_import=False) 457*da0073e9SAndroid Build Coastguard Worker 458*da0073e9SAndroid Build Coastguard Worker @staticmethod 459*da0073e9SAndroid Build Coastguard Worker def _test_super_resolution(self, device, check_export_import=True): 460*da0073e9SAndroid Build Coastguard Worker class Net(nn.Module): 461*da0073e9SAndroid Build Coastguard Worker def __init__(self, upscale_factor): 462*da0073e9SAndroid Build Coastguard Worker super().__init__() 463*da0073e9SAndroid Build Coastguard Worker 464*da0073e9SAndroid Build Coastguard Worker self.relu = nn.ReLU() 465*da0073e9SAndroid Build Coastguard Worker self.conv1 = nn.Conv2d(1, 64, (5, 5), (1, 1), (2, 2)) 466*da0073e9SAndroid Build Coastguard Worker self.conv2 = nn.Conv2d(64, 64, (3, 3), (1, 1), (1, 1)) 467*da0073e9SAndroid Build Coastguard Worker self.conv3 = nn.Conv2d(64, 32, (3, 3), (1, 1), (1, 1)) 468*da0073e9SAndroid Build Coastguard Worker self.conv4 = nn.Conv2d(32, upscale_factor**2, (3, 3), (1, 1), (1, 1)) 469*da0073e9SAndroid Build Coastguard Worker self.pixel_shuffle = nn.PixelShuffle(upscale_factor) 470*da0073e9SAndroid Build Coastguard Worker 471*da0073e9SAndroid Build Coastguard Worker def forward(self, x): 472*da0073e9SAndroid Build Coastguard Worker x = self.relu(self.conv1(x)) 473*da0073e9SAndroid Build Coastguard Worker x = self.relu(self.conv2(x)) 474*da0073e9SAndroid Build Coastguard Worker x = self.relu(self.conv3(x)) 475*da0073e9SAndroid Build Coastguard Worker x = self.pixel_shuffle(self.conv4(x)) 476*da0073e9SAndroid Build Coastguard Worker return x 477*da0073e9SAndroid Build Coastguard Worker 478*da0073e9SAndroid Build Coastguard Worker net = Net(upscale_factor=4).to(device) 479*da0073e9SAndroid Build Coastguard Worker self.checkTrace( 480*da0073e9SAndroid Build Coastguard Worker net, 481*da0073e9SAndroid Build Coastguard Worker (torch.rand(5, 1, 32, 32, device=device),), 482*da0073e9SAndroid Build Coastguard Worker export_import=check_export_import, 483*da0073e9SAndroid Build Coastguard Worker ) 484*da0073e9SAndroid Build Coastguard Worker 485*da0073e9SAndroid Build Coastguard Worker @slowTest 486*da0073e9SAndroid Build Coastguard Worker def test_super_resolution(self): 487*da0073e9SAndroid Build Coastguard Worker self._test_super_resolution(self, device="cpu") 488*da0073e9SAndroid Build Coastguard Worker 489*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf(not RUN_CUDA, "no CUDA") 490*da0073e9SAndroid Build Coastguard Worker def test_super_resolution_cuda(self): 491*da0073e9SAndroid Build Coastguard Worker # XXX: export_import on CUDA modules doesn't work (#11480) 492*da0073e9SAndroid Build Coastguard Worker self._test_super_resolution(self, device="cuda", check_export_import=False) 493*da0073e9SAndroid Build Coastguard Worker 494*da0073e9SAndroid Build Coastguard Worker @suppress_warnings 495*da0073e9SAndroid Build Coastguard Worker def test_time_sequence_prediction(self): 496*da0073e9SAndroid Build Coastguard Worker class Sequence(torch.jit.ScriptModule): 497*da0073e9SAndroid Build Coastguard Worker def __init__(self) -> None: 498*da0073e9SAndroid Build Coastguard Worker super().__init__() 499*da0073e9SAndroid Build Coastguard Worker self.lstm1 = nn.LSTMCell(1, 51) 500*da0073e9SAndroid Build Coastguard Worker self.lstm2 = nn.LSTMCell(51, 51) 501*da0073e9SAndroid Build Coastguard Worker self.linear = nn.Linear(51, 1) 502*da0073e9SAndroid Build Coastguard Worker 503*da0073e9SAndroid Build Coastguard Worker @torch.jit.script_method 504*da0073e9SAndroid Build Coastguard Worker def forward(self, input): 505*da0073e9SAndroid Build Coastguard Worker # TODO: add future as input with default val 506*da0073e9SAndroid Build Coastguard Worker # see https://github.com/pytorch/pytorch/issues/8724 507*da0073e9SAndroid Build Coastguard Worker outputs = torch.empty((3, 0)) 508*da0073e9SAndroid Build Coastguard Worker h_t = torch.zeros((3, 51)) 509*da0073e9SAndroid Build Coastguard Worker c_t = torch.zeros((3, 51)) 510*da0073e9SAndroid Build Coastguard Worker h_t2 = torch.zeros((3, 51)) 511*da0073e9SAndroid Build Coastguard Worker c_t2 = torch.zeros((3, 51)) 512*da0073e9SAndroid Build Coastguard Worker 513*da0073e9SAndroid Build Coastguard Worker output = torch.zeros([3, 51]) 514*da0073e9SAndroid Build Coastguard Worker future = 2 515*da0073e9SAndroid Build Coastguard Worker 516*da0073e9SAndroid Build Coastguard Worker # TODO: chunk call should appear as the for loop iterable 517*da0073e9SAndroid Build Coastguard Worker # We hard-code it to 4 for now. 518*da0073e9SAndroid Build Coastguard Worker a, b, c, d = input.chunk(input.size(1), dim=1) 519*da0073e9SAndroid Build Coastguard Worker for input_t in (a, b, c, d): 520*da0073e9SAndroid Build Coastguard Worker h_t, c_t = self.lstm1(input_t, (h_t, c_t)) 521*da0073e9SAndroid Build Coastguard Worker h_t2, c_t2 = self.lstm2(h_t, (h_t2, c_t2)) 522*da0073e9SAndroid Build Coastguard Worker output = self.linear(h_t2) 523*da0073e9SAndroid Build Coastguard Worker outputs = torch.cat((outputs, output), 1) 524*da0073e9SAndroid Build Coastguard Worker for _ in range(future): # if we should predict the future 525*da0073e9SAndroid Build Coastguard Worker h_t, c_t = self.lstm1(output, (h_t, c_t)) 526*da0073e9SAndroid Build Coastguard Worker h_t2, c_t2 = self.lstm2(h_t, (h_t2, c_t2)) 527*da0073e9SAndroid Build Coastguard Worker output = self.linear(h_t2) 528*da0073e9SAndroid Build Coastguard Worker outputs = torch.cat((outputs, output), 1) 529*da0073e9SAndroid Build Coastguard Worker return outputs 530*da0073e9SAndroid Build Coastguard Worker 531*da0073e9SAndroid Build Coastguard Worker class Traced(nn.Module): 532*da0073e9SAndroid Build Coastguard Worker def __init__(self) -> None: 533*da0073e9SAndroid Build Coastguard Worker super().__init__() 534*da0073e9SAndroid Build Coastguard Worker self.seq = Sequence() 535*da0073e9SAndroid Build Coastguard Worker 536*da0073e9SAndroid Build Coastguard Worker def forward(self, input): 537*da0073e9SAndroid Build Coastguard Worker return self.seq.forward(input) 538*da0073e9SAndroid Build Coastguard Worker 539*da0073e9SAndroid Build Coastguard Worker # disabled due to a jitter issues that will be fixed by using load/store in the compiler 540*da0073e9SAndroid Build Coastguard Worker with torch._jit_internal._disable_emit_hooks(): 541*da0073e9SAndroid Build Coastguard Worker # TODO: toggle export_import once above issues are fixed 542*da0073e9SAndroid Build Coastguard Worker self.checkTrace(Traced(), (torch.rand(3, 4),), export_import=False) 543*da0073e9SAndroid Build Coastguard Worker 544*da0073e9SAndroid Build Coastguard Worker @staticmethod 545*da0073e9SAndroid Build Coastguard Worker def _test_vae(self, device, check_export_import=True): 546*da0073e9SAndroid Build Coastguard Worker class VAE(nn.Module): 547*da0073e9SAndroid Build Coastguard Worker def __init__(self) -> None: 548*da0073e9SAndroid Build Coastguard Worker super().__init__() 549*da0073e9SAndroid Build Coastguard Worker 550*da0073e9SAndroid Build Coastguard Worker self.fc1 = nn.Linear(784, 400) 551*da0073e9SAndroid Build Coastguard Worker self.fc21 = nn.Linear(400, 20) 552*da0073e9SAndroid Build Coastguard Worker self.fc22 = nn.Linear(400, 20) 553*da0073e9SAndroid Build Coastguard Worker self.fc3 = nn.Linear(20, 400) 554*da0073e9SAndroid Build Coastguard Worker self.fc4 = nn.Linear(400, 784) 555*da0073e9SAndroid Build Coastguard Worker 556*da0073e9SAndroid Build Coastguard Worker def encode(self, x): 557*da0073e9SAndroid Build Coastguard Worker h1 = F.relu(self.fc1(x)) 558*da0073e9SAndroid Build Coastguard Worker return self.fc21(h1), self.fc22(h1) 559*da0073e9SAndroid Build Coastguard Worker 560*da0073e9SAndroid Build Coastguard Worker def reparameterize(self, mu, logvar): 561*da0073e9SAndroid Build Coastguard Worker if self.training: 562*da0073e9SAndroid Build Coastguard Worker std = torch.exp(0.5 * logvar) 563*da0073e9SAndroid Build Coastguard Worker eps = torch.randn_like(std) 564*da0073e9SAndroid Build Coastguard Worker return eps.mul(std).add_(mu) 565*da0073e9SAndroid Build Coastguard Worker else: 566*da0073e9SAndroid Build Coastguard Worker return mu 567*da0073e9SAndroid Build Coastguard Worker 568*da0073e9SAndroid Build Coastguard Worker def decode(self, z): 569*da0073e9SAndroid Build Coastguard Worker h3 = F.relu(self.fc3(z)) 570*da0073e9SAndroid Build Coastguard Worker return torch.sigmoid(self.fc4(h3)) 571*da0073e9SAndroid Build Coastguard Worker 572*da0073e9SAndroid Build Coastguard Worker def forward(self, x): 573*da0073e9SAndroid Build Coastguard Worker mu, logvar = self.encode(x.view(-1, 784)) 574*da0073e9SAndroid Build Coastguard Worker z = self.reparameterize(mu, logvar) 575*da0073e9SAndroid Build Coastguard Worker return self.decode(z), mu, logvar 576*da0073e9SAndroid Build Coastguard Worker 577*da0073e9SAndroid Build Coastguard Worker with enable_profiling_mode_for_profiling_tests(): 578*da0073e9SAndroid Build Coastguard Worker # eval() is present because randn_like makes this nondeterministic 579*da0073e9SAndroid Build Coastguard Worker self.checkTrace( 580*da0073e9SAndroid Build Coastguard Worker VAE().to(device).eval(), 581*da0073e9SAndroid Build Coastguard Worker (torch.rand(128, 1, 28, 28, device=device),), 582*da0073e9SAndroid Build Coastguard Worker export_import=check_export_import, 583*da0073e9SAndroid Build Coastguard Worker ) 584*da0073e9SAndroid Build Coastguard Worker 585*da0073e9SAndroid Build Coastguard Worker def test_vae(self): 586*da0073e9SAndroid Build Coastguard Worker self._test_vae(self, device="cpu") 587*da0073e9SAndroid Build Coastguard Worker 588*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf(not RUN_CUDA, "no CUDA") 589*da0073e9SAndroid Build Coastguard Worker def test_vae_cuda(self): 590*da0073e9SAndroid Build Coastguard Worker # XXX: export_import on CUDA modules doesn't work (#11480) 591*da0073e9SAndroid Build Coastguard Worker self._test_vae(self, device="cuda", check_export_import=False) 592*da0073e9SAndroid Build Coastguard Worker 593*da0073e9SAndroid Build Coastguard Worker @slowTest 594*da0073e9SAndroid Build Coastguard Worker @skipIfNoTorchVision 595*da0073e9SAndroid Build Coastguard Worker def test_script_module_trace_resnet18(self): 596*da0073e9SAndroid Build Coastguard Worker x = torch.ones(1, 3, 224, 224) 597*da0073e9SAndroid Build Coastguard Worker m_orig = torch.jit.trace( 598*da0073e9SAndroid Build Coastguard Worker torchvision.models.resnet18(), torch.ones(1, 3, 224, 224) 599*da0073e9SAndroid Build Coastguard Worker ) 600*da0073e9SAndroid Build Coastguard Worker m_import = self.getExportImportCopy(m_orig) 601*da0073e9SAndroid Build Coastguard Worker 602*da0073e9SAndroid Build Coastguard Worker input = torch.randn(1, 3, 224, 224, requires_grad=True) 603*da0073e9SAndroid Build Coastguard Worker output_orig = m_orig(input) 604*da0073e9SAndroid Build Coastguard Worker output_orig.sum().backward() 605*da0073e9SAndroid Build Coastguard Worker grad_orig = input.grad.clone() 606*da0073e9SAndroid Build Coastguard Worker input.grad.zero_() 607*da0073e9SAndroid Build Coastguard Worker 608*da0073e9SAndroid Build Coastguard Worker output_import = m_import(input) 609*da0073e9SAndroid Build Coastguard Worker output_import.sum().backward() 610*da0073e9SAndroid Build Coastguard Worker grad_import = input.grad.clone() 611*da0073e9SAndroid Build Coastguard Worker 612*da0073e9SAndroid Build Coastguard Worker self.assertEqual(output_orig, output_import) 613*da0073e9SAndroid Build Coastguard Worker self.assertEqual(grad_orig, grad_import) 614*da0073e9SAndroid Build Coastguard Worker 615*da0073e9SAndroid Build Coastguard Worker @slowTest 616*da0073e9SAndroid Build Coastguard Worker @skipIfNoTorchVision 617*da0073e9SAndroid Build Coastguard Worker def test_script_module_script_resnet(self): 618*da0073e9SAndroid Build Coastguard Worker def conv1x1(in_planes, out_planes, stride=1): 619*da0073e9SAndroid Build Coastguard Worker """1x1 convolution""" 620*da0073e9SAndroid Build Coastguard Worker return nn.Conv2d( 621*da0073e9SAndroid Build Coastguard Worker in_planes, out_planes, kernel_size=1, stride=stride, bias=False 622*da0073e9SAndroid Build Coastguard Worker ) 623*da0073e9SAndroid Build Coastguard Worker 624*da0073e9SAndroid Build Coastguard Worker def conv3x3(in_planes, out_planes, stride=1): 625*da0073e9SAndroid Build Coastguard Worker """3x3 convolution with padding""" 626*da0073e9SAndroid Build Coastguard Worker return nn.Conv2d( 627*da0073e9SAndroid Build Coastguard Worker in_planes, 628*da0073e9SAndroid Build Coastguard Worker out_planes, 629*da0073e9SAndroid Build Coastguard Worker kernel_size=3, 630*da0073e9SAndroid Build Coastguard Worker stride=stride, 631*da0073e9SAndroid Build Coastguard Worker padding=1, 632*da0073e9SAndroid Build Coastguard Worker bias=False, 633*da0073e9SAndroid Build Coastguard Worker ) 634*da0073e9SAndroid Build Coastguard Worker 635*da0073e9SAndroid Build Coastguard Worker class BasicBlock(torch.jit.ScriptModule): 636*da0073e9SAndroid Build Coastguard Worker expansion = 1 637*da0073e9SAndroid Build Coastguard Worker __constants__ = ["downsample"] 638*da0073e9SAndroid Build Coastguard Worker 639*da0073e9SAndroid Build Coastguard Worker def __init__(self, inplanes, planes, stride=1, downsample=None): 640*da0073e9SAndroid Build Coastguard Worker super().__init__() 641*da0073e9SAndroid Build Coastguard Worker self.conv1 = conv3x3(inplanes, planes, stride) 642*da0073e9SAndroid Build Coastguard Worker self.bn1 = nn.BatchNorm2d(planes) 643*da0073e9SAndroid Build Coastguard Worker self.relu = nn.ReLU(inplace=True) 644*da0073e9SAndroid Build Coastguard Worker self.conv2 = conv3x3(planes, planes) 645*da0073e9SAndroid Build Coastguard Worker self.bn2 = nn.BatchNorm2d(planes) 646*da0073e9SAndroid Build Coastguard Worker self.downsample = downsample 647*da0073e9SAndroid Build Coastguard Worker self.stride = stride 648*da0073e9SAndroid Build Coastguard Worker 649*da0073e9SAndroid Build Coastguard Worker @torch.jit.script_method 650*da0073e9SAndroid Build Coastguard Worker def forward(self, x): 651*da0073e9SAndroid Build Coastguard Worker residual = x 652*da0073e9SAndroid Build Coastguard Worker 653*da0073e9SAndroid Build Coastguard Worker out = self.conv1(x) 654*da0073e9SAndroid Build Coastguard Worker out = self.bn1(out) 655*da0073e9SAndroid Build Coastguard Worker out = self.relu(out) 656*da0073e9SAndroid Build Coastguard Worker 657*da0073e9SAndroid Build Coastguard Worker out = self.conv2(out) 658*da0073e9SAndroid Build Coastguard Worker out = self.bn2(out) 659*da0073e9SAndroid Build Coastguard Worker 660*da0073e9SAndroid Build Coastguard Worker if self.downsample is not None: 661*da0073e9SAndroid Build Coastguard Worker residual = self.downsample(x) 662*da0073e9SAndroid Build Coastguard Worker 663*da0073e9SAndroid Build Coastguard Worker out += residual 664*da0073e9SAndroid Build Coastguard Worker out = self.relu(out) 665*da0073e9SAndroid Build Coastguard Worker 666*da0073e9SAndroid Build Coastguard Worker return out 667*da0073e9SAndroid Build Coastguard Worker 668*da0073e9SAndroid Build Coastguard Worker class ResNet(torch.jit.ScriptModule): 669*da0073e9SAndroid Build Coastguard Worker __constants__ = ["layer1", "layer2", "layer3", "layer4"] 670*da0073e9SAndroid Build Coastguard Worker 671*da0073e9SAndroid Build Coastguard Worker def __init__(self, block, layers, num_classes=1000): 672*da0073e9SAndroid Build Coastguard Worker super().__init__() 673*da0073e9SAndroid Build Coastguard Worker self.inplanes = 64 674*da0073e9SAndroid Build Coastguard Worker self.conv1 = nn.Conv2d( 675*da0073e9SAndroid Build Coastguard Worker 3, 64, kernel_size=7, stride=2, padding=3, bias=False 676*da0073e9SAndroid Build Coastguard Worker ) 677*da0073e9SAndroid Build Coastguard Worker self.bn1 = nn.BatchNorm2d(64) 678*da0073e9SAndroid Build Coastguard Worker self.relu = nn.ReLU(inplace=True) 679*da0073e9SAndroid Build Coastguard Worker self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 680*da0073e9SAndroid Build Coastguard Worker self.layer1 = self._make_layer(block, 64, layers[0]) 681*da0073e9SAndroid Build Coastguard Worker self.layer2 = self._make_layer(block, 128, layers[1], stride=2) 682*da0073e9SAndroid Build Coastguard Worker self.layer3 = self._make_layer(block, 256, layers[2], stride=2) 683*da0073e9SAndroid Build Coastguard Worker self.layer4 = self._make_layer(block, 512, layers[3], stride=2) 684*da0073e9SAndroid Build Coastguard Worker self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) 685*da0073e9SAndroid Build Coastguard Worker self.fc = nn.Linear(512 * block.expansion, num_classes) 686*da0073e9SAndroid Build Coastguard Worker 687*da0073e9SAndroid Build Coastguard Worker for m in self.modules(): 688*da0073e9SAndroid Build Coastguard Worker if isinstance(m, nn.Conv2d): 689*da0073e9SAndroid Build Coastguard Worker nn.init.kaiming_normal_( 690*da0073e9SAndroid Build Coastguard Worker m.weight, mode="fan_out", nonlinearity="relu" 691*da0073e9SAndroid Build Coastguard Worker ) 692*da0073e9SAndroid Build Coastguard Worker elif isinstance(m, nn.BatchNorm2d): 693*da0073e9SAndroid Build Coastguard Worker nn.init.constant_(m.weight, 1) 694*da0073e9SAndroid Build Coastguard Worker nn.init.constant_(m.bias, 0) 695*da0073e9SAndroid Build Coastguard Worker 696*da0073e9SAndroid Build Coastguard Worker def _make_layer(self, block, planes, blocks, stride=1): 697*da0073e9SAndroid Build Coastguard Worker downsample = None 698*da0073e9SAndroid Build Coastguard Worker if stride != 1 or self.inplanes != planes * block.expansion: 699*da0073e9SAndroid Build Coastguard Worker downsample = nn.Sequential( 700*da0073e9SAndroid Build Coastguard Worker conv1x1(self.inplanes, planes * block.expansion, stride), 701*da0073e9SAndroid Build Coastguard Worker nn.BatchNorm2d(planes * block.expansion), 702*da0073e9SAndroid Build Coastguard Worker ) 703*da0073e9SAndroid Build Coastguard Worker 704*da0073e9SAndroid Build Coastguard Worker layers = [] 705*da0073e9SAndroid Build Coastguard Worker layers.append(block(self.inplanes, planes, stride, downsample)) 706*da0073e9SAndroid Build Coastguard Worker self.inplanes = planes * block.expansion 707*da0073e9SAndroid Build Coastguard Worker for _ in range(1, blocks): 708*da0073e9SAndroid Build Coastguard Worker layers.append(block(self.inplanes, planes)) 709*da0073e9SAndroid Build Coastguard Worker 710*da0073e9SAndroid Build Coastguard Worker return nn.Sequential(*layers) 711*da0073e9SAndroid Build Coastguard Worker 712*da0073e9SAndroid Build Coastguard Worker @torch.jit.script_method 713*da0073e9SAndroid Build Coastguard Worker def forward(self, x): 714*da0073e9SAndroid Build Coastguard Worker x = self.conv1(x) 715*da0073e9SAndroid Build Coastguard Worker x = self.bn1(x) 716*da0073e9SAndroid Build Coastguard Worker x = self.relu(x) 717*da0073e9SAndroid Build Coastguard Worker x = self.maxpool(x) 718*da0073e9SAndroid Build Coastguard Worker 719*da0073e9SAndroid Build Coastguard Worker x = self.layer1(x) 720*da0073e9SAndroid Build Coastguard Worker x = self.layer2(x) 721*da0073e9SAndroid Build Coastguard Worker x = self.layer3(x) 722*da0073e9SAndroid Build Coastguard Worker x = self.layer4(x) 723*da0073e9SAndroid Build Coastguard Worker 724*da0073e9SAndroid Build Coastguard Worker x = self.avgpool(x) 725*da0073e9SAndroid Build Coastguard Worker x = x.view(x.size(0), -1) 726*da0073e9SAndroid Build Coastguard Worker x = self.fc(x) 727*da0073e9SAndroid Build Coastguard Worker 728*da0073e9SAndroid Build Coastguard Worker return x 729*da0073e9SAndroid Build Coastguard Worker 730*da0073e9SAndroid Build Coastguard Worker resnet18 = ResNet(BasicBlock, [2, 2, 2, 2]) 731*da0073e9SAndroid Build Coastguard Worker 732*da0073e9SAndroid Build Coastguard Worker resnet18_imported = self.getExportImportCopy(resnet18) 733*da0073e9SAndroid Build Coastguard Worker 734*da0073e9SAndroid Build Coastguard Worker input = torch.randn(1, 3, 224, 224, requires_grad=True) 735*da0073e9SAndroid Build Coastguard Worker output_orig = resnet18(input) 736*da0073e9SAndroid Build Coastguard Worker output_orig.sum().backward() 737*da0073e9SAndroid Build Coastguard Worker grad_orig = input.grad.clone() 738*da0073e9SAndroid Build Coastguard Worker input.grad.zero_() 739*da0073e9SAndroid Build Coastguard Worker output_import = resnet18_imported(input) 740*da0073e9SAndroid Build Coastguard Worker output_import.sum().backward() 741*da0073e9SAndroid Build Coastguard Worker grad_import = input.grad.clone() 742*da0073e9SAndroid Build Coastguard Worker 743*da0073e9SAndroid Build Coastguard Worker self.assertEqual(output_orig, output_import) 744*da0073e9SAndroid Build Coastguard Worker self.assertEqual(grad_orig, grad_import) 745*da0073e9SAndroid Build Coastguard Worker 746*da0073e9SAndroid Build Coastguard Worker @skipIfNoTorchVision 747*da0073e9SAndroid Build Coastguard Worker def test_alexnet(self): 748*da0073e9SAndroid Build Coastguard Worker x = torch.ones(1, 3, 224, 224) 749*da0073e9SAndroid Build Coastguard Worker model = torchvision.models.AlexNet() 750*da0073e9SAndroid Build Coastguard Worker with torch.random.fork_rng(devices=[]): 751*da0073e9SAndroid Build Coastguard Worker g, outputs, inputs = torch.jit._get_trace_graph( 752*da0073e9SAndroid Build Coastguard Worker model, x, return_inputs=True 753*da0073e9SAndroid Build Coastguard Worker ) 754*da0073e9SAndroid Build Coastguard Worker self.run_pass("cse", g) 755*da0073e9SAndroid Build Coastguard Worker m = self.createFunctionFromGraph(g) 756*da0073e9SAndroid Build Coastguard Worker with torch.random.fork_rng(devices=[]): 757*da0073e9SAndroid Build Coastguard Worker self.assertEqual(outputs, m(*inputs)) 758