1*da0073e9SAndroid Build Coastguard Worker# Owner(s): ["module: unknown"] 2*da0073e9SAndroid Build Coastguard Worker 3*da0073e9SAndroid Build Coastguard Workerimport unittest 4*da0073e9SAndroid Build Coastguard Workerfrom typing import Dict, Optional 5*da0073e9SAndroid Build Coastguard Worker 6*da0073e9SAndroid Build Coastguard Workerimport numpy as np 7*da0073e9SAndroid Build Coastguard Workerimport torch 8*da0073e9SAndroid Build Coastguard Workerfrom torch import nn 9*da0073e9SAndroid Build Coastguard Workerfrom torch.testing._internal.common_utils import TestCase, run_tests 10*da0073e9SAndroid Build Coastguard Workerfrom torch.testing._internal.static_module import StaticModule 11*da0073e9SAndroid Build Coastguard Workerfrom typing import List 12*da0073e9SAndroid Build Coastguard Worker 13*da0073e9SAndroid Build Coastguard Worker 14*da0073e9SAndroid Build Coastguard Workerdef linear_shim( 15*da0073e9SAndroid Build Coastguard Worker input: torch.Tensor, weight: torch.Tensor, bias: Optional[torch.Tensor] = None 16*da0073e9SAndroid Build Coastguard Worker) -> torch.Tensor: 17*da0073e9SAndroid Build Coastguard Worker output = input.matmul(weight.t()) 18*da0073e9SAndroid Build Coastguard Worker if bias is not None: 19*da0073e9SAndroid Build Coastguard Worker output += bias 20*da0073e9SAndroid Build Coastguard Worker ret = output 21*da0073e9SAndroid Build Coastguard Worker return ret 22*da0073e9SAndroid Build Coastguard Worker 23*da0073e9SAndroid Build Coastguard Worker 24*da0073e9SAndroid Build Coastguard Workertorch.nn.functional.linear = linear_shim 25*da0073e9SAndroid Build Coastguard Worker 26*da0073e9SAndroid Build Coastguard Worker 27*da0073e9SAndroid Build Coastguard Workerclass MultiHeadAttentionLayer(nn.Module): 28*da0073e9SAndroid Build Coastguard Worker def __init__(self, hid_dim, n_heads, dropout, device): 29*da0073e9SAndroid Build Coastguard Worker super().__init__() 30*da0073e9SAndroid Build Coastguard Worker assert hid_dim % n_heads == 0 31*da0073e9SAndroid Build Coastguard Worker self.hid_dim = hid_dim 32*da0073e9SAndroid Build Coastguard Worker self.n_heads = n_heads 33*da0073e9SAndroid Build Coastguard Worker self.head_dim = hid_dim // n_heads 34*da0073e9SAndroid Build Coastguard Worker self.fc_q = nn.Linear(hid_dim, hid_dim) 35*da0073e9SAndroid Build Coastguard Worker self.fc_k = nn.Linear(hid_dim, hid_dim) 36*da0073e9SAndroid Build Coastguard Worker self.fc_v = nn.Linear(hid_dim, hid_dim) 37*da0073e9SAndroid Build Coastguard Worker self.fc_o = nn.Linear(hid_dim, hid_dim) 38*da0073e9SAndroid Build Coastguard Worker # self.dropout = nn.Dropout(dropout) 39*da0073e9SAndroid Build Coastguard Worker self.scale = torch.sqrt(torch.FloatTensor([self.head_dim])).to(device) 40*da0073e9SAndroid Build Coastguard Worker 41*da0073e9SAndroid Build Coastguard Worker def forward(self, query, key, value, mask): 42*da0073e9SAndroid Build Coastguard Worker batch_size = query.shape[0] 43*da0073e9SAndroid Build Coastguard Worker Q = self.fc_q(query) 44*da0073e9SAndroid Build Coastguard Worker K = self.fc_k(key) 45*da0073e9SAndroid Build Coastguard Worker V = self.fc_v(value) 46*da0073e9SAndroid Build Coastguard Worker Q = Q.view(batch_size, -1, self.n_heads, self.head_dim).permute(0, 2, 1, 3) 47*da0073e9SAndroid Build Coastguard Worker K = K.view(batch_size, -1, self.n_heads, self.head_dim).permute(0, 2, 1, 3) 48*da0073e9SAndroid Build Coastguard Worker V = V.view(batch_size, -1, self.n_heads, self.head_dim).permute(0, 2, 1, 3) 49*da0073e9SAndroid Build Coastguard Worker energy = torch.matmul(Q, K.permute(0, 1, 3, 2)) / self.scale 50*da0073e9SAndroid Build Coastguard Worker # energy = energy.masked_fill(mask == 0, -1e10) 51*da0073e9SAndroid Build Coastguard Worker attention = torch.softmax(energy, dim=-1) 52*da0073e9SAndroid Build Coastguard Worker # x = torch.matmul(self.dropout(attention), V) 53*da0073e9SAndroid Build Coastguard Worker x = torch.matmul(attention, V) 54*da0073e9SAndroid Build Coastguard Worker x = x.permute(0, 2, 1, 3).contiguous() 55*da0073e9SAndroid Build Coastguard Worker x = x.view(batch_size, -1, self.hid_dim) 56*da0073e9SAndroid Build Coastguard Worker x = self.fc_o(x) 57*da0073e9SAndroid Build Coastguard Worker return x, attention 58*da0073e9SAndroid Build Coastguard Worker 59*da0073e9SAndroid Build Coastguard Worker 60*da0073e9SAndroid Build Coastguard Worker# Taken from https://github.com/facebookresearch/dlrm/blob/master/dlrm_s_pytorch.py 61*da0073e9SAndroid Build Coastguard Workerdef create_mlp(ln, sigmoid_layer): 62*da0073e9SAndroid Build Coastguard Worker layers = nn.ModuleList() 63*da0073e9SAndroid Build Coastguard Worker for i in range(0, len(ln) - 1): 64*da0073e9SAndroid Build Coastguard Worker n = ln[i] 65*da0073e9SAndroid Build Coastguard Worker m = ln[i + 1] 66*da0073e9SAndroid Build Coastguard Worker 67*da0073e9SAndroid Build Coastguard Worker LL = nn.Linear(int(n), int(m), bias=True) 68*da0073e9SAndroid Build Coastguard Worker 69*da0073e9SAndroid Build Coastguard Worker mean = 0.0 # std_dev = np.sqrt(variance) 70*da0073e9SAndroid Build Coastguard Worker std_dev = np.sqrt(2 / (m + n)) # np.sqrt(1 / m) # np.sqrt(1 / n) 71*da0073e9SAndroid Build Coastguard Worker W = np.random.normal(mean, std_dev, size=(m, n)).astype(np.float32) 72*da0073e9SAndroid Build Coastguard Worker std_dev = np.sqrt(1 / m) # np.sqrt(2 / (m + 1)) 73*da0073e9SAndroid Build Coastguard Worker bt = np.random.normal(mean, std_dev, size=m).astype(np.float32) 74*da0073e9SAndroid Build Coastguard Worker LL.weight.data = torch.tensor(W, requires_grad=True) 75*da0073e9SAndroid Build Coastguard Worker LL.bias.data = torch.tensor(bt, requires_grad=True) 76*da0073e9SAndroid Build Coastguard Worker layers.append(LL) 77*da0073e9SAndroid Build Coastguard Worker 78*da0073e9SAndroid Build Coastguard Worker if i == sigmoid_layer: 79*da0073e9SAndroid Build Coastguard Worker layers.append(nn.Sigmoid()) 80*da0073e9SAndroid Build Coastguard Worker else: 81*da0073e9SAndroid Build Coastguard Worker layers.append(nn.ReLU()) 82*da0073e9SAndroid Build Coastguard Worker 83*da0073e9SAndroid Build Coastguard Worker with torch.no_grad(): 84*da0073e9SAndroid Build Coastguard Worker s = torch.jit.script(torch.nn.Sequential(*layers)) 85*da0073e9SAndroid Build Coastguard Worker s.eval() 86*da0073e9SAndroid Build Coastguard Worker return s 87*da0073e9SAndroid Build Coastguard Worker 88*da0073e9SAndroid Build Coastguard Worker 89*da0073e9SAndroid Build Coastguard Workerdef trivial_graph(a, b, c): 90*da0073e9SAndroid Build Coastguard Worker s = torch.tensor([[3, 3], [3, 3]]) 91*da0073e9SAndroid Build Coastguard Worker return a + b * c + s 92*da0073e9SAndroid Build Coastguard Worker 93*da0073e9SAndroid Build Coastguard Workerdef elementwise_square_addition(input1, input2): 94*da0073e9SAndroid Build Coastguard Worker return input1 * input1 + input2 * input2 95*da0073e9SAndroid Build Coastguard Worker 96*da0073e9SAndroid Build Coastguard Workerdef fork_wait_graph1(input1, input2): 97*da0073e9SAndroid Build Coastguard Worker fut = torch.jit.fork(elementwise_square_addition, input1, input2) 98*da0073e9SAndroid Build Coastguard Worker return torch.jit.wait(fut) 99*da0073e9SAndroid Build Coastguard Worker 100*da0073e9SAndroid Build Coastguard Workerdef fork_wait_graph2(input1, input2): 101*da0073e9SAndroid Build Coastguard Worker fut = torch.jit.fork(loop_graph, input1, input2, 5) 102*da0073e9SAndroid Build Coastguard Worker return torch.jit.wait(fut) 103*da0073e9SAndroid Build Coastguard Worker 104*da0073e9SAndroid Build Coastguard Worker""" 105*da0073e9SAndroid Build Coastguard Worker graph with multiple fork/wait operations 106*da0073e9SAndroid Build Coastguard Worker :param input: torch.tensor input to forked subgraph 107*da0073e9SAndroid Build Coastguard Worker :param iters: number of future/wait pairs to be created 108*da0073e9SAndroid Build Coastguard Worker""" 109*da0073e9SAndroid Build Coastguard Workerdef fork_wait_graph3(input, iters: int): 110*da0073e9SAndroid Build Coastguard Worker futures : List[torch.jit.Future[torch.Tensor]] = [] 111*da0073e9SAndroid Build Coastguard Worker for _ in range(iters): 112*da0073e9SAndroid Build Coastguard Worker futures.append(torch.jit.fork(torch.neg, input)) 113*da0073e9SAndroid Build Coastguard Worker results = [] 114*da0073e9SAndroid Build Coastguard Worker for future in futures: 115*da0073e9SAndroid Build Coastguard Worker results.append(torch.jit.wait(future)) 116*da0073e9SAndroid Build Coastguard Worker return torch.sum(torch.stack(results)) 117*da0073e9SAndroid Build Coastguard Worker 118*da0073e9SAndroid Build Coastguard Worker""" 119*da0073e9SAndroid Build Coastguard Worker graph with multi-level fork/wait operations 120*da0073e9SAndroid Build Coastguard Worker :param input: torch.tensor input to forked subgraph 121*da0073e9SAndroid Build Coastguard Worker :param num_forks: number of top level forks 122*da0073e9SAndroid Build Coastguard Worker :param num_child_forks: number of child forks per parent fork 123*da0073e9SAndroid Build Coastguard Worker""" 124*da0073e9SAndroid Build Coastguard Workerdef fork_wait_graph4(input, num_forks: int, num_child_forks: int): 125*da0073e9SAndroid Build Coastguard Worker futures : List[torch.jit.Future[torch.Tensor]] = [] 126*da0073e9SAndroid Build Coastguard Worker for _ in range(num_forks): 127*da0073e9SAndroid Build Coastguard Worker futures.append(torch.jit.fork(fork_wait_graph3, input, num_child_forks)) 128*da0073e9SAndroid Build Coastguard Worker results = [] 129*da0073e9SAndroid Build Coastguard Worker for future in futures: 130*da0073e9SAndroid Build Coastguard Worker results.append(torch.jit.wait(future)) 131*da0073e9SAndroid Build Coastguard Worker return torch.sum(torch.stack(results)) 132*da0073e9SAndroid Build Coastguard Worker 133*da0073e9SAndroid Build Coastguard Workerdef add_tensor(input1, input2): 134*da0073e9SAndroid Build Coastguard Worker return input1 + input2 135*da0073e9SAndroid Build Coastguard Worker 136*da0073e9SAndroid Build Coastguard Workerdef fork_wait_graph_exception(input1, input2): 137*da0073e9SAndroid Build Coastguard Worker fut = torch.jit.fork(add_tensor, input1, input2) 138*da0073e9SAndroid Build Coastguard Worker return torch.jit.wait(fut) 139*da0073e9SAndroid Build Coastguard Worker 140*da0073e9SAndroid Build Coastguard Workerdef loop_graph(a, b, iters: int): 141*da0073e9SAndroid Build Coastguard Worker c = a + b * 2 142*da0073e9SAndroid Build Coastguard Worker for i in range(iters): 143*da0073e9SAndroid Build Coastguard Worker c = c + b 144*da0073e9SAndroid Build Coastguard Worker c *= 2 145*da0073e9SAndroid Build Coastguard Worker c -= a 146*da0073e9SAndroid Build Coastguard Worker return c 147*da0073e9SAndroid Build Coastguard Worker 148*da0073e9SAndroid Build Coastguard Worker 149*da0073e9SAndroid Build Coastguard Workerdef output_graph(a, b, c, iters: int): 150*da0073e9SAndroid Build Coastguard Worker s = torch.tensor([[3, 3], [3, 3]]) 151*da0073e9SAndroid Build Coastguard Worker k = a + b * c + s 152*da0073e9SAndroid Build Coastguard Worker d: Dict[int, torch.Tensor] = {} 153*da0073e9SAndroid Build Coastguard Worker for i in range(iters): 154*da0073e9SAndroid Build Coastguard Worker d[i] = k + i 155*da0073e9SAndroid Build Coastguard Worker return d 156*da0073e9SAndroid Build Coastguard Worker 157*da0073e9SAndroid Build Coastguard Worker 158*da0073e9SAndroid Build Coastguard Workerclass SubModule(nn.Module): 159*da0073e9SAndroid Build Coastguard Worker def __init__(self) -> None: 160*da0073e9SAndroid Build Coastguard Worker super().__init__() 161*da0073e9SAndroid Build Coastguard Worker self.a = 11 162*da0073e9SAndroid Build Coastguard Worker self.b = 2 163*da0073e9SAndroid Build Coastguard Worker 164*da0073e9SAndroid Build Coastguard Worker def forward(self, x): 165*da0073e9SAndroid Build Coastguard Worker return self.a + self.b + x 166*da0073e9SAndroid Build Coastguard Worker 167*da0073e9SAndroid Build Coastguard Worker 168*da0073e9SAndroid Build Coastguard Workerclass SubModule2(nn.Module): 169*da0073e9SAndroid Build Coastguard Worker def __init__(self) -> None: 170*da0073e9SAndroid Build Coastguard Worker super().__init__() 171*da0073e9SAndroid Build Coastguard Worker self.a = 12 172*da0073e9SAndroid Build Coastguard Worker self.b = 2 173*da0073e9SAndroid Build Coastguard Worker 174*da0073e9SAndroid Build Coastguard Worker def forward(self, x): 175*da0073e9SAndroid Build Coastguard Worker self.b = 30 176*da0073e9SAndroid Build Coastguard Worker return self.a + self.b + x 177*da0073e9SAndroid Build Coastguard Worker 178*da0073e9SAndroid Build Coastguard Worker 179*da0073e9SAndroid Build Coastguard Workerclass TestModule(nn.Module): 180*da0073e9SAndroid Build Coastguard Worker def __init__(self) -> None: 181*da0073e9SAndroid Build Coastguard Worker super().__init__() 182*da0073e9SAndroid Build Coastguard Worker self.sub1 = SubModule() 183*da0073e9SAndroid Build Coastguard Worker self.sub2 = SubModule2() 184*da0073e9SAndroid Build Coastguard Worker self.a = 3 185*da0073e9SAndroid Build Coastguard Worker self.b = 4 186*da0073e9SAndroid Build Coastguard Worker 187*da0073e9SAndroid Build Coastguard Worker def forward(self, x): 188*da0073e9SAndroid Build Coastguard Worker self.b = 20 189*da0073e9SAndroid Build Coastguard Worker return self.sub1(x) + self.a + self.b + self.sub2(x) 190*da0073e9SAndroid Build Coastguard Worker 191*da0073e9SAndroid Build Coastguard Worker 192*da0073e9SAndroid Build Coastguard Workerclass TestStaticModule(TestCase): 193*da0073e9SAndroid Build Coastguard Worker 194*da0073e9SAndroid Build Coastguard Worker """ 195*da0073e9SAndroid Build Coastguard Worker Test Case: To test simple fork/wait operation in a graph 196*da0073e9SAndroid Build Coastguard Worker fork is called on simple addition operation on input tensors 197*da0073e9SAndroid Build Coastguard Worker """ 198*da0073e9SAndroid Build Coastguard Worker def test_fork_wait_1(self): 199*da0073e9SAndroid Build Coastguard Worker inp1 = torch.ones(5, 5) 200*da0073e9SAndroid Build Coastguard Worker inp2 = torch.randn(5, 5) 201*da0073e9SAndroid Build Coastguard Worker torch_graph = torch.jit.script(fork_wait_graph1) 202*da0073e9SAndroid Build Coastguard Worker output_ref = torch_graph(inp1, inp2) 203*da0073e9SAndroid Build Coastguard Worker static_runtime_module = StaticModule(torch_graph) 204*da0073e9SAndroid Build Coastguard Worker output_test = static_runtime_module(inp1, inp2) 205*da0073e9SAndroid Build Coastguard Worker torch.testing.assert_close(output_test, output_ref) 206*da0073e9SAndroid Build Coastguard Worker 207*da0073e9SAndroid Build Coastguard Worker """ 208*da0073e9SAndroid Build Coastguard Worker Test Case: To test simple fork/wait operation with 209*da0073e9SAndroid Build Coastguard Worker StaticRuntime runAsync API returning future 210*da0073e9SAndroid Build Coastguard Worker """ 211*da0073e9SAndroid Build Coastguard Worker def test_fork_wait_1_async(self): 212*da0073e9SAndroid Build Coastguard Worker inp1 = torch.ones(5, 5) 213*da0073e9SAndroid Build Coastguard Worker inp2 = torch.randn(5, 5) 214*da0073e9SAndroid Build Coastguard Worker torch_graph = torch.jit.script(fork_wait_graph1) 215*da0073e9SAndroid Build Coastguard Worker output_ref = torch_graph(inp1, inp2) 216*da0073e9SAndroid Build Coastguard Worker static_runtime_module = StaticModule(torch_graph) 217*da0073e9SAndroid Build Coastguard Worker output_test = static_runtime_module.runAsync((inp1, inp2), {}) 218*da0073e9SAndroid Build Coastguard Worker output_test.wait() 219*da0073e9SAndroid Build Coastguard Worker torch.testing.assert_close(output_test.value(), output_ref) 220*da0073e9SAndroid Build Coastguard Worker 221*da0073e9SAndroid Build Coastguard Worker """ 222*da0073e9SAndroid Build Coastguard Worker Test Case: To test fork/wait operation in a graph on 223*da0073e9SAndroid Build Coastguard Worker a loop subgraph performing mix of operations 224*da0073e9SAndroid Build Coastguard Worker """ 225*da0073e9SAndroid Build Coastguard Worker def test_fork_wait_2(self): 226*da0073e9SAndroid Build Coastguard Worker inp1 = torch.randn(5, 5) 227*da0073e9SAndroid Build Coastguard Worker inp2 = torch.randn(5, 5) 228*da0073e9SAndroid Build Coastguard Worker torch_graph = torch.jit.script(fork_wait_graph2) 229*da0073e9SAndroid Build Coastguard Worker output_ref = torch_graph(inp1, inp2) 230*da0073e9SAndroid Build Coastguard Worker static_runtime_module = StaticModule(torch_graph) 231*da0073e9SAndroid Build Coastguard Worker output_test = static_runtime_module(inp1, inp2) 232*da0073e9SAndroid Build Coastguard Worker torch.testing.assert_close(output_test, output_ref) 233*da0073e9SAndroid Build Coastguard Worker 234*da0073e9SAndroid Build Coastguard Worker """ 235*da0073e9SAndroid Build Coastguard Worker Test Case: To test fork/wait operation on a loop 236*da0073e9SAndroid Build Coastguard Worker subgraph with StaticRuntime runAsync API returning future 237*da0073e9SAndroid Build Coastguard Worker """ 238*da0073e9SAndroid Build Coastguard Worker def test_fork_wait_2_async(self): 239*da0073e9SAndroid Build Coastguard Worker inp1 = torch.randn(5, 5) 240*da0073e9SAndroid Build Coastguard Worker inp2 = torch.randn(5, 5) 241*da0073e9SAndroid Build Coastguard Worker torch_graph = torch.jit.script(fork_wait_graph2) 242*da0073e9SAndroid Build Coastguard Worker output_ref = torch_graph(inp1, inp2) 243*da0073e9SAndroid Build Coastguard Worker static_runtime_module = StaticModule(torch_graph) 244*da0073e9SAndroid Build Coastguard Worker output_test = static_runtime_module.runAsync((inp1, inp2), {}) 245*da0073e9SAndroid Build Coastguard Worker output_test.wait() 246*da0073e9SAndroid Build Coastguard Worker torch.testing.assert_close(output_test.value(), output_ref) 247*da0073e9SAndroid Build Coastguard Worker 248*da0073e9SAndroid Build Coastguard Worker """ 249*da0073e9SAndroid Build Coastguard Worker Test Case: To test fork/wait operation in a graph on 250*da0073e9SAndroid Build Coastguard Worker having multiple fork/wait operations 251*da0073e9SAndroid Build Coastguard Worker """ 252*da0073e9SAndroid Build Coastguard Worker def test_fork_wait_3(self): 253*da0073e9SAndroid Build Coastguard Worker input = torch.ones(3, 3) 254*da0073e9SAndroid Build Coastguard Worker num_forks = 10 255*da0073e9SAndroid Build Coastguard Worker torch_graph = torch.jit.script(fork_wait_graph3) 256*da0073e9SAndroid Build Coastguard Worker output_ref = torch_graph(input, num_forks) 257*da0073e9SAndroid Build Coastguard Worker static_runtime_module = StaticModule(torch_graph) 258*da0073e9SAndroid Build Coastguard Worker output_test = static_runtime_module(input, num_forks) 259*da0073e9SAndroid Build Coastguard Worker torch.testing.assert_close(output_test, output_ref) 260*da0073e9SAndroid Build Coastguard Worker 261*da0073e9SAndroid Build Coastguard Worker """ 262*da0073e9SAndroid Build Coastguard Worker Test Case: To test fork/wait operation in a graph with 263*da0073e9SAndroid Build Coastguard Worker multiple fork/wait operations on runAsync API returning future 264*da0073e9SAndroid Build Coastguard Worker """ 265*da0073e9SAndroid Build Coastguard Worker def test_fork_wait_3_async(self): 266*da0073e9SAndroid Build Coastguard Worker input = torch.ones(3, 3) 267*da0073e9SAndroid Build Coastguard Worker num_forks = 10 268*da0073e9SAndroid Build Coastguard Worker torch_graph = torch.jit.script(fork_wait_graph3) 269*da0073e9SAndroid Build Coastguard Worker output_ref = torch_graph(input, num_forks) 270*da0073e9SAndroid Build Coastguard Worker static_runtime_module = StaticModule(torch_graph) 271*da0073e9SAndroid Build Coastguard Worker output_test = static_runtime_module.runAsync((input, num_forks), {}) 272*da0073e9SAndroid Build Coastguard Worker output_test.wait() 273*da0073e9SAndroid Build Coastguard Worker torch.testing.assert_close(output_test.value(), output_ref) 274*da0073e9SAndroid Build Coastguard Worker 275*da0073e9SAndroid Build Coastguard Worker """ 276*da0073e9SAndroid Build Coastguard Worker Test Case: To test fork/wait operation in a graph on 277*da0073e9SAndroid Build Coastguard Worker multiple nested fork/wait operations 278*da0073e9SAndroid Build Coastguard Worker """ 279*da0073e9SAndroid Build Coastguard Worker @unittest.skip("Broken test: https://github.com/pytorch/pytorch/issues/109782") 280*da0073e9SAndroid Build Coastguard Worker def test_fork_wait_4(self): 281*da0073e9SAndroid Build Coastguard Worker input = torch.ones(3, 3) 282*da0073e9SAndroid Build Coastguard Worker num_forks = 10 283*da0073e9SAndroid Build Coastguard Worker num_child_forks = 10 284*da0073e9SAndroid Build Coastguard Worker torch_graph = torch.jit.script(fork_wait_graph4) 285*da0073e9SAndroid Build Coastguard Worker static_runtime_module = StaticModule(torch_graph) 286*da0073e9SAndroid Build Coastguard Worker output_ref = torch_graph(input, num_forks, num_child_forks) 287*da0073e9SAndroid Build Coastguard Worker output_test = static_runtime_module(input, num_forks, num_child_forks) 288*da0073e9SAndroid Build Coastguard Worker torch.testing.assert_close(output_test, output_ref) 289*da0073e9SAndroid Build Coastguard Worker 290*da0073e9SAndroid Build Coastguard Worker """ 291*da0073e9SAndroid Build Coastguard Worker Test Case: To test fork/wait operation in a graph with multiple 292*da0073e9SAndroid Build Coastguard Worker nested fork/wait operations on runAsync API returning future 293*da0073e9SAndroid Build Coastguard Worker """ 294*da0073e9SAndroid Build Coastguard Worker @unittest.skip("Broken test: https://github.com/pytorch/pytorch/issues/109782") 295*da0073e9SAndroid Build Coastguard Worker def test_fork_wait_4_async(self): 296*da0073e9SAndroid Build Coastguard Worker input = torch.ones(3, 3) 297*da0073e9SAndroid Build Coastguard Worker num_forks = 10 298*da0073e9SAndroid Build Coastguard Worker num_child_forks = 10 299*da0073e9SAndroid Build Coastguard Worker torch_graph = torch.jit.script(fork_wait_graph4) 300*da0073e9SAndroid Build Coastguard Worker static_runtime_module = StaticModule(torch_graph) 301*da0073e9SAndroid Build Coastguard Worker output_ref = torch_graph(input, num_forks, num_child_forks) 302*da0073e9SAndroid Build Coastguard Worker output_test = static_runtime_module.runAsync( 303*da0073e9SAndroid Build Coastguard Worker (input, num_forks, num_child_forks), {}) 304*da0073e9SAndroid Build Coastguard Worker output_test.wait() 305*da0073e9SAndroid Build Coastguard Worker torch.testing.assert_close(output_test.value(), output_ref) 306*da0073e9SAndroid Build Coastguard Worker 307*da0073e9SAndroid Build Coastguard Worker """ 308*da0073e9SAndroid Build Coastguard Worker Test Case: To test exception handling in fork/wait 309*da0073e9SAndroid Build Coastguard Worker operation. Add.Tensor op is called for tensors with 310*da0073e9SAndroid Build Coastguard Worker non-matching dims on the forked subgraph and the 311*da0073e9SAndroid Build Coastguard Worker exception raised by subgraph is set on future returned 312*da0073e9SAndroid Build Coastguard Worker by prim::fork to parent graph. Returned exception is 313*da0073e9SAndroid Build Coastguard Worker checked for substring expected_error_msg as declared below 314*da0073e9SAndroid Build Coastguard Worker """ 315*da0073e9SAndroid Build Coastguard Worker def test_fork_wait_exception(self): 316*da0073e9SAndroid Build Coastguard Worker # incompatible tensors for add due to shape mismatch 317*da0073e9SAndroid Build Coastguard Worker input1 = torch.randn(4, 7) 318*da0073e9SAndroid Build Coastguard Worker input2 = torch.randn(4, 5) 319*da0073e9SAndroid Build Coastguard Worker torch_graph = torch.jit.script(fork_wait_graph_exception) 320*da0073e9SAndroid Build Coastguard Worker try: 321*da0073e9SAndroid Build Coastguard Worker static_runtime_module = StaticModule(torch_graph) 322*da0073e9SAndroid Build Coastguard Worker output_test = static_runtime_module(input1, input2) 323*da0073e9SAndroid Build Coastguard Worker except Exception as error: 324*da0073e9SAndroid Build Coastguard Worker expected_error_msg = ( 325*da0073e9SAndroid Build Coastguard Worker "The size of tensor a (7) must match the size " 326*da0073e9SAndroid Build Coastguard Worker "of tensor b (5) at non-singleton dimension 1" 327*da0073e9SAndroid Build Coastguard Worker ) 328*da0073e9SAndroid Build Coastguard Worker # test fails if error does not contain expected substr 329*da0073e9SAndroid Build Coastguard Worker if str(error).find(expected_error_msg) == -1: 330*da0073e9SAndroid Build Coastguard Worker raise RuntimeError( 331*da0073e9SAndroid Build Coastguard Worker "Tried execution of add.Tensors with incompatible shape. " 332*da0073e9SAndroid Build Coastguard Worker "Exception raised by forked runtime execution does " 333*da0073e9SAndroid Build Coastguard Worker f'not contain expected substring: "{expected_error_msg}"' 334*da0073e9SAndroid Build Coastguard Worker ) from error 335*da0073e9SAndroid Build Coastguard Worker 336*da0073e9SAndroid Build Coastguard Worker """ 337*da0073e9SAndroid Build Coastguard Worker Test Case: To test exception handling in fork/wait 338*da0073e9SAndroid Build Coastguard Worker operation with runAsync API. Add.Tensor op is called for 339*da0073e9SAndroid Build Coastguard Worker tensors with non-matching dims on the forked subgraph 340*da0073e9SAndroid Build Coastguard Worker and the exception raised by subgraph is set on future returned 341*da0073e9SAndroid Build Coastguard Worker by prim::fork to parent graph. Returned exception is 342*da0073e9SAndroid Build Coastguard Worker checked for substring expected_error_msg as declared below 343*da0073e9SAndroid Build Coastguard Worker """ 344*da0073e9SAndroid Build Coastguard Worker def test_fork_wait_exception_async(self): 345*da0073e9SAndroid Build Coastguard Worker # incompatible tensors for add due to shape mismatch 346*da0073e9SAndroid Build Coastguard Worker input1 = torch.randn(4, 7) 347*da0073e9SAndroid Build Coastguard Worker input2 = torch.randn(4, 5) 348*da0073e9SAndroid Build Coastguard Worker torch_graph = torch.jit.script(fork_wait_graph_exception) 349*da0073e9SAndroid Build Coastguard Worker try: 350*da0073e9SAndroid Build Coastguard Worker static_runtime_module = StaticModule(torch_graph) 351*da0073e9SAndroid Build Coastguard Worker output_test = static_runtime_module.runAsync( 352*da0073e9SAndroid Build Coastguard Worker (input1, input2), {}) 353*da0073e9SAndroid Build Coastguard Worker except Exception as error: 354*da0073e9SAndroid Build Coastguard Worker expected_error_msg = ( 355*da0073e9SAndroid Build Coastguard Worker "The size of tensor a (7) must match the size " 356*da0073e9SAndroid Build Coastguard Worker "of tensor b (5) at non-singleton dimension 1" 357*da0073e9SAndroid Build Coastguard Worker ) 358*da0073e9SAndroid Build Coastguard Worker # test fails if error does not contain expected substr 359*da0073e9SAndroid Build Coastguard Worker if str(error).find(expected_error_msg) == -1: 360*da0073e9SAndroid Build Coastguard Worker raise RuntimeError( 361*da0073e9SAndroid Build Coastguard Worker "Tried execution of add.Tensors with incompatible shape. " 362*da0073e9SAndroid Build Coastguard Worker "Exception raised by forked runtime execution does " 363*da0073e9SAndroid Build Coastguard Worker f'not contain expected substring: "{expected_error_msg}"' 364*da0073e9SAndroid Build Coastguard Worker ) from error 365*da0073e9SAndroid Build Coastguard Worker 366*da0073e9SAndroid Build Coastguard Worker def test_multihead_attention_layer(self): 367*da0073e9SAndroid Build Coastguard Worker HID_DIM = 256 368*da0073e9SAndroid Build Coastguard Worker QUERY_LEN = 8 369*da0073e9SAndroid Build Coastguard Worker BATCH_SIZE = 128 370*da0073e9SAndroid Build Coastguard Worker LAYERS = 3 371*da0073e9SAndroid Build Coastguard Worker HEADS = 8 372*da0073e9SAndroid Build Coastguard Worker DROPOUT = 0.1 373*da0073e9SAndroid Build Coastguard Worker device = torch.device("cpu") 374*da0073e9SAndroid Build Coastguard Worker attention = MultiHeadAttentionLayer(HID_DIM, HEADS, DROPOUT, device).to(device) 375*da0073e9SAndroid Build Coastguard Worker with torch.no_grad(): 376*da0073e9SAndroid Build Coastguard Worker src = torch.randn(BATCH_SIZE, QUERY_LEN, HID_DIM).to(device) 377*da0073e9SAndroid Build Coastguard Worker src_mask = (src > 0)[:, :, 0].unsqueeze(1).unsqueeze(2).to(device) 378*da0073e9SAndroid Build Coastguard Worker 379*da0073e9SAndroid Build Coastguard Worker attention.eval() 380*da0073e9SAndroid Build Coastguard Worker attention = torch.jit.script(attention) 381*da0073e9SAndroid Build Coastguard Worker attention.eval() 382*da0073e9SAndroid Build Coastguard Worker o_ref = attention(src, src, src, src_mask) 383*da0073e9SAndroid Build Coastguard Worker 384*da0073e9SAndroid Build Coastguard Worker attention_a = StaticModule(attention) 385*da0073e9SAndroid Build Coastguard Worker o_test = attention_a(src, src, src, src_mask) 386*da0073e9SAndroid Build Coastguard Worker o_test_kw = attention_a(src, src, value=src, mask=src_mask) 387*da0073e9SAndroid Build Coastguard Worker 388*da0073e9SAndroid Build Coastguard Worker for a, b in zip(o_ref, o_test): 389*da0073e9SAndroid Build Coastguard Worker torch.testing.assert_close(a, b) 390*da0073e9SAndroid Build Coastguard Worker 391*da0073e9SAndroid Build Coastguard Worker for a, b in zip(o_ref, o_test_kw): 392*da0073e9SAndroid Build Coastguard Worker torch.testing.assert_close(a, b) 393*da0073e9SAndroid Build Coastguard Worker 394*da0073e9SAndroid Build Coastguard Worker def test_multihead_attention_layer_benchmark(self): 395*da0073e9SAndroid Build Coastguard Worker HID_DIM = 256 396*da0073e9SAndroid Build Coastguard Worker QUERY_LEN = 8 397*da0073e9SAndroid Build Coastguard Worker BATCH_SIZE = 128 398*da0073e9SAndroid Build Coastguard Worker LAYERS = 3 399*da0073e9SAndroid Build Coastguard Worker HEADS = 8 400*da0073e9SAndroid Build Coastguard Worker DROPOUT = 0.1 401*da0073e9SAndroid Build Coastguard Worker device = torch.device("cpu") 402*da0073e9SAndroid Build Coastguard Worker attention = MultiHeadAttentionLayer(HID_DIM, HEADS, DROPOUT, device).to(device) 403*da0073e9SAndroid Build Coastguard Worker with torch.no_grad(): 404*da0073e9SAndroid Build Coastguard Worker src = torch.randn(BATCH_SIZE, QUERY_LEN, HID_DIM).to(device) 405*da0073e9SAndroid Build Coastguard Worker src_mask = (src > 0)[:, :, 0].unsqueeze(1).unsqueeze(2).to(device) 406*da0073e9SAndroid Build Coastguard Worker 407*da0073e9SAndroid Build Coastguard Worker attention.eval() 408*da0073e9SAndroid Build Coastguard Worker attention = torch.jit.script(attention) 409*da0073e9SAndroid Build Coastguard Worker attention_a = StaticModule(attention) 410*da0073e9SAndroid Build Coastguard Worker 411*da0073e9SAndroid Build Coastguard Worker attention_a.benchmark([src, src, src, src_mask], {}, 2, 2) 412*da0073e9SAndroid Build Coastguard Worker metrics = attention_a.benchmark_individual_ops( 413*da0073e9SAndroid Build Coastguard Worker [src, src, src, src_mask], {}, 2, 2 414*da0073e9SAndroid Build Coastguard Worker ) 415*da0073e9SAndroid Build Coastguard Worker 416*da0073e9SAndroid Build Coastguard Worker def test_mlp(self): 417*da0073e9SAndroid Build Coastguard Worker # Arguments taken from benchmark script, ./bench/dlrm_s_benchmark.sh 418*da0073e9SAndroid Build Coastguard Worker ln_bot = [512, 512, 64] 419*da0073e9SAndroid Build Coastguard Worker sigmoid_bot = -1 420*da0073e9SAndroid Build Coastguard Worker ln_top = [100, 1024, 1024, 1024, 1] 421*da0073e9SAndroid Build Coastguard Worker sigmoid_top = 3 422*da0073e9SAndroid Build Coastguard Worker bot_l = create_mlp(ln_bot, sigmoid_bot) 423*da0073e9SAndroid Build Coastguard Worker bot_l_acc = StaticModule(bot_l) 424*da0073e9SAndroid Build Coastguard Worker top_l = create_mlp(ln_top, sigmoid_top) 425*da0073e9SAndroid Build Coastguard Worker top_l_acc = StaticModule(top_l) 426*da0073e9SAndroid Build Coastguard Worker with torch.no_grad(): 427*da0073e9SAndroid Build Coastguard Worker bot_inp = torch.randn(2048, 512) # torch.Size([2048, 512]) 428*da0073e9SAndroid Build Coastguard Worker top_inp = torch.randn(2048, 100) # torch.Size([2048, 100]) 429*da0073e9SAndroid Build Coastguard Worker ref_bot = bot_l(bot_inp) 430*da0073e9SAndroid Build Coastguard Worker acc_bot = bot_l_acc(bot_inp) 431*da0073e9SAndroid Build Coastguard Worker torch.testing.assert_close(acc_bot, ref_bot) 432*da0073e9SAndroid Build Coastguard Worker ref_top = top_l(top_inp) 433*da0073e9SAndroid Build Coastguard Worker acc_top = top_l_acc(top_inp) 434*da0073e9SAndroid Build Coastguard Worker torch.testing.assert_close(acc_top, ref_top) 435*da0073e9SAndroid Build Coastguard Worker for _ in range(5): 436*da0073e9SAndroid Build Coastguard Worker with torch.no_grad(): 437*da0073e9SAndroid Build Coastguard Worker bot_inp = torch.randn(2048, 512) # torch.Size([2048, 512]) 438*da0073e9SAndroid Build Coastguard Worker top_inp = torch.randn(2048, 100) # torch.Size([2048, 100]) 439*da0073e9SAndroid Build Coastguard Worker ref_bot = bot_l(bot_inp) 440*da0073e9SAndroid Build Coastguard Worker acc_bot = bot_l_acc(bot_inp) 441*da0073e9SAndroid Build Coastguard Worker torch.testing.assert_close(acc_bot, ref_bot) 442*da0073e9SAndroid Build Coastguard Worker ref_top = top_l(top_inp) 443*da0073e9SAndroid Build Coastguard Worker acc_top = top_l_acc(top_inp) 444*da0073e9SAndroid Build Coastguard Worker torch.testing.assert_close(acc_top, ref_top) 445*da0073e9SAndroid Build Coastguard Worker 446*da0073e9SAndroid Build Coastguard Worker def test_trivial_graph(self): 447*da0073e9SAndroid Build Coastguard Worker s = torch.full((2, 2), 2) 448*da0073e9SAndroid Build Coastguard Worker tg = torch.jit.script(trivial_graph) 449*da0073e9SAndroid Build Coastguard Worker o_ref = tg(s, s, s) 450*da0073e9SAndroid Build Coastguard Worker tg_a = StaticModule(tg) 451*da0073e9SAndroid Build Coastguard Worker o_test = tg_a(s, s, s) 452*da0073e9SAndroid Build Coastguard Worker torch.testing.assert_close(o_ref, o_test) 453*da0073e9SAndroid Build Coastguard Worker 454*da0073e9SAndroid Build Coastguard Worker def test_leaky_relu(self): 455*da0073e9SAndroid Build Coastguard Worker s = torch.randn(5, 5) 456*da0073e9SAndroid Build Coastguard Worker tg = torch.jit.script(nn.LeakyReLU(0.1)) 457*da0073e9SAndroid Build Coastguard Worker o_ref = tg(s) 458*da0073e9SAndroid Build Coastguard Worker tg_a = StaticModule(tg) 459*da0073e9SAndroid Build Coastguard Worker o_test = tg_a(s) 460*da0073e9SAndroid Build Coastguard Worker torch.testing.assert_close(o_ref, o_test) 461*da0073e9SAndroid Build Coastguard Worker 462*da0073e9SAndroid Build Coastguard Worker def test_attr(self): 463*da0073e9SAndroid Build Coastguard Worker """ 464*da0073e9SAndroid Build Coastguard Worker TorchScript IR of TestModule() after freezing: 465*da0073e9SAndroid Build Coastguard Worker graph(%self : __torch__.test_static_runtime.___torch_mangle_0.TestModule, 466*da0073e9SAndroid Build Coastguard Worker %x.1 : Tensor): 467*da0073e9SAndroid Build Coastguard Worker %18 : int = prim::Constant[value=30]() 468*da0073e9SAndroid Build Coastguard Worker %30 : int = prim::Constant[value=13]() 469*da0073e9SAndroid Build Coastguard Worker %3 : int = prim::Constant[value=20]() 470*da0073e9SAndroid Build Coastguard Worker %2 : int = prim::Constant[value=1]() 471*da0073e9SAndroid Build Coastguard Worker %self.sub2.a : int = prim::Constant[value=12]() 472*da0073e9SAndroid Build Coastguard Worker %self.a : int = prim::Constant[value=3]() 473*da0073e9SAndroid Build Coastguard Worker = prim::SetAttr[name="b"](%self, %3) 474*da0073e9SAndroid Build Coastguard Worker %17 : Tensor = aten::add(%x.1, %30, %2) 475*da0073e9SAndroid Build Coastguard Worker %7 : Tensor = aten::add(%17, %self.a, %2) 476*da0073e9SAndroid Build Coastguard Worker %b.1 : int = prim::GetAttr[name="b"](%self) 477*da0073e9SAndroid Build Coastguard Worker %9 : Tensor = aten::add(%7, %b.1, %2) 478*da0073e9SAndroid Build Coastguard Worker %sub2 : __torch__.test_static_runtime.___torch_mangle_2.SubModule2 = prim::GetAttr[name="sub2"](%self) 479*da0073e9SAndroid Build Coastguard Worker = prim::SetAttr[name="b"](%sub2, %18) 480*da0073e9SAndroid Build Coastguard Worker %b : int = prim::GetAttr[name="b"](%sub2) 481*da0073e9SAndroid Build Coastguard Worker %22 : int = aten::add(%self.sub2.a, %b) 482*da0073e9SAndroid Build Coastguard Worker %23 : Tensor = aten::add(%x.1, %22, %2) 483*da0073e9SAndroid Build Coastguard Worker %12 : Tensor = aten::add(%9, %23, %2) 484*da0073e9SAndroid Build Coastguard Worker return (%12) 485*da0073e9SAndroid Build Coastguard Worker """ 486*da0073e9SAndroid Build Coastguard Worker # test prim::SetAttr and prim::GetAttr impl in Static Runtime 487*da0073e9SAndroid Build Coastguard Worker m = TestModule() 488*da0073e9SAndroid Build Coastguard Worker 489*da0073e9SAndroid Build Coastguard Worker m.eval() 490*da0073e9SAndroid Build Coastguard Worker input = torch.randn(2, 2) 491*da0073e9SAndroid Build Coastguard Worker output_s = m.forward(input) 492*da0073e9SAndroid Build Coastguard Worker 493*da0073e9SAndroid Build Coastguard Worker ms = torch.jit.script(m) 494*da0073e9SAndroid Build Coastguard Worker sm = StaticModule(ms) 495*da0073e9SAndroid Build Coastguard Worker output_sm = sm(input) 496*da0073e9SAndroid Build Coastguard Worker torch.testing.assert_close(output_s, output_sm) 497*da0073e9SAndroid Build Coastguard Worker sm.benchmark([input], {}, 2, 2) 498*da0073e9SAndroid Build Coastguard Worker sm.benchmark_individual_ops([input], {}, 2, 2) 499*da0073e9SAndroid Build Coastguard Worker sm.benchmark([], {"x": input}, 2, 2) 500*da0073e9SAndroid Build Coastguard Worker sm.benchmark_individual_ops([], {"x": input}, 2, 2) 501*da0073e9SAndroid Build Coastguard Worker 502*da0073e9SAndroid Build Coastguard Worker @unittest.skip("Temporarily disabled") 503*da0073e9SAndroid Build Coastguard Worker def test_fusion_trivial_graph(self): 504*da0073e9SAndroid Build Coastguard Worker s = torch.full((2, 2), 2) 505*da0073e9SAndroid Build Coastguard Worker tg = torch.jit.script(trivial_graph) 506*da0073e9SAndroid Build Coastguard Worker o_ref = tg(s, s, s) 507*da0073e9SAndroid Build Coastguard Worker torch._C._fuse_to_static_module(tg.graph) 508*da0073e9SAndroid Build Coastguard Worker assert "StaticSubgraph" in str(tg.graph) 509*da0073e9SAndroid Build Coastguard Worker o_test = tg(s, s, s) 510*da0073e9SAndroid Build Coastguard Worker torch.testing.assert_close(o_ref, o_test) 511*da0073e9SAndroid Build Coastguard Worker 512*da0073e9SAndroid Build Coastguard Worker @unittest.skip("Temporarily disabled") 513*da0073e9SAndroid Build Coastguard Worker def test_fusion_multihead_attention_layer(self): 514*da0073e9SAndroid Build Coastguard Worker HID_DIM = 256 515*da0073e9SAndroid Build Coastguard Worker QUERY_LEN = 8 516*da0073e9SAndroid Build Coastguard Worker BATCH_SIZE = 128 517*da0073e9SAndroid Build Coastguard Worker LAYERS = 3 518*da0073e9SAndroid Build Coastguard Worker HEADS = 8 519*da0073e9SAndroid Build Coastguard Worker DROPOUT = 0.1 520*da0073e9SAndroid Build Coastguard Worker device = torch.device("cpu") 521*da0073e9SAndroid Build Coastguard Worker attention = MultiHeadAttentionLayer(HID_DIM, HEADS, DROPOUT, device).to(device) 522*da0073e9SAndroid Build Coastguard Worker with torch.no_grad(): 523*da0073e9SAndroid Build Coastguard Worker src = torch.randn(BATCH_SIZE, QUERY_LEN, HID_DIM).to(device) 524*da0073e9SAndroid Build Coastguard Worker src_mask = (src > 0)[:, :, 0].unsqueeze(1).unsqueeze(2).to(device) 525*da0073e9SAndroid Build Coastguard Worker 526*da0073e9SAndroid Build Coastguard Worker attention.eval() 527*da0073e9SAndroid Build Coastguard Worker attention = torch.jit.script(attention) 528*da0073e9SAndroid Build Coastguard Worker attention.eval() 529*da0073e9SAndroid Build Coastguard Worker o_ref = attention(src, src, src, src_mask) 530*da0073e9SAndroid Build Coastguard Worker 531*da0073e9SAndroid Build Coastguard Worker torch._C._fuse_to_static_module(attention._c) 532*da0073e9SAndroid Build Coastguard Worker o_test = attention(src, src, src, src_mask) 533*da0073e9SAndroid Build Coastguard Worker 534*da0073e9SAndroid Build Coastguard Worker for a, b in zip(o_ref, o_test): 535*da0073e9SAndroid Build Coastguard Worker torch.testing.assert_close(a, b) 536*da0073e9SAndroid Build Coastguard Worker 537*da0073e9SAndroid Build Coastguard Worker @unittest.skip("Temporarily disabled") 538*da0073e9SAndroid Build Coastguard Worker def test_fusion_loop(self): 539*da0073e9SAndroid Build Coastguard Worker a = torch.randn(5, 5) 540*da0073e9SAndroid Build Coastguard Worker b = torch.randn(5, 5) 541*da0073e9SAndroid Build Coastguard Worker c = 4 542*da0073e9SAndroid Build Coastguard Worker lg = torch.jit.script(loop_graph) 543*da0073e9SAndroid Build Coastguard Worker o_ref = lg(a, b, c) 544*da0073e9SAndroid Build Coastguard Worker torch._C._fuse_to_static_module(lg.graph) 545*da0073e9SAndroid Build Coastguard Worker assert "StaticSubgraph" in str(lg.graph) 546*da0073e9SAndroid Build Coastguard Worker o_test = lg(a, b, c) 547*da0073e9SAndroid Build Coastguard Worker torch.testing.assert_close(o_ref, o_test) 548*da0073e9SAndroid Build Coastguard Worker 549*da0073e9SAndroid Build Coastguard Worker @unittest.skip("Temporarily disabled") 550*da0073e9SAndroid Build Coastguard Worker def test_fusion_outputs(self): 551*da0073e9SAndroid Build Coastguard Worker a = torch.randn(2, 2) 552*da0073e9SAndroid Build Coastguard Worker b = torch.randn(2, 2) 553*da0073e9SAndroid Build Coastguard Worker c = 4 554*da0073e9SAndroid Build Coastguard Worker og = torch.jit.script(output_graph) 555*da0073e9SAndroid Build Coastguard Worker o_ref = og(a, b, b, c) 556*da0073e9SAndroid Build Coastguard Worker torch._C._fuse_to_static_module(og.graph) 557*da0073e9SAndroid Build Coastguard Worker assert "StaticSubgraph" in str(og.graph) 558*da0073e9SAndroid Build Coastguard Worker o_test = og(a, b, b, c) 559*da0073e9SAndroid Build Coastguard Worker for i in o_ref.keys(): 560*da0073e9SAndroid Build Coastguard Worker torch.testing.assert_close(o_ref[i], o_test[i]) 561*da0073e9SAndroid Build Coastguard Worker 562*da0073e9SAndroid Build Coastguard Worker def test_create_object(self): 563*da0073e9SAndroid Build Coastguard Worker class Foo: # noqa: B903 564*da0073e9SAndroid Build Coastguard Worker def __init__(self, x: torch.Tensor) -> None: 565*da0073e9SAndroid Build Coastguard Worker self.x = x 566*da0073e9SAndroid Build Coastguard Worker 567*da0073e9SAndroid Build Coastguard Worker class Mod(torch.nn.Module): 568*da0073e9SAndroid Build Coastguard Worker def __init__(self) -> None: 569*da0073e9SAndroid Build Coastguard Worker super().__init__() 570*da0073e9SAndroid Build Coastguard Worker 571*da0073e9SAndroid Build Coastguard Worker def forward(self, y: torch.Tensor) -> torch.Tensor: 572*da0073e9SAndroid Build Coastguard Worker foo = Foo(y) 573*da0073e9SAndroid Build Coastguard Worker return y * foo.x 574*da0073e9SAndroid Build Coastguard Worker 575*da0073e9SAndroid Build Coastguard Worker mod = torch.jit.script(Mod()).eval() 576*da0073e9SAndroid Build Coastguard Worker y = torch.randn((1, )) 577*da0073e9SAndroid Build Coastguard Worker expected = mod(y) 578*da0073e9SAndroid Build Coastguard Worker 579*da0073e9SAndroid Build Coastguard Worker static_mod = StaticModule(torch.jit.freeze(mod)) 580*da0073e9SAndroid Build Coastguard Worker actual = static_mod(y) 581*da0073e9SAndroid Build Coastguard Worker 582*da0073e9SAndroid Build Coastguard Worker self.assertEqual(expected, actual) 583*da0073e9SAndroid Build Coastguard Worker 584*da0073e9SAndroid Build Coastguard Workerif __name__ == "__main__": 585*da0073e9SAndroid Build Coastguard Worker run_tests() 586