xref: /aosp_15_r20/external/pytorch/test/test_static_runtime.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1*da0073e9SAndroid Build Coastguard Worker# Owner(s): ["module: unknown"]
2*da0073e9SAndroid Build Coastguard Worker
3*da0073e9SAndroid Build Coastguard Workerimport unittest
4*da0073e9SAndroid Build Coastguard Workerfrom typing import Dict, Optional
5*da0073e9SAndroid Build Coastguard Worker
6*da0073e9SAndroid Build Coastguard Workerimport numpy as np
7*da0073e9SAndroid Build Coastguard Workerimport torch
8*da0073e9SAndroid Build Coastguard Workerfrom torch import nn
9*da0073e9SAndroid Build Coastguard Workerfrom torch.testing._internal.common_utils import TestCase, run_tests
10*da0073e9SAndroid Build Coastguard Workerfrom torch.testing._internal.static_module import StaticModule
11*da0073e9SAndroid Build Coastguard Workerfrom typing import List
12*da0073e9SAndroid Build Coastguard Worker
13*da0073e9SAndroid Build Coastguard Worker
14*da0073e9SAndroid Build Coastguard Workerdef linear_shim(
15*da0073e9SAndroid Build Coastguard Worker    input: torch.Tensor, weight: torch.Tensor, bias: Optional[torch.Tensor] = None
16*da0073e9SAndroid Build Coastguard Worker) -> torch.Tensor:
17*da0073e9SAndroid Build Coastguard Worker    output = input.matmul(weight.t())
18*da0073e9SAndroid Build Coastguard Worker    if bias is not None:
19*da0073e9SAndroid Build Coastguard Worker        output += bias
20*da0073e9SAndroid Build Coastguard Worker    ret = output
21*da0073e9SAndroid Build Coastguard Worker    return ret
22*da0073e9SAndroid Build Coastguard Worker
23*da0073e9SAndroid Build Coastguard Worker
24*da0073e9SAndroid Build Coastguard Workertorch.nn.functional.linear = linear_shim
25*da0073e9SAndroid Build Coastguard Worker
26*da0073e9SAndroid Build Coastguard Worker
27*da0073e9SAndroid Build Coastguard Workerclass MultiHeadAttentionLayer(nn.Module):
28*da0073e9SAndroid Build Coastguard Worker    def __init__(self, hid_dim, n_heads, dropout, device):
29*da0073e9SAndroid Build Coastguard Worker        super().__init__()
30*da0073e9SAndroid Build Coastguard Worker        assert hid_dim % n_heads == 0
31*da0073e9SAndroid Build Coastguard Worker        self.hid_dim = hid_dim
32*da0073e9SAndroid Build Coastguard Worker        self.n_heads = n_heads
33*da0073e9SAndroid Build Coastguard Worker        self.head_dim = hid_dim // n_heads
34*da0073e9SAndroid Build Coastguard Worker        self.fc_q = nn.Linear(hid_dim, hid_dim)
35*da0073e9SAndroid Build Coastguard Worker        self.fc_k = nn.Linear(hid_dim, hid_dim)
36*da0073e9SAndroid Build Coastguard Worker        self.fc_v = nn.Linear(hid_dim, hid_dim)
37*da0073e9SAndroid Build Coastguard Worker        self.fc_o = nn.Linear(hid_dim, hid_dim)
38*da0073e9SAndroid Build Coastguard Worker        # self.dropout = nn.Dropout(dropout)
39*da0073e9SAndroid Build Coastguard Worker        self.scale = torch.sqrt(torch.FloatTensor([self.head_dim])).to(device)
40*da0073e9SAndroid Build Coastguard Worker
41*da0073e9SAndroid Build Coastguard Worker    def forward(self, query, key, value, mask):
42*da0073e9SAndroid Build Coastguard Worker        batch_size = query.shape[0]
43*da0073e9SAndroid Build Coastguard Worker        Q = self.fc_q(query)
44*da0073e9SAndroid Build Coastguard Worker        K = self.fc_k(key)
45*da0073e9SAndroid Build Coastguard Worker        V = self.fc_v(value)
46*da0073e9SAndroid Build Coastguard Worker        Q = Q.view(batch_size, -1, self.n_heads, self.head_dim).permute(0, 2, 1, 3)
47*da0073e9SAndroid Build Coastguard Worker        K = K.view(batch_size, -1, self.n_heads, self.head_dim).permute(0, 2, 1, 3)
48*da0073e9SAndroid Build Coastguard Worker        V = V.view(batch_size, -1, self.n_heads, self.head_dim).permute(0, 2, 1, 3)
49*da0073e9SAndroid Build Coastguard Worker        energy = torch.matmul(Q, K.permute(0, 1, 3, 2)) / self.scale
50*da0073e9SAndroid Build Coastguard Worker        # energy = energy.masked_fill(mask == 0, -1e10)
51*da0073e9SAndroid Build Coastguard Worker        attention = torch.softmax(energy, dim=-1)
52*da0073e9SAndroid Build Coastguard Worker        # x = torch.matmul(self.dropout(attention), V)
53*da0073e9SAndroid Build Coastguard Worker        x = torch.matmul(attention, V)
54*da0073e9SAndroid Build Coastguard Worker        x = x.permute(0, 2, 1, 3).contiguous()
55*da0073e9SAndroid Build Coastguard Worker        x = x.view(batch_size, -1, self.hid_dim)
56*da0073e9SAndroid Build Coastguard Worker        x = self.fc_o(x)
57*da0073e9SAndroid Build Coastguard Worker        return x, attention
58*da0073e9SAndroid Build Coastguard Worker
59*da0073e9SAndroid Build Coastguard Worker
60*da0073e9SAndroid Build Coastguard Worker# Taken from https://github.com/facebookresearch/dlrm/blob/master/dlrm_s_pytorch.py
61*da0073e9SAndroid Build Coastguard Workerdef create_mlp(ln, sigmoid_layer):
62*da0073e9SAndroid Build Coastguard Worker    layers = nn.ModuleList()
63*da0073e9SAndroid Build Coastguard Worker    for i in range(0, len(ln) - 1):
64*da0073e9SAndroid Build Coastguard Worker        n = ln[i]
65*da0073e9SAndroid Build Coastguard Worker        m = ln[i + 1]
66*da0073e9SAndroid Build Coastguard Worker
67*da0073e9SAndroid Build Coastguard Worker        LL = nn.Linear(int(n), int(m), bias=True)
68*da0073e9SAndroid Build Coastguard Worker
69*da0073e9SAndroid Build Coastguard Worker        mean = 0.0  # std_dev = np.sqrt(variance)
70*da0073e9SAndroid Build Coastguard Worker        std_dev = np.sqrt(2 / (m + n))  # np.sqrt(1 / m) # np.sqrt(1 / n)
71*da0073e9SAndroid Build Coastguard Worker        W = np.random.normal(mean, std_dev, size=(m, n)).astype(np.float32)
72*da0073e9SAndroid Build Coastguard Worker        std_dev = np.sqrt(1 / m)  # np.sqrt(2 / (m + 1))
73*da0073e9SAndroid Build Coastguard Worker        bt = np.random.normal(mean, std_dev, size=m).astype(np.float32)
74*da0073e9SAndroid Build Coastguard Worker        LL.weight.data = torch.tensor(W, requires_grad=True)
75*da0073e9SAndroid Build Coastguard Worker        LL.bias.data = torch.tensor(bt, requires_grad=True)
76*da0073e9SAndroid Build Coastguard Worker        layers.append(LL)
77*da0073e9SAndroid Build Coastguard Worker
78*da0073e9SAndroid Build Coastguard Worker        if i == sigmoid_layer:
79*da0073e9SAndroid Build Coastguard Worker            layers.append(nn.Sigmoid())
80*da0073e9SAndroid Build Coastguard Worker        else:
81*da0073e9SAndroid Build Coastguard Worker            layers.append(nn.ReLU())
82*da0073e9SAndroid Build Coastguard Worker
83*da0073e9SAndroid Build Coastguard Worker    with torch.no_grad():
84*da0073e9SAndroid Build Coastguard Worker        s = torch.jit.script(torch.nn.Sequential(*layers))
85*da0073e9SAndroid Build Coastguard Worker    s.eval()
86*da0073e9SAndroid Build Coastguard Worker    return s
87*da0073e9SAndroid Build Coastguard Worker
88*da0073e9SAndroid Build Coastguard Worker
89*da0073e9SAndroid Build Coastguard Workerdef trivial_graph(a, b, c):
90*da0073e9SAndroid Build Coastguard Worker    s = torch.tensor([[3, 3], [3, 3]])
91*da0073e9SAndroid Build Coastguard Worker    return a + b * c + s
92*da0073e9SAndroid Build Coastguard Worker
93*da0073e9SAndroid Build Coastguard Workerdef elementwise_square_addition(input1, input2):
94*da0073e9SAndroid Build Coastguard Worker    return input1 * input1 + input2 * input2
95*da0073e9SAndroid Build Coastguard Worker
96*da0073e9SAndroid Build Coastguard Workerdef fork_wait_graph1(input1, input2):
97*da0073e9SAndroid Build Coastguard Worker    fut = torch.jit.fork(elementwise_square_addition, input1, input2)
98*da0073e9SAndroid Build Coastguard Worker    return torch.jit.wait(fut)
99*da0073e9SAndroid Build Coastguard Worker
100*da0073e9SAndroid Build Coastguard Workerdef fork_wait_graph2(input1, input2):
101*da0073e9SAndroid Build Coastguard Worker    fut = torch.jit.fork(loop_graph, input1, input2, 5)
102*da0073e9SAndroid Build Coastguard Worker    return torch.jit.wait(fut)
103*da0073e9SAndroid Build Coastguard Worker
104*da0073e9SAndroid Build Coastguard Worker"""
105*da0073e9SAndroid Build Coastguard Worker   graph with multiple fork/wait operations
106*da0073e9SAndroid Build Coastguard Worker   :param input: torch.tensor input to forked subgraph
107*da0073e9SAndroid Build Coastguard Worker   :param iters: number of future/wait pairs to be created
108*da0073e9SAndroid Build Coastguard Worker"""
109*da0073e9SAndroid Build Coastguard Workerdef fork_wait_graph3(input, iters: int):
110*da0073e9SAndroid Build Coastguard Worker    futures : List[torch.jit.Future[torch.Tensor]] = []
111*da0073e9SAndroid Build Coastguard Worker    for _ in range(iters):
112*da0073e9SAndroid Build Coastguard Worker        futures.append(torch.jit.fork(torch.neg, input))
113*da0073e9SAndroid Build Coastguard Worker    results = []
114*da0073e9SAndroid Build Coastguard Worker    for future in futures:
115*da0073e9SAndroid Build Coastguard Worker        results.append(torch.jit.wait(future))
116*da0073e9SAndroid Build Coastguard Worker    return torch.sum(torch.stack(results))
117*da0073e9SAndroid Build Coastguard Worker
118*da0073e9SAndroid Build Coastguard Worker"""
119*da0073e9SAndroid Build Coastguard Worker   graph with multi-level fork/wait operations
120*da0073e9SAndroid Build Coastguard Worker   :param input: torch.tensor input to forked subgraph
121*da0073e9SAndroid Build Coastguard Worker   :param num_forks: number of top level forks
122*da0073e9SAndroid Build Coastguard Worker   :param num_child_forks: number of child forks per parent fork
123*da0073e9SAndroid Build Coastguard Worker"""
124*da0073e9SAndroid Build Coastguard Workerdef fork_wait_graph4(input, num_forks: int, num_child_forks: int):
125*da0073e9SAndroid Build Coastguard Worker    futures : List[torch.jit.Future[torch.Tensor]] = []
126*da0073e9SAndroid Build Coastguard Worker    for _ in range(num_forks):
127*da0073e9SAndroid Build Coastguard Worker        futures.append(torch.jit.fork(fork_wait_graph3, input, num_child_forks))
128*da0073e9SAndroid Build Coastguard Worker    results = []
129*da0073e9SAndroid Build Coastguard Worker    for future in futures:
130*da0073e9SAndroid Build Coastguard Worker        results.append(torch.jit.wait(future))
131*da0073e9SAndroid Build Coastguard Worker    return torch.sum(torch.stack(results))
132*da0073e9SAndroid Build Coastguard Worker
133*da0073e9SAndroid Build Coastguard Workerdef add_tensor(input1, input2):
134*da0073e9SAndroid Build Coastguard Worker    return input1 + input2
135*da0073e9SAndroid Build Coastguard Worker
136*da0073e9SAndroid Build Coastguard Workerdef fork_wait_graph_exception(input1, input2):
137*da0073e9SAndroid Build Coastguard Worker    fut = torch.jit.fork(add_tensor, input1, input2)
138*da0073e9SAndroid Build Coastguard Worker    return torch.jit.wait(fut)
139*da0073e9SAndroid Build Coastguard Worker
140*da0073e9SAndroid Build Coastguard Workerdef loop_graph(a, b, iters: int):
141*da0073e9SAndroid Build Coastguard Worker    c = a + b * 2
142*da0073e9SAndroid Build Coastguard Worker    for i in range(iters):
143*da0073e9SAndroid Build Coastguard Worker        c = c + b
144*da0073e9SAndroid Build Coastguard Worker        c *= 2
145*da0073e9SAndroid Build Coastguard Worker        c -= a
146*da0073e9SAndroid Build Coastguard Worker    return c
147*da0073e9SAndroid Build Coastguard Worker
148*da0073e9SAndroid Build Coastguard Worker
149*da0073e9SAndroid Build Coastguard Workerdef output_graph(a, b, c, iters: int):
150*da0073e9SAndroid Build Coastguard Worker    s = torch.tensor([[3, 3], [3, 3]])
151*da0073e9SAndroid Build Coastguard Worker    k = a + b * c + s
152*da0073e9SAndroid Build Coastguard Worker    d: Dict[int, torch.Tensor] = {}
153*da0073e9SAndroid Build Coastguard Worker    for i in range(iters):
154*da0073e9SAndroid Build Coastguard Worker        d[i] = k + i
155*da0073e9SAndroid Build Coastguard Worker    return d
156*da0073e9SAndroid Build Coastguard Worker
157*da0073e9SAndroid Build Coastguard Worker
158*da0073e9SAndroid Build Coastguard Workerclass SubModule(nn.Module):
159*da0073e9SAndroid Build Coastguard Worker    def __init__(self) -> None:
160*da0073e9SAndroid Build Coastguard Worker        super().__init__()
161*da0073e9SAndroid Build Coastguard Worker        self.a = 11
162*da0073e9SAndroid Build Coastguard Worker        self.b = 2
163*da0073e9SAndroid Build Coastguard Worker
164*da0073e9SAndroid Build Coastguard Worker    def forward(self, x):
165*da0073e9SAndroid Build Coastguard Worker        return self.a + self.b + x
166*da0073e9SAndroid Build Coastguard Worker
167*da0073e9SAndroid Build Coastguard Worker
168*da0073e9SAndroid Build Coastguard Workerclass SubModule2(nn.Module):
169*da0073e9SAndroid Build Coastguard Worker    def __init__(self) -> None:
170*da0073e9SAndroid Build Coastguard Worker        super().__init__()
171*da0073e9SAndroid Build Coastguard Worker        self.a = 12
172*da0073e9SAndroid Build Coastguard Worker        self.b = 2
173*da0073e9SAndroid Build Coastguard Worker
174*da0073e9SAndroid Build Coastguard Worker    def forward(self, x):
175*da0073e9SAndroid Build Coastguard Worker        self.b = 30
176*da0073e9SAndroid Build Coastguard Worker        return self.a + self.b + x
177*da0073e9SAndroid Build Coastguard Worker
178*da0073e9SAndroid Build Coastguard Worker
179*da0073e9SAndroid Build Coastguard Workerclass TestModule(nn.Module):
180*da0073e9SAndroid Build Coastguard Worker    def __init__(self) -> None:
181*da0073e9SAndroid Build Coastguard Worker        super().__init__()
182*da0073e9SAndroid Build Coastguard Worker        self.sub1 = SubModule()
183*da0073e9SAndroid Build Coastguard Worker        self.sub2 = SubModule2()
184*da0073e9SAndroid Build Coastguard Worker        self.a = 3
185*da0073e9SAndroid Build Coastguard Worker        self.b = 4
186*da0073e9SAndroid Build Coastguard Worker
187*da0073e9SAndroid Build Coastguard Worker    def forward(self, x):
188*da0073e9SAndroid Build Coastguard Worker        self.b = 20
189*da0073e9SAndroid Build Coastguard Worker        return self.sub1(x) + self.a + self.b + self.sub2(x)
190*da0073e9SAndroid Build Coastguard Worker
191*da0073e9SAndroid Build Coastguard Worker
192*da0073e9SAndroid Build Coastguard Workerclass TestStaticModule(TestCase):
193*da0073e9SAndroid Build Coastguard Worker
194*da0073e9SAndroid Build Coastguard Worker    """
195*da0073e9SAndroid Build Coastguard Worker    Test Case: To test simple fork/wait operation in a graph
196*da0073e9SAndroid Build Coastguard Worker    fork is called on simple addition operation on input tensors
197*da0073e9SAndroid Build Coastguard Worker    """
198*da0073e9SAndroid Build Coastguard Worker    def test_fork_wait_1(self):
199*da0073e9SAndroid Build Coastguard Worker        inp1 = torch.ones(5, 5)
200*da0073e9SAndroid Build Coastguard Worker        inp2 = torch.randn(5, 5)
201*da0073e9SAndroid Build Coastguard Worker        torch_graph = torch.jit.script(fork_wait_graph1)
202*da0073e9SAndroid Build Coastguard Worker        output_ref = torch_graph(inp1, inp2)
203*da0073e9SAndroid Build Coastguard Worker        static_runtime_module = StaticModule(torch_graph)
204*da0073e9SAndroid Build Coastguard Worker        output_test = static_runtime_module(inp1, inp2)
205*da0073e9SAndroid Build Coastguard Worker        torch.testing.assert_close(output_test, output_ref)
206*da0073e9SAndroid Build Coastguard Worker
207*da0073e9SAndroid Build Coastguard Worker    """
208*da0073e9SAndroid Build Coastguard Worker    Test Case: To test simple fork/wait operation with
209*da0073e9SAndroid Build Coastguard Worker    StaticRuntime runAsync API returning future
210*da0073e9SAndroid Build Coastguard Worker    """
211*da0073e9SAndroid Build Coastguard Worker    def test_fork_wait_1_async(self):
212*da0073e9SAndroid Build Coastguard Worker        inp1 = torch.ones(5, 5)
213*da0073e9SAndroid Build Coastguard Worker        inp2 = torch.randn(5, 5)
214*da0073e9SAndroid Build Coastguard Worker        torch_graph = torch.jit.script(fork_wait_graph1)
215*da0073e9SAndroid Build Coastguard Worker        output_ref = torch_graph(inp1, inp2)
216*da0073e9SAndroid Build Coastguard Worker        static_runtime_module = StaticModule(torch_graph)
217*da0073e9SAndroid Build Coastguard Worker        output_test = static_runtime_module.runAsync((inp1, inp2), {})
218*da0073e9SAndroid Build Coastguard Worker        output_test.wait()
219*da0073e9SAndroid Build Coastguard Worker        torch.testing.assert_close(output_test.value(), output_ref)
220*da0073e9SAndroid Build Coastguard Worker
221*da0073e9SAndroid Build Coastguard Worker    """
222*da0073e9SAndroid Build Coastguard Worker    Test Case: To test fork/wait operation in a graph on
223*da0073e9SAndroid Build Coastguard Worker    a loop subgraph performing mix of operations
224*da0073e9SAndroid Build Coastguard Worker    """
225*da0073e9SAndroid Build Coastguard Worker    def test_fork_wait_2(self):
226*da0073e9SAndroid Build Coastguard Worker        inp1 = torch.randn(5, 5)
227*da0073e9SAndroid Build Coastguard Worker        inp2 = torch.randn(5, 5)
228*da0073e9SAndroid Build Coastguard Worker        torch_graph = torch.jit.script(fork_wait_graph2)
229*da0073e9SAndroid Build Coastguard Worker        output_ref = torch_graph(inp1, inp2)
230*da0073e9SAndroid Build Coastguard Worker        static_runtime_module = StaticModule(torch_graph)
231*da0073e9SAndroid Build Coastguard Worker        output_test = static_runtime_module(inp1, inp2)
232*da0073e9SAndroid Build Coastguard Worker        torch.testing.assert_close(output_test, output_ref)
233*da0073e9SAndroid Build Coastguard Worker
234*da0073e9SAndroid Build Coastguard Worker    """
235*da0073e9SAndroid Build Coastguard Worker    Test Case: To test fork/wait operation on a loop
236*da0073e9SAndroid Build Coastguard Worker    subgraph with StaticRuntime runAsync API returning future
237*da0073e9SAndroid Build Coastguard Worker    """
238*da0073e9SAndroid Build Coastguard Worker    def test_fork_wait_2_async(self):
239*da0073e9SAndroid Build Coastguard Worker        inp1 = torch.randn(5, 5)
240*da0073e9SAndroid Build Coastguard Worker        inp2 = torch.randn(5, 5)
241*da0073e9SAndroid Build Coastguard Worker        torch_graph = torch.jit.script(fork_wait_graph2)
242*da0073e9SAndroid Build Coastguard Worker        output_ref = torch_graph(inp1, inp2)
243*da0073e9SAndroid Build Coastguard Worker        static_runtime_module = StaticModule(torch_graph)
244*da0073e9SAndroid Build Coastguard Worker        output_test = static_runtime_module.runAsync((inp1, inp2), {})
245*da0073e9SAndroid Build Coastguard Worker        output_test.wait()
246*da0073e9SAndroid Build Coastguard Worker        torch.testing.assert_close(output_test.value(), output_ref)
247*da0073e9SAndroid Build Coastguard Worker
248*da0073e9SAndroid Build Coastguard Worker    """
249*da0073e9SAndroid Build Coastguard Worker    Test Case: To test fork/wait operation in a graph on
250*da0073e9SAndroid Build Coastguard Worker    having multiple fork/wait operations
251*da0073e9SAndroid Build Coastguard Worker    """
252*da0073e9SAndroid Build Coastguard Worker    def test_fork_wait_3(self):
253*da0073e9SAndroid Build Coastguard Worker        input = torch.ones(3, 3)
254*da0073e9SAndroid Build Coastguard Worker        num_forks = 10
255*da0073e9SAndroid Build Coastguard Worker        torch_graph = torch.jit.script(fork_wait_graph3)
256*da0073e9SAndroid Build Coastguard Worker        output_ref = torch_graph(input, num_forks)
257*da0073e9SAndroid Build Coastguard Worker        static_runtime_module = StaticModule(torch_graph)
258*da0073e9SAndroid Build Coastguard Worker        output_test = static_runtime_module(input, num_forks)
259*da0073e9SAndroid Build Coastguard Worker        torch.testing.assert_close(output_test, output_ref)
260*da0073e9SAndroid Build Coastguard Worker
261*da0073e9SAndroid Build Coastguard Worker    """
262*da0073e9SAndroid Build Coastguard Worker    Test Case: To test fork/wait operation in a graph with
263*da0073e9SAndroid Build Coastguard Worker    multiple fork/wait operations on runAsync API returning future
264*da0073e9SAndroid Build Coastguard Worker    """
265*da0073e9SAndroid Build Coastguard Worker    def test_fork_wait_3_async(self):
266*da0073e9SAndroid Build Coastguard Worker        input = torch.ones(3, 3)
267*da0073e9SAndroid Build Coastguard Worker        num_forks = 10
268*da0073e9SAndroid Build Coastguard Worker        torch_graph = torch.jit.script(fork_wait_graph3)
269*da0073e9SAndroid Build Coastguard Worker        output_ref = torch_graph(input, num_forks)
270*da0073e9SAndroid Build Coastguard Worker        static_runtime_module = StaticModule(torch_graph)
271*da0073e9SAndroid Build Coastguard Worker        output_test = static_runtime_module.runAsync((input, num_forks), {})
272*da0073e9SAndroid Build Coastguard Worker        output_test.wait()
273*da0073e9SAndroid Build Coastguard Worker        torch.testing.assert_close(output_test.value(), output_ref)
274*da0073e9SAndroid Build Coastguard Worker
275*da0073e9SAndroid Build Coastguard Worker    """
276*da0073e9SAndroid Build Coastguard Worker    Test Case: To test fork/wait operation in a graph on
277*da0073e9SAndroid Build Coastguard Worker    multiple nested fork/wait operations
278*da0073e9SAndroid Build Coastguard Worker    """
279*da0073e9SAndroid Build Coastguard Worker    @unittest.skip("Broken test: https://github.com/pytorch/pytorch/issues/109782")
280*da0073e9SAndroid Build Coastguard Worker    def test_fork_wait_4(self):
281*da0073e9SAndroid Build Coastguard Worker        input = torch.ones(3, 3)
282*da0073e9SAndroid Build Coastguard Worker        num_forks = 10
283*da0073e9SAndroid Build Coastguard Worker        num_child_forks = 10
284*da0073e9SAndroid Build Coastguard Worker        torch_graph = torch.jit.script(fork_wait_graph4)
285*da0073e9SAndroid Build Coastguard Worker        static_runtime_module = StaticModule(torch_graph)
286*da0073e9SAndroid Build Coastguard Worker        output_ref = torch_graph(input, num_forks, num_child_forks)
287*da0073e9SAndroid Build Coastguard Worker        output_test = static_runtime_module(input, num_forks, num_child_forks)
288*da0073e9SAndroid Build Coastguard Worker        torch.testing.assert_close(output_test, output_ref)
289*da0073e9SAndroid Build Coastguard Worker
290*da0073e9SAndroid Build Coastguard Worker    """
291*da0073e9SAndroid Build Coastguard Worker    Test Case: To test fork/wait operation in a graph with multiple
292*da0073e9SAndroid Build Coastguard Worker    nested fork/wait operations on runAsync API returning future
293*da0073e9SAndroid Build Coastguard Worker    """
294*da0073e9SAndroid Build Coastguard Worker    @unittest.skip("Broken test: https://github.com/pytorch/pytorch/issues/109782")
295*da0073e9SAndroid Build Coastguard Worker    def test_fork_wait_4_async(self):
296*da0073e9SAndroid Build Coastguard Worker        input = torch.ones(3, 3)
297*da0073e9SAndroid Build Coastguard Worker        num_forks = 10
298*da0073e9SAndroid Build Coastguard Worker        num_child_forks = 10
299*da0073e9SAndroid Build Coastguard Worker        torch_graph = torch.jit.script(fork_wait_graph4)
300*da0073e9SAndroid Build Coastguard Worker        static_runtime_module = StaticModule(torch_graph)
301*da0073e9SAndroid Build Coastguard Worker        output_ref = torch_graph(input, num_forks, num_child_forks)
302*da0073e9SAndroid Build Coastguard Worker        output_test = static_runtime_module.runAsync(
303*da0073e9SAndroid Build Coastguard Worker            (input, num_forks, num_child_forks), {})
304*da0073e9SAndroid Build Coastguard Worker        output_test.wait()
305*da0073e9SAndroid Build Coastguard Worker        torch.testing.assert_close(output_test.value(), output_ref)
306*da0073e9SAndroid Build Coastguard Worker
307*da0073e9SAndroid Build Coastguard Worker    """
308*da0073e9SAndroid Build Coastguard Worker    Test Case: To test exception handling in fork/wait
309*da0073e9SAndroid Build Coastguard Worker    operation. Add.Tensor op is called for tensors with
310*da0073e9SAndroid Build Coastguard Worker    non-matching dims on the forked subgraph and the
311*da0073e9SAndroid Build Coastguard Worker    exception raised by subgraph is set on future returned
312*da0073e9SAndroid Build Coastguard Worker    by prim::fork to parent graph. Returned exception is
313*da0073e9SAndroid Build Coastguard Worker    checked for substring expected_error_msg as declared below
314*da0073e9SAndroid Build Coastguard Worker    """
315*da0073e9SAndroid Build Coastguard Worker    def test_fork_wait_exception(self):
316*da0073e9SAndroid Build Coastguard Worker        # incompatible tensors for add due to shape mismatch
317*da0073e9SAndroid Build Coastguard Worker        input1 = torch.randn(4, 7)
318*da0073e9SAndroid Build Coastguard Worker        input2 = torch.randn(4, 5)
319*da0073e9SAndroid Build Coastguard Worker        torch_graph = torch.jit.script(fork_wait_graph_exception)
320*da0073e9SAndroid Build Coastguard Worker        try:
321*da0073e9SAndroid Build Coastguard Worker            static_runtime_module = StaticModule(torch_graph)
322*da0073e9SAndroid Build Coastguard Worker            output_test = static_runtime_module(input1, input2)
323*da0073e9SAndroid Build Coastguard Worker        except Exception as error:
324*da0073e9SAndroid Build Coastguard Worker            expected_error_msg = (
325*da0073e9SAndroid Build Coastguard Worker                "The size of tensor a (7) must match the size "
326*da0073e9SAndroid Build Coastguard Worker                "of tensor b (5) at non-singleton dimension 1"
327*da0073e9SAndroid Build Coastguard Worker            )
328*da0073e9SAndroid Build Coastguard Worker            # test fails if error does not contain expected substr
329*da0073e9SAndroid Build Coastguard Worker            if str(error).find(expected_error_msg) == -1:
330*da0073e9SAndroid Build Coastguard Worker                raise RuntimeError(
331*da0073e9SAndroid Build Coastguard Worker                    "Tried execution of add.Tensors with incompatible shape. "
332*da0073e9SAndroid Build Coastguard Worker                    "Exception raised by forked runtime execution does "
333*da0073e9SAndroid Build Coastguard Worker                    f'not contain expected substring: "{expected_error_msg}"'
334*da0073e9SAndroid Build Coastguard Worker                ) from error
335*da0073e9SAndroid Build Coastguard Worker
336*da0073e9SAndroid Build Coastguard Worker    """
337*da0073e9SAndroid Build Coastguard Worker    Test Case: To test exception handling in fork/wait
338*da0073e9SAndroid Build Coastguard Worker    operation with runAsync API. Add.Tensor op is called for
339*da0073e9SAndroid Build Coastguard Worker    tensors with non-matching dims on the forked subgraph
340*da0073e9SAndroid Build Coastguard Worker    and the exception raised by subgraph is set on future returned
341*da0073e9SAndroid Build Coastguard Worker    by prim::fork to parent graph. Returned exception is
342*da0073e9SAndroid Build Coastguard Worker    checked for substring expected_error_msg as declared below
343*da0073e9SAndroid Build Coastguard Worker    """
344*da0073e9SAndroid Build Coastguard Worker    def test_fork_wait_exception_async(self):
345*da0073e9SAndroid Build Coastguard Worker        # incompatible tensors for add due to shape mismatch
346*da0073e9SAndroid Build Coastguard Worker        input1 = torch.randn(4, 7)
347*da0073e9SAndroid Build Coastguard Worker        input2 = torch.randn(4, 5)
348*da0073e9SAndroid Build Coastguard Worker        torch_graph = torch.jit.script(fork_wait_graph_exception)
349*da0073e9SAndroid Build Coastguard Worker        try:
350*da0073e9SAndroid Build Coastguard Worker            static_runtime_module = StaticModule(torch_graph)
351*da0073e9SAndroid Build Coastguard Worker            output_test = static_runtime_module.runAsync(
352*da0073e9SAndroid Build Coastguard Worker                (input1, input2), {})
353*da0073e9SAndroid Build Coastguard Worker        except Exception as error:
354*da0073e9SAndroid Build Coastguard Worker            expected_error_msg = (
355*da0073e9SAndroid Build Coastguard Worker                "The size of tensor a (7) must match the size "
356*da0073e9SAndroid Build Coastguard Worker                "of tensor b (5) at non-singleton dimension 1"
357*da0073e9SAndroid Build Coastguard Worker            )
358*da0073e9SAndroid Build Coastguard Worker            # test fails if error does not contain expected substr
359*da0073e9SAndroid Build Coastguard Worker            if str(error).find(expected_error_msg) == -1:
360*da0073e9SAndroid Build Coastguard Worker                raise RuntimeError(
361*da0073e9SAndroid Build Coastguard Worker                    "Tried execution of add.Tensors with incompatible shape. "
362*da0073e9SAndroid Build Coastguard Worker                    "Exception raised by forked runtime execution does "
363*da0073e9SAndroid Build Coastguard Worker                    f'not contain expected substring: "{expected_error_msg}"'
364*da0073e9SAndroid Build Coastguard Worker                ) from error
365*da0073e9SAndroid Build Coastguard Worker
366*da0073e9SAndroid Build Coastguard Worker    def test_multihead_attention_layer(self):
367*da0073e9SAndroid Build Coastguard Worker        HID_DIM = 256
368*da0073e9SAndroid Build Coastguard Worker        QUERY_LEN = 8
369*da0073e9SAndroid Build Coastguard Worker        BATCH_SIZE = 128
370*da0073e9SAndroid Build Coastguard Worker        LAYERS = 3
371*da0073e9SAndroid Build Coastguard Worker        HEADS = 8
372*da0073e9SAndroid Build Coastguard Worker        DROPOUT = 0.1
373*da0073e9SAndroid Build Coastguard Worker        device = torch.device("cpu")
374*da0073e9SAndroid Build Coastguard Worker        attention = MultiHeadAttentionLayer(HID_DIM, HEADS, DROPOUT, device).to(device)
375*da0073e9SAndroid Build Coastguard Worker        with torch.no_grad():
376*da0073e9SAndroid Build Coastguard Worker            src = torch.randn(BATCH_SIZE, QUERY_LEN, HID_DIM).to(device)
377*da0073e9SAndroid Build Coastguard Worker        src_mask = (src > 0)[:, :, 0].unsqueeze(1).unsqueeze(2).to(device)
378*da0073e9SAndroid Build Coastguard Worker
379*da0073e9SAndroid Build Coastguard Worker        attention.eval()
380*da0073e9SAndroid Build Coastguard Worker        attention = torch.jit.script(attention)
381*da0073e9SAndroid Build Coastguard Worker        attention.eval()
382*da0073e9SAndroid Build Coastguard Worker        o_ref = attention(src, src, src, src_mask)
383*da0073e9SAndroid Build Coastguard Worker
384*da0073e9SAndroid Build Coastguard Worker        attention_a = StaticModule(attention)
385*da0073e9SAndroid Build Coastguard Worker        o_test = attention_a(src, src, src, src_mask)
386*da0073e9SAndroid Build Coastguard Worker        o_test_kw = attention_a(src, src, value=src, mask=src_mask)
387*da0073e9SAndroid Build Coastguard Worker
388*da0073e9SAndroid Build Coastguard Worker        for a, b in zip(o_ref, o_test):
389*da0073e9SAndroid Build Coastguard Worker            torch.testing.assert_close(a, b)
390*da0073e9SAndroid Build Coastguard Worker
391*da0073e9SAndroid Build Coastguard Worker        for a, b in zip(o_ref, o_test_kw):
392*da0073e9SAndroid Build Coastguard Worker            torch.testing.assert_close(a, b)
393*da0073e9SAndroid Build Coastguard Worker
394*da0073e9SAndroid Build Coastguard Worker    def test_multihead_attention_layer_benchmark(self):
395*da0073e9SAndroid Build Coastguard Worker        HID_DIM = 256
396*da0073e9SAndroid Build Coastguard Worker        QUERY_LEN = 8
397*da0073e9SAndroid Build Coastguard Worker        BATCH_SIZE = 128
398*da0073e9SAndroid Build Coastguard Worker        LAYERS = 3
399*da0073e9SAndroid Build Coastguard Worker        HEADS = 8
400*da0073e9SAndroid Build Coastguard Worker        DROPOUT = 0.1
401*da0073e9SAndroid Build Coastguard Worker        device = torch.device("cpu")
402*da0073e9SAndroid Build Coastguard Worker        attention = MultiHeadAttentionLayer(HID_DIM, HEADS, DROPOUT, device).to(device)
403*da0073e9SAndroid Build Coastguard Worker        with torch.no_grad():
404*da0073e9SAndroid Build Coastguard Worker            src = torch.randn(BATCH_SIZE, QUERY_LEN, HID_DIM).to(device)
405*da0073e9SAndroid Build Coastguard Worker        src_mask = (src > 0)[:, :, 0].unsqueeze(1).unsqueeze(2).to(device)
406*da0073e9SAndroid Build Coastguard Worker
407*da0073e9SAndroid Build Coastguard Worker        attention.eval()
408*da0073e9SAndroid Build Coastguard Worker        attention = torch.jit.script(attention)
409*da0073e9SAndroid Build Coastguard Worker        attention_a = StaticModule(attention)
410*da0073e9SAndroid Build Coastguard Worker
411*da0073e9SAndroid Build Coastguard Worker        attention_a.benchmark([src, src, src, src_mask], {}, 2, 2)
412*da0073e9SAndroid Build Coastguard Worker        metrics = attention_a.benchmark_individual_ops(
413*da0073e9SAndroid Build Coastguard Worker            [src, src, src, src_mask], {}, 2, 2
414*da0073e9SAndroid Build Coastguard Worker        )
415*da0073e9SAndroid Build Coastguard Worker
416*da0073e9SAndroid Build Coastguard Worker    def test_mlp(self):
417*da0073e9SAndroid Build Coastguard Worker        # Arguments taken from benchmark script, ./bench/dlrm_s_benchmark.sh
418*da0073e9SAndroid Build Coastguard Worker        ln_bot = [512, 512, 64]
419*da0073e9SAndroid Build Coastguard Worker        sigmoid_bot = -1
420*da0073e9SAndroid Build Coastguard Worker        ln_top = [100, 1024, 1024, 1024, 1]
421*da0073e9SAndroid Build Coastguard Worker        sigmoid_top = 3
422*da0073e9SAndroid Build Coastguard Worker        bot_l = create_mlp(ln_bot, sigmoid_bot)
423*da0073e9SAndroid Build Coastguard Worker        bot_l_acc = StaticModule(bot_l)
424*da0073e9SAndroid Build Coastguard Worker        top_l = create_mlp(ln_top, sigmoid_top)
425*da0073e9SAndroid Build Coastguard Worker        top_l_acc = StaticModule(top_l)
426*da0073e9SAndroid Build Coastguard Worker        with torch.no_grad():
427*da0073e9SAndroid Build Coastguard Worker            bot_inp = torch.randn(2048, 512)  # torch.Size([2048, 512])
428*da0073e9SAndroid Build Coastguard Worker            top_inp = torch.randn(2048, 100)  # torch.Size([2048, 100])
429*da0073e9SAndroid Build Coastguard Worker        ref_bot = bot_l(bot_inp)
430*da0073e9SAndroid Build Coastguard Worker        acc_bot = bot_l_acc(bot_inp)
431*da0073e9SAndroid Build Coastguard Worker        torch.testing.assert_close(acc_bot, ref_bot)
432*da0073e9SAndroid Build Coastguard Worker        ref_top = top_l(top_inp)
433*da0073e9SAndroid Build Coastguard Worker        acc_top = top_l_acc(top_inp)
434*da0073e9SAndroid Build Coastguard Worker        torch.testing.assert_close(acc_top, ref_top)
435*da0073e9SAndroid Build Coastguard Worker        for _ in range(5):
436*da0073e9SAndroid Build Coastguard Worker            with torch.no_grad():
437*da0073e9SAndroid Build Coastguard Worker                bot_inp = torch.randn(2048, 512)  # torch.Size([2048, 512])
438*da0073e9SAndroid Build Coastguard Worker                top_inp = torch.randn(2048, 100)  # torch.Size([2048, 100])
439*da0073e9SAndroid Build Coastguard Worker            ref_bot = bot_l(bot_inp)
440*da0073e9SAndroid Build Coastguard Worker            acc_bot = bot_l_acc(bot_inp)
441*da0073e9SAndroid Build Coastguard Worker            torch.testing.assert_close(acc_bot, ref_bot)
442*da0073e9SAndroid Build Coastguard Worker            ref_top = top_l(top_inp)
443*da0073e9SAndroid Build Coastguard Worker            acc_top = top_l_acc(top_inp)
444*da0073e9SAndroid Build Coastguard Worker            torch.testing.assert_close(acc_top, ref_top)
445*da0073e9SAndroid Build Coastguard Worker
446*da0073e9SAndroid Build Coastguard Worker    def test_trivial_graph(self):
447*da0073e9SAndroid Build Coastguard Worker        s = torch.full((2, 2), 2)
448*da0073e9SAndroid Build Coastguard Worker        tg = torch.jit.script(trivial_graph)
449*da0073e9SAndroid Build Coastguard Worker        o_ref = tg(s, s, s)
450*da0073e9SAndroid Build Coastguard Worker        tg_a = StaticModule(tg)
451*da0073e9SAndroid Build Coastguard Worker        o_test = tg_a(s, s, s)
452*da0073e9SAndroid Build Coastguard Worker        torch.testing.assert_close(o_ref, o_test)
453*da0073e9SAndroid Build Coastguard Worker
454*da0073e9SAndroid Build Coastguard Worker    def test_leaky_relu(self):
455*da0073e9SAndroid Build Coastguard Worker        s = torch.randn(5, 5)
456*da0073e9SAndroid Build Coastguard Worker        tg = torch.jit.script(nn.LeakyReLU(0.1))
457*da0073e9SAndroid Build Coastguard Worker        o_ref = tg(s)
458*da0073e9SAndroid Build Coastguard Worker        tg_a = StaticModule(tg)
459*da0073e9SAndroid Build Coastguard Worker        o_test = tg_a(s)
460*da0073e9SAndroid Build Coastguard Worker        torch.testing.assert_close(o_ref, o_test)
461*da0073e9SAndroid Build Coastguard Worker
462*da0073e9SAndroid Build Coastguard Worker    def test_attr(self):
463*da0073e9SAndroid Build Coastguard Worker        """
464*da0073e9SAndroid Build Coastguard Worker        TorchScript IR of TestModule() after freezing:
465*da0073e9SAndroid Build Coastguard Worker        graph(%self : __torch__.test_static_runtime.___torch_mangle_0.TestModule,
466*da0073e9SAndroid Build Coastguard Worker              %x.1 : Tensor):
467*da0073e9SAndroid Build Coastguard Worker            %18 : int = prim::Constant[value=30]()
468*da0073e9SAndroid Build Coastguard Worker            %30 : int = prim::Constant[value=13]()
469*da0073e9SAndroid Build Coastguard Worker            %3 : int = prim::Constant[value=20]()
470*da0073e9SAndroid Build Coastguard Worker            %2 : int = prim::Constant[value=1]()
471*da0073e9SAndroid Build Coastguard Worker            %self.sub2.a : int = prim::Constant[value=12]()
472*da0073e9SAndroid Build Coastguard Worker            %self.a : int = prim::Constant[value=3]()
473*da0073e9SAndroid Build Coastguard Worker            = prim::SetAttr[name="b"](%self, %3)
474*da0073e9SAndroid Build Coastguard Worker            %17 : Tensor = aten::add(%x.1, %30, %2)
475*da0073e9SAndroid Build Coastguard Worker            %7 : Tensor = aten::add(%17, %self.a, %2)
476*da0073e9SAndroid Build Coastguard Worker            %b.1 : int = prim::GetAttr[name="b"](%self)
477*da0073e9SAndroid Build Coastguard Worker            %9 : Tensor = aten::add(%7, %b.1, %2)
478*da0073e9SAndroid Build Coastguard Worker            %sub2 : __torch__.test_static_runtime.___torch_mangle_2.SubModule2 = prim::GetAttr[name="sub2"](%self)
479*da0073e9SAndroid Build Coastguard Worker            = prim::SetAttr[name="b"](%sub2, %18)
480*da0073e9SAndroid Build Coastguard Worker            %b : int = prim::GetAttr[name="b"](%sub2)
481*da0073e9SAndroid Build Coastguard Worker            %22 : int = aten::add(%self.sub2.a, %b)
482*da0073e9SAndroid Build Coastguard Worker            %23 : Tensor = aten::add(%x.1, %22, %2)
483*da0073e9SAndroid Build Coastguard Worker            %12 : Tensor = aten::add(%9, %23, %2)
484*da0073e9SAndroid Build Coastguard Worker            return (%12)
485*da0073e9SAndroid Build Coastguard Worker        """
486*da0073e9SAndroid Build Coastguard Worker        # test prim::SetAttr and prim::GetAttr impl in Static Runtime
487*da0073e9SAndroid Build Coastguard Worker        m = TestModule()
488*da0073e9SAndroid Build Coastguard Worker
489*da0073e9SAndroid Build Coastguard Worker        m.eval()
490*da0073e9SAndroid Build Coastguard Worker        input = torch.randn(2, 2)
491*da0073e9SAndroid Build Coastguard Worker        output_s = m.forward(input)
492*da0073e9SAndroid Build Coastguard Worker
493*da0073e9SAndroid Build Coastguard Worker        ms = torch.jit.script(m)
494*da0073e9SAndroid Build Coastguard Worker        sm = StaticModule(ms)
495*da0073e9SAndroid Build Coastguard Worker        output_sm = sm(input)
496*da0073e9SAndroid Build Coastguard Worker        torch.testing.assert_close(output_s, output_sm)
497*da0073e9SAndroid Build Coastguard Worker        sm.benchmark([input], {}, 2, 2)
498*da0073e9SAndroid Build Coastguard Worker        sm.benchmark_individual_ops([input], {}, 2, 2)
499*da0073e9SAndroid Build Coastguard Worker        sm.benchmark([], {"x": input}, 2, 2)
500*da0073e9SAndroid Build Coastguard Worker        sm.benchmark_individual_ops([], {"x": input}, 2, 2)
501*da0073e9SAndroid Build Coastguard Worker
502*da0073e9SAndroid Build Coastguard Worker    @unittest.skip("Temporarily disabled")
503*da0073e9SAndroid Build Coastguard Worker    def test_fusion_trivial_graph(self):
504*da0073e9SAndroid Build Coastguard Worker        s = torch.full((2, 2), 2)
505*da0073e9SAndroid Build Coastguard Worker        tg = torch.jit.script(trivial_graph)
506*da0073e9SAndroid Build Coastguard Worker        o_ref = tg(s, s, s)
507*da0073e9SAndroid Build Coastguard Worker        torch._C._fuse_to_static_module(tg.graph)
508*da0073e9SAndroid Build Coastguard Worker        assert "StaticSubgraph" in str(tg.graph)
509*da0073e9SAndroid Build Coastguard Worker        o_test = tg(s, s, s)
510*da0073e9SAndroid Build Coastguard Worker        torch.testing.assert_close(o_ref, o_test)
511*da0073e9SAndroid Build Coastguard Worker
512*da0073e9SAndroid Build Coastguard Worker    @unittest.skip("Temporarily disabled")
513*da0073e9SAndroid Build Coastguard Worker    def test_fusion_multihead_attention_layer(self):
514*da0073e9SAndroid Build Coastguard Worker        HID_DIM = 256
515*da0073e9SAndroid Build Coastguard Worker        QUERY_LEN = 8
516*da0073e9SAndroid Build Coastguard Worker        BATCH_SIZE = 128
517*da0073e9SAndroid Build Coastguard Worker        LAYERS = 3
518*da0073e9SAndroid Build Coastguard Worker        HEADS = 8
519*da0073e9SAndroid Build Coastguard Worker        DROPOUT = 0.1
520*da0073e9SAndroid Build Coastguard Worker        device = torch.device("cpu")
521*da0073e9SAndroid Build Coastguard Worker        attention = MultiHeadAttentionLayer(HID_DIM, HEADS, DROPOUT, device).to(device)
522*da0073e9SAndroid Build Coastguard Worker        with torch.no_grad():
523*da0073e9SAndroid Build Coastguard Worker            src = torch.randn(BATCH_SIZE, QUERY_LEN, HID_DIM).to(device)
524*da0073e9SAndroid Build Coastguard Worker        src_mask = (src > 0)[:, :, 0].unsqueeze(1).unsqueeze(2).to(device)
525*da0073e9SAndroid Build Coastguard Worker
526*da0073e9SAndroid Build Coastguard Worker        attention.eval()
527*da0073e9SAndroid Build Coastguard Worker        attention = torch.jit.script(attention)
528*da0073e9SAndroid Build Coastguard Worker        attention.eval()
529*da0073e9SAndroid Build Coastguard Worker        o_ref = attention(src, src, src, src_mask)
530*da0073e9SAndroid Build Coastguard Worker
531*da0073e9SAndroid Build Coastguard Worker        torch._C._fuse_to_static_module(attention._c)
532*da0073e9SAndroid Build Coastguard Worker        o_test = attention(src, src, src, src_mask)
533*da0073e9SAndroid Build Coastguard Worker
534*da0073e9SAndroid Build Coastguard Worker        for a, b in zip(o_ref, o_test):
535*da0073e9SAndroid Build Coastguard Worker            torch.testing.assert_close(a, b)
536*da0073e9SAndroid Build Coastguard Worker
537*da0073e9SAndroid Build Coastguard Worker    @unittest.skip("Temporarily disabled")
538*da0073e9SAndroid Build Coastguard Worker    def test_fusion_loop(self):
539*da0073e9SAndroid Build Coastguard Worker        a = torch.randn(5, 5)
540*da0073e9SAndroid Build Coastguard Worker        b = torch.randn(5, 5)
541*da0073e9SAndroid Build Coastguard Worker        c = 4
542*da0073e9SAndroid Build Coastguard Worker        lg = torch.jit.script(loop_graph)
543*da0073e9SAndroid Build Coastguard Worker        o_ref = lg(a, b, c)
544*da0073e9SAndroid Build Coastguard Worker        torch._C._fuse_to_static_module(lg.graph)
545*da0073e9SAndroid Build Coastguard Worker        assert "StaticSubgraph" in str(lg.graph)
546*da0073e9SAndroid Build Coastguard Worker        o_test = lg(a, b, c)
547*da0073e9SAndroid Build Coastguard Worker        torch.testing.assert_close(o_ref, o_test)
548*da0073e9SAndroid Build Coastguard Worker
549*da0073e9SAndroid Build Coastguard Worker    @unittest.skip("Temporarily disabled")
550*da0073e9SAndroid Build Coastguard Worker    def test_fusion_outputs(self):
551*da0073e9SAndroid Build Coastguard Worker        a = torch.randn(2, 2)
552*da0073e9SAndroid Build Coastguard Worker        b = torch.randn(2, 2)
553*da0073e9SAndroid Build Coastguard Worker        c = 4
554*da0073e9SAndroid Build Coastguard Worker        og = torch.jit.script(output_graph)
555*da0073e9SAndroid Build Coastguard Worker        o_ref = og(a, b, b, c)
556*da0073e9SAndroid Build Coastguard Worker        torch._C._fuse_to_static_module(og.graph)
557*da0073e9SAndroid Build Coastguard Worker        assert "StaticSubgraph" in str(og.graph)
558*da0073e9SAndroid Build Coastguard Worker        o_test = og(a, b, b, c)
559*da0073e9SAndroid Build Coastguard Worker        for i in o_ref.keys():
560*da0073e9SAndroid Build Coastguard Worker            torch.testing.assert_close(o_ref[i], o_test[i])
561*da0073e9SAndroid Build Coastguard Worker
562*da0073e9SAndroid Build Coastguard Worker    def test_create_object(self):
563*da0073e9SAndroid Build Coastguard Worker        class Foo:  # noqa: B903
564*da0073e9SAndroid Build Coastguard Worker            def __init__(self, x: torch.Tensor) -> None:
565*da0073e9SAndroid Build Coastguard Worker                self.x = x
566*da0073e9SAndroid Build Coastguard Worker
567*da0073e9SAndroid Build Coastguard Worker        class Mod(torch.nn.Module):
568*da0073e9SAndroid Build Coastguard Worker            def __init__(self) -> None:
569*da0073e9SAndroid Build Coastguard Worker                super().__init__()
570*da0073e9SAndroid Build Coastguard Worker
571*da0073e9SAndroid Build Coastguard Worker            def forward(self, y: torch.Tensor) -> torch.Tensor:
572*da0073e9SAndroid Build Coastguard Worker                foo = Foo(y)
573*da0073e9SAndroid Build Coastguard Worker                return y * foo.x
574*da0073e9SAndroid Build Coastguard Worker
575*da0073e9SAndroid Build Coastguard Worker        mod = torch.jit.script(Mod()).eval()
576*da0073e9SAndroid Build Coastguard Worker        y = torch.randn((1, ))
577*da0073e9SAndroid Build Coastguard Worker        expected = mod(y)
578*da0073e9SAndroid Build Coastguard Worker
579*da0073e9SAndroid Build Coastguard Worker        static_mod = StaticModule(torch.jit.freeze(mod))
580*da0073e9SAndroid Build Coastguard Worker        actual = static_mod(y)
581*da0073e9SAndroid Build Coastguard Worker
582*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(expected, actual)
583*da0073e9SAndroid Build Coastguard Worker
584*da0073e9SAndroid Build Coastguard Workerif __name__ == "__main__":
585*da0073e9SAndroid Build Coastguard Worker    run_tests()
586