xref: /aosp_15_r20/external/pytorch/test/jit/test_tracer.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1*da0073e9SAndroid Build Coastguard Worker# Owner(s): ["oncall: jit"]
2*da0073e9SAndroid Build Coastguard Worker
3*da0073e9SAndroid Build Coastguard Workerimport copy
4*da0073e9SAndroid Build Coastguard Workerimport io
5*da0073e9SAndroid Build Coastguard Workerimport os
6*da0073e9SAndroid Build Coastguard Workerimport sys
7*da0073e9SAndroid Build Coastguard Workerimport unittest
8*da0073e9SAndroid Build Coastguard Worker
9*da0073e9SAndroid Build Coastguard Workerimport torch
10*da0073e9SAndroid Build Coastguard Workerimport torch.nn as nn
11*da0073e9SAndroid Build Coastguard Workerimport torch.nn.functional as F
12*da0073e9SAndroid Build Coastguard Workerfrom torch.autograd import Function, Variable
13*da0073e9SAndroid Build Coastguard Workerfrom torch.testing import FileCheck
14*da0073e9SAndroid Build Coastguard Worker
15*da0073e9SAndroid Build Coastguard Worker
16*da0073e9SAndroid Build Coastguard Worker# Make the helper files in test/ importable
17*da0073e9SAndroid Build Coastguard Workerpytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
18*da0073e9SAndroid Build Coastguard Workersys.path.append(pytorch_test_dir)
19*da0073e9SAndroid Build Coastguard Workerimport warnings
20*da0073e9SAndroid Build Coastguard Worker
21*da0073e9SAndroid Build Coastguard Worker# Standard library
22*da0073e9SAndroid Build Coastguard Workerfrom collections import namedtuple
23*da0073e9SAndroid Build Coastguard Workerfrom itertools import chain
24*da0073e9SAndroid Build Coastguard Workerfrom typing import Dict, List, Optional, Tuple
25*da0073e9SAndroid Build Coastguard Worker
26*da0073e9SAndroid Build Coastguard Workerfrom torch import Tensor
27*da0073e9SAndroid Build Coastguard Workerfrom torch.testing._internal.common_cuda import with_tf32_off
28*da0073e9SAndroid Build Coastguard Workerfrom torch.testing._internal.common_utils import (
29*da0073e9SAndroid Build Coastguard Worker    enable_profiling_mode_for_profiling_tests,
30*da0073e9SAndroid Build Coastguard Worker    IS_SANDCASTLE,
31*da0073e9SAndroid Build Coastguard Worker    skipIfCompiledWithoutNumpy,
32*da0073e9SAndroid Build Coastguard Worker    skipIfCrossRef,
33*da0073e9SAndroid Build Coastguard Worker    skipIfTorchDynamo,
34*da0073e9SAndroid Build Coastguard Worker    suppress_warnings,
35*da0073e9SAndroid Build Coastguard Worker    TemporaryFileName,
36*da0073e9SAndroid Build Coastguard Worker)
37*da0073e9SAndroid Build Coastguard Workerfrom torch.testing._internal.jit_utils import (
38*da0073e9SAndroid Build Coastguard Worker    _tmp_donotuse_dont_inline_everything,
39*da0073e9SAndroid Build Coastguard Worker    _trace,
40*da0073e9SAndroid Build Coastguard Worker    enable_cpu_fuser,
41*da0073e9SAndroid Build Coastguard Worker    JitTestCase,
42*da0073e9SAndroid Build Coastguard Worker    make_global,
43*da0073e9SAndroid Build Coastguard Worker    RUN_CUDA,
44*da0073e9SAndroid Build Coastguard Worker    RUN_CUDA_MULTI_GPU,
45*da0073e9SAndroid Build Coastguard Worker)
46*da0073e9SAndroid Build Coastguard Worker
47*da0073e9SAndroid Build Coastguard Worker
48*da0073e9SAndroid Build Coastguard Workerif __name__ == "__main__":
49*da0073e9SAndroid Build Coastguard Worker    raise RuntimeError(
50*da0073e9SAndroid Build Coastguard Worker        "This test file is not meant to be run directly, use:\n\n"
51*da0073e9SAndroid Build Coastguard Worker        "\tpython test/test_jit.py TESTNAME\n\n"
52*da0073e9SAndroid Build Coastguard Worker        "instead."
53*da0073e9SAndroid Build Coastguard Worker    )
54*da0073e9SAndroid Build Coastguard Worker
55*da0073e9SAndroid Build Coastguard Worker
56*da0073e9SAndroid Build Coastguard Worker@skipIfTorchDynamo("Not a suitable test for TorchDynamo")
57*da0073e9SAndroid Build Coastguard Workerclass TestTracer(JitTestCase):
58*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(not RUN_CUDA, "requires CUDA")
59*da0073e9SAndroid Build Coastguard Worker    def test_large_nbr_kernel_args(self):
60*da0073e9SAndroid Build Coastguard Worker        class Recurrence(nn.Module):
61*da0073e9SAndroid Build Coastguard Worker            def __init__(self, seq_len):
62*da0073e9SAndroid Build Coastguard Worker                super().__init__()
63*da0073e9SAndroid Build Coastguard Worker                self.seq_len = seq_len
64*da0073e9SAndroid Build Coastguard Worker
65*da0073e9SAndroid Build Coastguard Worker            def forward(self, input):
66*da0073e9SAndroid Build Coastguard Worker                input = input.transpose(0, 1)
67*da0073e9SAndroid Build Coastguard Worker
68*da0073e9SAndroid Build Coastguard Worker                # Main loop
69*da0073e9SAndroid Build Coastguard Worker                output = []
70*da0073e9SAndroid Build Coastguard Worker                for i in range(self.seq_len):
71*da0073e9SAndroid Build Coastguard Worker                    b = input[i] * 2
72*da0073e9SAndroid Build Coastguard Worker                    output.append(b)
73*da0073e9SAndroid Build Coastguard Worker
74*da0073e9SAndroid Build Coastguard Worker                output = torch.cat(output, 0).view(input.size(0), *output[0].size())
75*da0073e9SAndroid Build Coastguard Worker                output = output.transpose(0, 1)
76*da0073e9SAndroid Build Coastguard Worker                return output
77*da0073e9SAndroid Build Coastguard Worker
78*da0073e9SAndroid Build Coastguard Worker        input_size = 8
79*da0073e9SAndroid Build Coastguard Worker        batch_size = 2
80*da0073e9SAndroid Build Coastguard Worker        seq_len = 130
81*da0073e9SAndroid Build Coastguard Worker
82*da0073e9SAndroid Build Coastguard Worker        rec = Recurrence(seq_len)
83*da0073e9SAndroid Build Coastguard Worker        input = torch.rand(batch_size, seq_len, input_size)
84*da0073e9SAndroid Build Coastguard Worker
85*da0073e9SAndroid Build Coastguard Worker        torch.cuda.set_device(0)
86*da0073e9SAndroid Build Coastguard Worker        rec = rec.cuda()
87*da0073e9SAndroid Build Coastguard Worker        input = input.cuda()
88*da0073e9SAndroid Build Coastguard Worker
89*da0073e9SAndroid Build Coastguard Worker        traced_rec = torch.jit.trace(rec, (input))
90*da0073e9SAndroid Build Coastguard Worker
91*da0073e9SAndroid Build Coastguard Worker    def test_trace_legacy_ctor(self):
92*da0073e9SAndroid Build Coastguard Worker        class MyModule(nn.Module):
93*da0073e9SAndroid Build Coastguard Worker            def forward(self, x):
94*da0073e9SAndroid Build Coastguard Worker                return (x + 1, torch.FloatTensor([0]))
95*da0073e9SAndroid Build Coastguard Worker
96*da0073e9SAndroid Build Coastguard Worker        traced_rec = torch.jit.trace(MyModule(), torch.randn(2, 2))
97*da0073e9SAndroid Build Coastguard Worker
98*da0073e9SAndroid Build Coastguard Worker    def test_simple(self):
99*da0073e9SAndroid Build Coastguard Worker        x = torch.tensor([0.4], requires_grad=True)
100*da0073e9SAndroid Build Coastguard Worker        y = torch.tensor([0.7], requires_grad=True)
101*da0073e9SAndroid Build Coastguard Worker
102*da0073e9SAndroid Build Coastguard Worker        def f(x, y):
103*da0073e9SAndroid Build Coastguard Worker            return torch.sigmoid(torch.tanh(x * (x + y)))
104*da0073e9SAndroid Build Coastguard Worker
105*da0073e9SAndroid Build Coastguard Worker        self.checkTrace(f, (x, y))
106*da0073e9SAndroid Build Coastguard Worker
107*da0073e9SAndroid Build Coastguard Worker    def test_trace_checking_with_global_name(self):
108*da0073e9SAndroid Build Coastguard Worker        class MyClass(torch.nn.Module):
109*da0073e9SAndroid Build Coastguard Worker            def forward(self, xs: List[Tensor]):
110*da0073e9SAndroid Build Coastguard Worker                y = torch.cat(xs, dim=0)
111*da0073e9SAndroid Build Coastguard Worker                return y
112*da0073e9SAndroid Build Coastguard Worker
113*da0073e9SAndroid Build Coastguard Worker        model = MyClass()
114*da0073e9SAndroid Build Coastguard Worker        # Simulate these inputs being in the globals, like they would be if,
115*da0073e9SAndroid Build Coastguard Worker        # e.g. they were defined outermost scope of a script
116*da0073e9SAndroid Build Coastguard Worker        global input1, input2
117*da0073e9SAndroid Build Coastguard Worker        input1 = torch.ones(2, 2)
118*da0073e9SAndroid Build Coastguard Worker        input2 = torch.ones(2, 2)
119*da0073e9SAndroid Build Coastguard Worker        m2 = torch.jit.trace(model, ((input1, input2),))
120*da0073e9SAndroid Build Coastguard Worker
121*da0073e9SAndroid Build Coastguard Worker    def test_trace_aliased_parameter(self):
122*da0073e9SAndroid Build Coastguard Worker        class M(nn.Module):
123*da0073e9SAndroid Build Coastguard Worker            def __init__(self, x):
124*da0073e9SAndroid Build Coastguard Worker                super().__init__()
125*da0073e9SAndroid Build Coastguard Worker                self.x = nn.Parameter(x)
126*da0073e9SAndroid Build Coastguard Worker
127*da0073e9SAndroid Build Coastguard Worker            def forward(self, y):
128*da0073e9SAndroid Build Coastguard Worker                return self.x + y
129*da0073e9SAndroid Build Coastguard Worker
130*da0073e9SAndroid Build Coastguard Worker        m = M(torch.rand(3, 4))
131*da0073e9SAndroid Build Coastguard Worker        r = torch.jit.trace(m, m.x)
132*da0073e9SAndroid Build Coastguard Worker        t2 = torch.rand(3, 4)
133*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(r(t2), m.x + t2)
134*da0073e9SAndroid Build Coastguard Worker
135*da0073e9SAndroid Build Coastguard Worker    def test_trace_nested_fn(self):
136*da0073e9SAndroid Build Coastguard Worker        class TracedInlineDecision(torch.nn.Module):
137*da0073e9SAndroid Build Coastguard Worker            def forward(self, x, flag):
138*da0073e9SAndroid Build Coastguard Worker                @torch.jit.script
139*da0073e9SAndroid Build Coastguard Worker                def make_decision(flag, x):
140*da0073e9SAndroid Build Coastguard Worker                    if flag:
141*da0073e9SAndroid Build Coastguard Worker                        return x
142*da0073e9SAndroid Build Coastguard Worker                    else:
143*da0073e9SAndroid Build Coastguard Worker                        return torch.zeros_like(x)
144*da0073e9SAndroid Build Coastguard Worker
145*da0073e9SAndroid Build Coastguard Worker                x = torch.neg(x)
146*da0073e9SAndroid Build Coastguard Worker                return make_decision(flag, x)
147*da0073e9SAndroid Build Coastguard Worker
148*da0073e9SAndroid Build Coastguard Worker        decision = TracedInlineDecision()
149*da0073e9SAndroid Build Coastguard Worker        torch.jit.trace(
150*da0073e9SAndroid Build Coastguard Worker            decision,
151*da0073e9SAndroid Build Coastguard Worker            (torch.rand(3, 4), torch.tensor([True], dtype=torch.bool)),
152*da0073e9SAndroid Build Coastguard Worker            check_trace=True,
153*da0073e9SAndroid Build Coastguard Worker        )
154*da0073e9SAndroid Build Coastguard Worker
155*da0073e9SAndroid Build Coastguard Worker    def test_trace_single_tuple(self):
156*da0073e9SAndroid Build Coastguard Worker        x = torch.tensor(2.0)
157*da0073e9SAndroid Build Coastguard Worker
158*da0073e9SAndroid Build Coastguard Worker        def f2(x):
159*da0073e9SAndroid Build Coastguard Worker            return (x,)
160*da0073e9SAndroid Build Coastguard Worker
161*da0073e9SAndroid Build Coastguard Worker        jit_f2 = torch.jit.trace(f2, x)
162*da0073e9SAndroid Build Coastguard Worker        assert f2(x) == jit_f2(x)  # fails
163*da0073e9SAndroid Build Coastguard Worker
164*da0073e9SAndroid Build Coastguard Worker    def test_trace_out_operator_with_two_output(self):
165*da0073e9SAndroid Build Coastguard Worker        example_input = torch.rand(2, 8)
166*da0073e9SAndroid Build Coastguard Worker        out_1, out_2 = torch.cummax(example_input, 1)
167*da0073e9SAndroid Build Coastguard Worker
168*da0073e9SAndroid Build Coastguard Worker        def run_cummax(example_input, out_1, out_2):
169*da0073e9SAndroid Build Coastguard Worker            output_1, output_2 = torch.cummax(example_input, 1, out=(out_1, out_2))
170*da0073e9SAndroid Build Coastguard Worker            return output_1, output_2
171*da0073e9SAndroid Build Coastguard Worker
172*da0073e9SAndroid Build Coastguard Worker        trace_model = torch.jit.trace(run_cummax, (example_input, out_1, out_2))
173*da0073e9SAndroid Build Coastguard Worker
174*da0073e9SAndroid Build Coastguard Worker    def test_trace_namedtuple(self):
175*da0073e9SAndroid Build Coastguard Worker        Point = namedtuple("point", ["x", "y"])
176*da0073e9SAndroid Build Coastguard Worker
177*da0073e9SAndroid Build Coastguard Worker        def f(p):
178*da0073e9SAndroid Build Coastguard Worker            if type(p) is tuple:
179*da0073e9SAndroid Build Coastguard Worker                p = Point(*p)
180*da0073e9SAndroid Build Coastguard Worker            return p.x + p.y
181*da0073e9SAndroid Build Coastguard Worker
182*da0073e9SAndroid Build Coastguard Worker        p = Point(torch.randn(1), torch.randn(1))
183*da0073e9SAndroid Build Coastguard Worker        traced = torch.jit.trace(f, (p,))
184*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(f(p), traced(p))
185*da0073e9SAndroid Build Coastguard Worker
186*da0073e9SAndroid Build Coastguard Worker    def test_trace_topk(self):
187*da0073e9SAndroid Build Coastguard Worker        class M(torch.nn.Module):
188*da0073e9SAndroid Build Coastguard Worker            def forward(self, x, y):
189*da0073e9SAndroid Build Coastguard Worker                return x.topk(y, dim=1)[1]
190*da0073e9SAndroid Build Coastguard Worker
191*da0073e9SAndroid Build Coastguard Worker        mod = M()
192*da0073e9SAndroid Build Coastguard Worker        inputs = (torch.randint(0, 10, (20, 20)), torch.tensor(17))
193*da0073e9SAndroid Build Coastguard Worker        traced_func = torch.jit.trace(mod, inputs)
194*da0073e9SAndroid Build Coastguard Worker
195*da0073e9SAndroid Build Coastguard Worker        test_inputs = (torch.randint(0, 9, (9, 9)), torch.tensor(8))
196*da0073e9SAndroid Build Coastguard Worker        eager_out = mod(*test_inputs)
197*da0073e9SAndroid Build Coastguard Worker        traced_out = traced_func(*test_inputs)
198*da0073e9SAndroid Build Coastguard Worker        self.assertNotWarn(
199*da0073e9SAndroid Build Coastguard Worker            lambda: traced_func(*test_inputs),
200*da0073e9SAndroid Build Coastguard Worker            "Shouldn't throw slicing related warn here",
201*da0073e9SAndroid Build Coastguard Worker        )
202*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(eager_out, traced_out)
203*da0073e9SAndroid Build Coastguard Worker
204*da0073e9SAndroid Build Coastguard Worker        test_inputs = (torch.randint(0, 50, (50, 50)), torch.tensor(12))
205*da0073e9SAndroid Build Coastguard Worker        eager_out = mod(*test_inputs)
206*da0073e9SAndroid Build Coastguard Worker        traced_out = traced_func(*test_inputs)
207*da0073e9SAndroid Build Coastguard Worker        self.assertNotWarn(
208*da0073e9SAndroid Build Coastguard Worker            lambda: traced_func(*test_inputs),
209*da0073e9SAndroid Build Coastguard Worker            "Shouldn't throw slicing related warn here",
210*da0073e9SAndroid Build Coastguard Worker        )
211*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(eager_out, traced_out)
212*da0073e9SAndroid Build Coastguard Worker
213*da0073e9SAndroid Build Coastguard Worker    def test_typeas_trace_check(self):
214*da0073e9SAndroid Build Coastguard Worker        a = torch.tensor([0.4], requires_grad=True)
215*da0073e9SAndroid Build Coastguard Worker        b = torch.tensor([0.7], requires_grad=True)
216*da0073e9SAndroid Build Coastguard Worker
217*da0073e9SAndroid Build Coastguard Worker        def f(x, y):
218*da0073e9SAndroid Build Coastguard Worker            return x.type_as(y)
219*da0073e9SAndroid Build Coastguard Worker
220*da0073e9SAndroid Build Coastguard Worker        trace = torch.jit.trace(f, (a, b))
221*da0073e9SAndroid Build Coastguard Worker
222*da0073e9SAndroid Build Coastguard Worker    def test_trace_index(self):
223*da0073e9SAndroid Build Coastguard Worker        x = torch.tensor([0.4], requires_grad=True)
224*da0073e9SAndroid Build Coastguard Worker        y = torch.tensor([0], dtype=torch.int64)
225*da0073e9SAndroid Build Coastguard Worker
226*da0073e9SAndroid Build Coastguard Worker        def fn(x, y):
227*da0073e9SAndroid Build Coastguard Worker            return x[y]
228*da0073e9SAndroid Build Coastguard Worker
229*da0073e9SAndroid Build Coastguard Worker        fn_traced = torch.jit.trace(
230*da0073e9SAndroid Build Coastguard Worker            fn,
231*da0073e9SAndroid Build Coastguard Worker            (
232*da0073e9SAndroid Build Coastguard Worker                x,
233*da0073e9SAndroid Build Coastguard Worker                y,
234*da0073e9SAndroid Build Coastguard Worker            ),
235*da0073e9SAndroid Build Coastguard Worker        )
236*da0073e9SAndroid Build Coastguard Worker
237*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(fn(x, y), fn_traced(x, y))
238*da0073e9SAndroid Build Coastguard Worker
239*da0073e9SAndroid Build Coastguard Worker    # Backwards tracing was broken for indexing by a constant,
240*da0073e9SAndroid Build Coastguard Worker    # because it's internally implemented using as_strided,
241*da0073e9SAndroid Build Coastguard Worker    # and we attempted to trace its derivative (which is not
242*da0073e9SAndroid Build Coastguard Worker    # currently supported.)  It currently works because
243*da0073e9SAndroid Build Coastguard Worker    # slice() is now not marked as traceable.
244*da0073e9SAndroid Build Coastguard Worker    def test_trace_index_constant(self):
245*da0073e9SAndroid Build Coastguard Worker        x = torch.tensor([0.4], requires_grad=True)
246*da0073e9SAndroid Build Coastguard Worker
247*da0073e9SAndroid Build Coastguard Worker        def fn(x):
248*da0073e9SAndroid Build Coastguard Worker            return x[0]
249*da0073e9SAndroid Build Coastguard Worker
250*da0073e9SAndroid Build Coastguard Worker        def run(f):
251*da0073e9SAndroid Build Coastguard Worker            y = f(x)
252*da0073e9SAndroid Build Coastguard Worker            grad = torch.autograd.grad(y, x)[0].clone()
253*da0073e9SAndroid Build Coastguard Worker            return y, grad
254*da0073e9SAndroid Build Coastguard Worker
255*da0073e9SAndroid Build Coastguard Worker        traced_fn = torch.jit.trace(fn, torch.ones(1))
256*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(run(fn), run(traced_fn))
257*da0073e9SAndroid Build Coastguard Worker
258*da0073e9SAndroid Build Coastguard Worker    def test_index_put(self):
259*da0073e9SAndroid Build Coastguard Worker        ten = torch.zeros(3, 3)
260*da0073e9SAndroid Build Coastguard Worker        mask = torch.tensor(
261*da0073e9SAndroid Build Coastguard Worker            [[True, True, True], [True, False, False], [True, True, False]]
262*da0073e9SAndroid Build Coastguard Worker        )
263*da0073e9SAndroid Build Coastguard Worker
264*da0073e9SAndroid Build Coastguard Worker        def test_fn(ten, mask):
265*da0073e9SAndroid Build Coastguard Worker            ten[mask] = torch.ones(6)
266*da0073e9SAndroid Build Coastguard Worker            return ten
267*da0073e9SAndroid Build Coastguard Worker
268*da0073e9SAndroid Build Coastguard Worker        traced_test_fn = torch.jit.trace(test_fn, (ten, mask))
269*da0073e9SAndroid Build Coastguard Worker
270*da0073e9SAndroid Build Coastguard Worker        ten = torch.rand(3, 3)
271*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(test_fn(ten, mask), traced_test_fn(ten, mask))
272*da0073e9SAndroid Build Coastguard Worker
273*da0073e9SAndroid Build Coastguard Worker    def test_canonicalize_tensor_iterator(self):
274*da0073e9SAndroid Build Coastguard Worker        x = torch.randn(4, 4)
275*da0073e9SAndroid Build Coastguard Worker
276*da0073e9SAndroid Build Coastguard Worker        def f(x):
277*da0073e9SAndroid Build Coastguard Worker            x = x + 2
278*da0073e9SAndroid Build Coastguard Worker            x = x - 4
279*da0073e9SAndroid Build Coastguard Worker            x = x * 6
280*da0073e9SAndroid Build Coastguard Worker            x = x / 8
281*da0073e9SAndroid Build Coastguard Worker            return x
282*da0073e9SAndroid Build Coastguard Worker
283*da0073e9SAndroid Build Coastguard Worker        traced = torch.jit.trace(f, (x,))
284*da0073e9SAndroid Build Coastguard Worker        f(x)
285*da0073e9SAndroid Build Coastguard Worker        graph = traced.graph_for(x)
286*da0073e9SAndroid Build Coastguard Worker        # There should be 4 int constants for the right sides of operators, plus one
287*da0073e9SAndroid Build Coastguard Worker        # for the alpha argument for add and sub
288*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(str(traced.graph_for(x)).count(": int = prim::Constant") == 5)
289*da0073e9SAndroid Build Coastguard Worker
290*da0073e9SAndroid Build Coastguard Worker    @suppress_warnings
291*da0073e9SAndroid Build Coastguard Worker    def test_constant(self):
292*da0073e9SAndroid Build Coastguard Worker        x = torch.randn(2, 2, requires_grad=True)
293*da0073e9SAndroid Build Coastguard Worker
294*da0073e9SAndroid Build Coastguard Worker        def f(x):
295*da0073e9SAndroid Build Coastguard Worker            return x.matmul(torch.diag(torch.tensor([2.0, 2.0])))
296*da0073e9SAndroid Build Coastguard Worker
297*da0073e9SAndroid Build Coastguard Worker        self.checkTrace(f, (x,), (torch.ones(2, 2, requires_grad=True),))
298*da0073e9SAndroid Build Coastguard Worker
299*da0073e9SAndroid Build Coastguard Worker    def test_wrapped_number(self):
300*da0073e9SAndroid Build Coastguard Worker        # Scalar's get converted to 'wrapped' tensors of default tensor type.
301*da0073e9SAndroid Build Coastguard Worker        # Wrapped tensors behave differently in certain promotion operations:
302*da0073e9SAndroid Build Coastguard Worker        # float_tensor * double -> float but wrapped_float * double -> double.
303*da0073e9SAndroid Build Coastguard Worker        # This can cause issues in check-trace if not handled correctly in
304*da0073e9SAndroid Build Coastguard Worker        # `aten::isclose()`.
305*da0073e9SAndroid Build Coastguard Worker
306*da0073e9SAndroid Build Coastguard Worker        def foobar():
307*da0073e9SAndroid Build Coastguard Worker            x = -10000.0
308*da0073e9SAndroid Build Coastguard Worker            result = x * torch.ones(1, dtype=torch.float)
309*da0073e9SAndroid Build Coastguard Worker            return result
310*da0073e9SAndroid Build Coastguard Worker
311*da0073e9SAndroid Build Coastguard Worker        scripted = torch.jit.trace(foobar, (), check_trace=True)
312*da0073e9SAndroid Build Coastguard Worker
313*da0073e9SAndroid Build Coastguard Worker    def test_inplace_transplant(self):
314*da0073e9SAndroid Build Coastguard Worker        x = torch.tensor([0.0], requires_grad=True)
315*da0073e9SAndroid Build Coastguard Worker
316*da0073e9SAndroid Build Coastguard Worker        def fn(x):
317*da0073e9SAndroid Build Coastguard Worker            y = x.clone()
318*da0073e9SAndroid Build Coastguard Worker            y.add_(2)
319*da0073e9SAndroid Build Coastguard Worker            y.add_(3)
320*da0073e9SAndroid Build Coastguard Worker            return y
321*da0073e9SAndroid Build Coastguard Worker
322*da0073e9SAndroid Build Coastguard Worker        g, _ = torch.jit._get_trace_graph(fn, (x,))
323*da0073e9SAndroid Build Coastguard Worker        self.run_pass("dce", g)
324*da0073e9SAndroid Build Coastguard Worker        FileCheck().check_count("aten::clone", 1, exactly=True).check_count(
325*da0073e9SAndroid Build Coastguard Worker            "aten::add_", 2, exactly=True
326*da0073e9SAndroid Build Coastguard Worker        ).check_next("return").run(str(g))
327*da0073e9SAndroid Build Coastguard Worker        self.assertExportImport(g, (x,))
328*da0073e9SAndroid Build Coastguard Worker
329*da0073e9SAndroid Build Coastguard Worker    def test_inplace_flags(self):
330*da0073e9SAndroid Build Coastguard Worker        class InplaceFn(Function):
331*da0073e9SAndroid Build Coastguard Worker            @staticmethod
332*da0073e9SAndroid Build Coastguard Worker            def forward(ctx, x):
333*da0073e9SAndroid Build Coastguard Worker                ctx.mark_dirty(x)
334*da0073e9SAndroid Build Coastguard Worker                return x.add_(1)
335*da0073e9SAndroid Build Coastguard Worker
336*da0073e9SAndroid Build Coastguard Worker            @staticmethod
337*da0073e9SAndroid Build Coastguard Worker            def backward(ctx, go):
338*da0073e9SAndroid Build Coastguard Worker                return go
339*da0073e9SAndroid Build Coastguard Worker
340*da0073e9SAndroid Build Coastguard Worker        class RegularFn(Function):
341*da0073e9SAndroid Build Coastguard Worker            @staticmethod
342*da0073e9SAndroid Build Coastguard Worker            def forward(ctx, x):
343*da0073e9SAndroid Build Coastguard Worker                return x.add(1)
344*da0073e9SAndroid Build Coastguard Worker
345*da0073e9SAndroid Build Coastguard Worker            @staticmethod
346*da0073e9SAndroid Build Coastguard Worker            def backward(ctx, go):
347*da0073e9SAndroid Build Coastguard Worker                return go
348*da0073e9SAndroid Build Coastguard Worker
349*da0073e9SAndroid Build Coastguard Worker        x = torch.tensor([0.0], requires_grad=True)
350*da0073e9SAndroid Build Coastguard Worker
351*da0073e9SAndroid Build Coastguard Worker        def fn(x):
352*da0073e9SAndroid Build Coastguard Worker            y = RegularFn.apply(x)
353*da0073e9SAndroid Build Coastguard Worker            y = InplaceFn.apply(y)
354*da0073e9SAndroid Build Coastguard Worker            y = InplaceFn.apply(y)
355*da0073e9SAndroid Build Coastguard Worker            y = RegularFn.apply(y)
356*da0073e9SAndroid Build Coastguard Worker            return y
357*da0073e9SAndroid Build Coastguard Worker
358*da0073e9SAndroid Build Coastguard Worker        trace_graph, _ = torch.jit._get_trace_graph(fn, (x,), _force_outplace=True)
359*da0073e9SAndroid Build Coastguard Worker        self.run_pass("dce", trace_graph)
360*da0073e9SAndroid Build Coastguard Worker        ops = list(trace_graph.nodes())
361*da0073e9SAndroid Build Coastguard Worker        for op in ops:
362*da0073e9SAndroid Build Coastguard Worker            self.assertTrue(op.hasAttribute("inplace"))
363*da0073e9SAndroid Build Coastguard Worker        inplace_flags = [False, True, True, False]
364*da0073e9SAndroid Build Coastguard Worker        for op, is_inplace in zip(ops, inplace_flags):
365*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(op.i("inplace"), is_inplace)
366*da0073e9SAndroid Build Coastguard Worker
367*da0073e9SAndroid Build Coastguard Worker    def test_inplace_check(self):
368*da0073e9SAndroid Build Coastguard Worker        class MyInplaceFn(Function):
369*da0073e9SAndroid Build Coastguard Worker            @staticmethod
370*da0073e9SAndroid Build Coastguard Worker            def forward(self, x):
371*da0073e9SAndroid Build Coastguard Worker                x.add_(1)
372*da0073e9SAndroid Build Coastguard Worker                self.mark_dirty(x)
373*da0073e9SAndroid Build Coastguard Worker                return x
374*da0073e9SAndroid Build Coastguard Worker
375*da0073e9SAndroid Build Coastguard Worker            @staticmethod
376*da0073e9SAndroid Build Coastguard Worker            def backward(self, grad):
377*da0073e9SAndroid Build Coastguard Worker                return grad
378*da0073e9SAndroid Build Coastguard Worker
379*da0073e9SAndroid Build Coastguard Worker        def fn(x):
380*da0073e9SAndroid Build Coastguard Worker            return MyInplaceFn.apply(x)
381*da0073e9SAndroid Build Coastguard Worker
382*da0073e9SAndroid Build Coastguard Worker        x = torch.randn(5, 5)
383*da0073e9SAndroid Build Coastguard Worker        ge = torch.jit.trace(fn, (x,), _force_outplace=True, check_trace=False)
384*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(RuntimeError, "inplace MyInplaceFn"):
385*da0073e9SAndroid Build Coastguard Worker            ge(x)
386*da0073e9SAndroid Build Coastguard Worker
387*da0073e9SAndroid Build Coastguard Worker    def test_force_outplace_check_fill(self):
388*da0073e9SAndroid Build Coastguard Worker        def f(x):
389*da0073e9SAndroid Build Coastguard Worker            return torch.empty(x.shape).fill_(7)
390*da0073e9SAndroid Build Coastguard Worker
391*da0073e9SAndroid Build Coastguard Worker        x = torch.randn(10, 15)
392*da0073e9SAndroid Build Coastguard Worker        ft = torch.jit.trace(f, x, _force_outplace=True)
393*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(f(x), ft(x))
394*da0073e9SAndroid Build Coastguard Worker
395*da0073e9SAndroid Build Coastguard Worker    def test_force_outplace_check_zero(self):
396*da0073e9SAndroid Build Coastguard Worker        def f(x):
397*da0073e9SAndroid Build Coastguard Worker            return torch.empty(x.shape).zero_()
398*da0073e9SAndroid Build Coastguard Worker
399*da0073e9SAndroid Build Coastguard Worker        x = torch.randn(10, 15)
400*da0073e9SAndroid Build Coastguard Worker        ft = torch.jit.trace(f, x, _force_outplace=True)
401*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(f(x), ft(x))
402*da0073e9SAndroid Build Coastguard Worker
403*da0073e9SAndroid Build Coastguard Worker    def do_trace_size(self, requires_grad):
404*da0073e9SAndroid Build Coastguard Worker        def fn(x):
405*da0073e9SAndroid Build Coastguard Worker            return x.view(x.shape[1] * 2, x.size(0), 2)
406*da0073e9SAndroid Build Coastguard Worker
407*da0073e9SAndroid Build Coastguard Worker        x = torch.randn(5, 2, 4, requires_grad=requires_grad)
408*da0073e9SAndroid Build Coastguard Worker        y = torch.randn(4, 8, 4, requires_grad=requires_grad)
409*da0073e9SAndroid Build Coastguard Worker
410*da0073e9SAndroid Build Coastguard Worker        # Check that it behaves as expected
411*da0073e9SAndroid Build Coastguard Worker        traced_fn = torch.jit.trace(fn, x)
412*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(traced_fn(y), fn(y))
413*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(traced_fn(x), fn(x))
414*da0073e9SAndroid Build Coastguard Worker
415*da0073e9SAndroid Build Coastguard Worker    def test_trace_size(self):
416*da0073e9SAndroid Build Coastguard Worker        self.do_trace_size(False)
417*da0073e9SAndroid Build Coastguard Worker
418*da0073e9SAndroid Build Coastguard Worker    # test the different graph_executor path that happens when
419*da0073e9SAndroid Build Coastguard Worker    # gradients are required and sizes are involved
420*da0073e9SAndroid Build Coastguard Worker    def test_trace_size_with_grad(self):
421*da0073e9SAndroid Build Coastguard Worker        self.do_trace_size(True)
422*da0073e9SAndroid Build Coastguard Worker
423*da0073e9SAndroid Build Coastguard Worker    def test_trace_numel(self):
424*da0073e9SAndroid Build Coastguard Worker        def fn(x):
425*da0073e9SAndroid Build Coastguard Worker            return x.numel()
426*da0073e9SAndroid Build Coastguard Worker
427*da0073e9SAndroid Build Coastguard Worker        x = torch.randn(2, 3, 4)
428*da0073e9SAndroid Build Coastguard Worker        y = torch.randn(4, 5, 6)
429*da0073e9SAndroid Build Coastguard Worker
430*da0073e9SAndroid Build Coastguard Worker        traced_fn = torch.jit.trace(fn, x)
431*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(traced_fn(y), fn(y))
432*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(traced_fn(x), fn(x))
433*da0073e9SAndroid Build Coastguard Worker
434*da0073e9SAndroid Build Coastguard Worker    def do_trace_arange(self, requires_grad):
435*da0073e9SAndroid Build Coastguard Worker        def arange(x):
436*da0073e9SAndroid Build Coastguard Worker            return torch.arange(x.shape[0])
437*da0073e9SAndroid Build Coastguard Worker
438*da0073e9SAndroid Build Coastguard Worker        def arange_scalar(x):
439*da0073e9SAndroid Build Coastguard Worker            return torch.arange(12)
440*da0073e9SAndroid Build Coastguard Worker
441*da0073e9SAndroid Build Coastguard Worker        def arange_start_end(x):
442*da0073e9SAndroid Build Coastguard Worker            return torch.arange(start=x.shape[0], end=x.shape[0] + 5)
443*da0073e9SAndroid Build Coastguard Worker
444*da0073e9SAndroid Build Coastguard Worker        x = torch.randn(5, 3, 2, requires_grad=requires_grad)
445*da0073e9SAndroid Build Coastguard Worker        y = torch.randn(8, 2, 4, requires_grad=requires_grad)
446*da0073e9SAndroid Build Coastguard Worker
447*da0073e9SAndroid Build Coastguard Worker        # Check that it behaves as expected
448*da0073e9SAndroid Build Coastguard Worker        traced_arange = torch.jit.trace(arange, x)
449*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(traced_arange(y), arange(y))
450*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(traced_arange(x), arange(x))
451*da0073e9SAndroid Build Coastguard Worker
452*da0073e9SAndroid Build Coastguard Worker        traced_arange_scalar = torch.jit.trace(arange_scalar, x)
453*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(traced_arange_scalar(y), arange_scalar(y))
454*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(traced_arange_scalar(x), arange_scalar(x))
455*da0073e9SAndroid Build Coastguard Worker
456*da0073e9SAndroid Build Coastguard Worker        traced_arange_start_end = torch.jit.trace(arange_start_end, x)
457*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(traced_arange_start_end(y), arange_start_end(y))
458*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(traced_arange_start_end(x), arange_start_end(x))
459*da0073e9SAndroid Build Coastguard Worker
460*da0073e9SAndroid Build Coastguard Worker    def test_trace_arange(self):
461*da0073e9SAndroid Build Coastguard Worker        self.do_trace_arange(False)
462*da0073e9SAndroid Build Coastguard Worker
463*da0073e9SAndroid Build Coastguard Worker    # test the different graph_executor path that happens when
464*da0073e9SAndroid Build Coastguard Worker    # gradients are required and sizes are involved
465*da0073e9SAndroid Build Coastguard Worker    def test_trace_arange_with_grad(self):
466*da0073e9SAndroid Build Coastguard Worker        self.do_trace_arange(True)
467*da0073e9SAndroid Build Coastguard Worker
468*da0073e9SAndroid Build Coastguard Worker    # Test that a trace of torch.full(x.shape) doesn't store the shape as a constant
469*da0073e9SAndroid Build Coastguard Worker    def test_trace_full_dynamic_shape(self):
470*da0073e9SAndroid Build Coastguard Worker        def full_with_shape_like(x):
471*da0073e9SAndroid Build Coastguard Worker            return torch.full(x.shape, 2.0)
472*da0073e9SAndroid Build Coastguard Worker
473*da0073e9SAndroid Build Coastguard Worker        x = torch.randn(3, 4)
474*da0073e9SAndroid Build Coastguard Worker        ge = torch.jit.trace(full_with_shape_like, example_inputs=x)
475*da0073e9SAndroid Build Coastguard Worker        y = torch.randn(2, 7)
476*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(ge(y).shape, y.shape)
477*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(ge(x).shape, x.shape)
478*da0073e9SAndroid Build Coastguard Worker
479*da0073e9SAndroid Build Coastguard Worker    # Test that the trace of setitem doesn't store shapes as constants
480*da0073e9SAndroid Build Coastguard Worker    # Fix https://github.com/pytorch/pytorch/issues/43548
481*da0073e9SAndroid Build Coastguard Worker    def test_trace_slice_setitem_dynamic_shape(self):
482*da0073e9SAndroid Build Coastguard Worker        def slice_setitem(x, y):
483*da0073e9SAndroid Build Coastguard Worker            x[:, 2] = y + 1
484*da0073e9SAndroid Build Coastguard Worker            return x
485*da0073e9SAndroid Build Coastguard Worker
486*da0073e9SAndroid Build Coastguard Worker        x = torch.randn(3, 4)
487*da0073e9SAndroid Build Coastguard Worker        traced = torch.jit.trace(slice_setitem, (x, x[:, 0]))
488*da0073e9SAndroid Build Coastguard Worker        x = torch.randn(10, 5)
489*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(traced(x.clone(), x[:, 0]), slice_setitem(x.clone(), x[:, 0]))
490*da0073e9SAndroid Build Coastguard Worker
491*da0073e9SAndroid Build Coastguard Worker    # Suppression: we are intentionally slicing a tensor, we don't care that it
492*da0073e9SAndroid Build Coastguard Worker    # will be constantified
493*da0073e9SAndroid Build Coastguard Worker    @suppress_warnings
494*da0073e9SAndroid Build Coastguard Worker    def do_trace_slice(self, requires_grad):
495*da0073e9SAndroid Build Coastguard Worker        def slice(x):
496*da0073e9SAndroid Build Coastguard Worker            results = []
497*da0073e9SAndroid Build Coastguard Worker            for i in range(4):
498*da0073e9SAndroid Build Coastguard Worker                results.append(x[: x.size(0) - i, i : x.size(2), i:3])
499*da0073e9SAndroid Build Coastguard Worker            return tuple(results)
500*da0073e9SAndroid Build Coastguard Worker
501*da0073e9SAndroid Build Coastguard Worker        def slice_select(x):
502*da0073e9SAndroid Build Coastguard Worker            results = []
503*da0073e9SAndroid Build Coastguard Worker            for i in range(4):
504*da0073e9SAndroid Build Coastguard Worker                results.append(x[:, i:, x.size(2) - 5])
505*da0073e9SAndroid Build Coastguard Worker            return tuple(results)
506*da0073e9SAndroid Build Coastguard Worker
507*da0073e9SAndroid Build Coastguard Worker        x = torch.randn(5, 6, 7, requires_grad=requires_grad)
508*da0073e9SAndroid Build Coastguard Worker        y = torch.randn(7, 8, 9, requires_grad=requires_grad)
509*da0073e9SAndroid Build Coastguard Worker
510*da0073e9SAndroid Build Coastguard Worker        # Check that it behaves as expected
511*da0073e9SAndroid Build Coastguard Worker        traced_slice = torch.jit.trace(slice, x)
512*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(traced_slice(y), slice(y))
513*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(traced_slice(x), slice(x))
514*da0073e9SAndroid Build Coastguard Worker
515*da0073e9SAndroid Build Coastguard Worker        traced_slice_select = torch.jit.trace(slice_select, x)
516*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(traced_slice_select(y), slice_select(y))
517*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(traced_slice_select(x), slice_select(x))
518*da0073e9SAndroid Build Coastguard Worker
519*da0073e9SAndroid Build Coastguard Worker    def test_trace_slice(self):
520*da0073e9SAndroid Build Coastguard Worker        self.do_trace_slice(False)
521*da0073e9SAndroid Build Coastguard Worker
522*da0073e9SAndroid Build Coastguard Worker    # test the different graph_executor path that happens when
523*da0073e9SAndroid Build Coastguard Worker    # gradients are required and sizes are involved
524*da0073e9SAndroid Build Coastguard Worker    def test_trace_slice_with_grad(self):
525*da0073e9SAndroid Build Coastguard Worker        self.do_trace_slice(True)
526*da0073e9SAndroid Build Coastguard Worker
527*da0073e9SAndroid Build Coastguard Worker    def test_trace_casts(self):
528*da0073e9SAndroid Build Coastguard Worker        casts = [
529*da0073e9SAndroid Build Coastguard Worker            lambda x: x.byte(),
530*da0073e9SAndroid Build Coastguard Worker            lambda x: x.float(),
531*da0073e9SAndroid Build Coastguard Worker            lambda x: x.cpu(),
532*da0073e9SAndroid Build Coastguard Worker            lambda x: x.to(device="cpu"),
533*da0073e9SAndroid Build Coastguard Worker            lambda x: x.to(dtype=torch.int64),
534*da0073e9SAndroid Build Coastguard Worker            lambda x: x.to(device="cpu", dtype=torch.float),
535*da0073e9SAndroid Build Coastguard Worker            lambda x: x.to(x),
536*da0073e9SAndroid Build Coastguard Worker        ]
537*da0073e9SAndroid Build Coastguard Worker
538*da0073e9SAndroid Build Coastguard Worker        def assertContainsCast(trace):
539*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(
540*da0073e9SAndroid Build Coastguard Worker                sum(n.kind() == "aten::to" for n in trace.graph.nodes()), 1
541*da0073e9SAndroid Build Coastguard Worker            )
542*da0073e9SAndroid Build Coastguard Worker
543*da0073e9SAndroid Build Coastguard Worker        for cast in casts:
544*da0073e9SAndroid Build Coastguard Worker            trace = torch.jit.trace(cast, torch.randn(2, 2))
545*da0073e9SAndroid Build Coastguard Worker            assertContainsCast(trace)
546*da0073e9SAndroid Build Coastguard Worker            x = torch.randn(2, 2)
547*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(trace(x), cast(x))
548*da0073e9SAndroid Build Coastguard Worker
549*da0073e9SAndroid Build Coastguard Worker        def to_tensor(x, y):
550*da0073e9SAndroid Build Coastguard Worker            return x.to(y)
551*da0073e9SAndroid Build Coastguard Worker
552*da0073e9SAndroid Build Coastguard Worker        to_tensor_trace = torch.jit.trace(
553*da0073e9SAndroid Build Coastguard Worker            to_tensor, (torch.randn(2, 2), torch.randn(1, 8))
554*da0073e9SAndroid Build Coastguard Worker        )
555*da0073e9SAndroid Build Coastguard Worker        assertContainsCast(to_tensor_trace)
556*da0073e9SAndroid Build Coastguard Worker        x, y = torch.randn(2, 2), torch.randn(1, 10)
557*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(to_tensor_trace(x, y), to_tensor(x, y))
558*da0073e9SAndroid Build Coastguard Worker
559*da0073e9SAndroid Build Coastguard Worker    @skipIfCompiledWithoutNumpy
560*da0073e9SAndroid Build Coastguard Worker    @skipIfCrossRef
561*da0073e9SAndroid Build Coastguard Worker    def test_trace_warn(self):
562*da0073e9SAndroid Build Coastguard Worker        def fn(x):
563*da0073e9SAndroid Build Coastguard Worker            int(x)  # Warning 1.
564*da0073e9SAndroid Build Coastguard Worker            y = x * 1
565*da0073e9SAndroid Build Coastguard Worker            if y:  # Warning 2.
566*da0073e9SAndroid Build Coastguard Worker                pass
567*da0073e9SAndroid Build Coastguard Worker            q = [x, x * 4]
568*da0073e9SAndroid Build Coastguard Worker            z = q[y]
569*da0073e9SAndroid Build Coastguard Worker            float(z)  # Warning 3.
570*da0073e9SAndroid Build Coastguard Worker            z.tolist()  # Warning 4.
571*da0073e9SAndroid Build Coastguard Worker            z.numpy()  # Warning 5.
572*da0073e9SAndroid Build Coastguard Worker            for _ in torch.ones(4, 4):  # Warning 6.
573*da0073e9SAndroid Build Coastguard Worker                pass
574*da0073e9SAndroid Build Coastguard Worker            return z + 4
575*da0073e9SAndroid Build Coastguard Worker
576*da0073e9SAndroid Build Coastguard Worker        with warnings.catch_warnings(record=True) as warns:
577*da0073e9SAndroid Build Coastguard Worker            traced_fn = torch.jit.trace(fn, torch.tensor([1]))
578*da0073e9SAndroid Build Coastguard Worker        for warn in warns:
579*da0073e9SAndroid Build Coastguard Worker            self.assertIs(warn.category, torch.jit.TracerWarning)
580*da0073e9SAndroid Build Coastguard Worker        warns = [str(w.message) for w in warns]
581*da0073e9SAndroid Build Coastguard Worker        self.assertIn("a Python integer", warns[0])
582*da0073e9SAndroid Build Coastguard Worker        self.assertIn("a Python boolean", warns[1])
583*da0073e9SAndroid Build Coastguard Worker        self.assertIn("a Python float", warns[2])
584*da0073e9SAndroid Build Coastguard Worker        self.assertIn("a Python list", warns[3])
585*da0073e9SAndroid Build Coastguard Worker        self.assertIn("a NumPy array", warns[4])
586*da0073e9SAndroid Build Coastguard Worker        self.assertIn("Iterating over", warns[5])
587*da0073e9SAndroid Build Coastguard Worker
588*da0073e9SAndroid Build Coastguard Worker    def test_trace_tuple(self):
589*da0073e9SAndroid Build Coastguard Worker        def fn(x, y):
590*da0073e9SAndroid Build Coastguard Worker            return x, (x * y[1], x * y[0])
591*da0073e9SAndroid Build Coastguard Worker
592*da0073e9SAndroid Build Coastguard Worker        x, y = torch.randn(2, 2), (torch.ones(2, 2), torch.randn(2, 2))
593*da0073e9SAndroid Build Coastguard Worker        traced_fn = torch.jit.trace(fn, (x, y))
594*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(traced_fn(x, y), fn(x, y))
595*da0073e9SAndroid Build Coastguard Worker        # should be a tuple nested within another tuple
596*da0073e9SAndroid Build Coastguard Worker        FileCheck().check_count("prim::TupleConstruct", 2, exactly=True).check_next(
597*da0073e9SAndroid Build Coastguard Worker            "return"
598*da0073e9SAndroid Build Coastguard Worker        ).run(str(traced_fn.graph))
599*da0073e9SAndroid Build Coastguard Worker        self.assertExportImport(traced_fn.graph, (x, y))
600*da0073e9SAndroid Build Coastguard Worker
601*da0073e9SAndroid Build Coastguard Worker    def test_trace_random(self):
602*da0073e9SAndroid Build Coastguard Worker        def f(mean, std):
603*da0073e9SAndroid Build Coastguard Worker            return torch.normal(mean, std)
604*da0073e9SAndroid Build Coastguard Worker
605*da0073e9SAndroid Build Coastguard Worker        traced = torch.jit.trace(
606*da0073e9SAndroid Build Coastguard Worker            f, (torch.zeros(2, 3), torch.ones(2, 3)), check_trace=False
607*da0073e9SAndroid Build Coastguard Worker        )
608*da0073e9SAndroid Build Coastguard Worker        mean, std = torch.zeros(5, 5), torch.ones(5, 5)
609*da0073e9SAndroid Build Coastguard Worker        with torch.random.fork_rng(devices=[]):
610*da0073e9SAndroid Build Coastguard Worker            output = f(mean, std)
611*da0073e9SAndroid Build Coastguard Worker        traced_output = traced(mean, std)
612*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(output, traced_output)
613*da0073e9SAndroid Build Coastguard Worker
614*da0073e9SAndroid Build Coastguard Worker    def test_trace_tensor_factory(self):
615*da0073e9SAndroid Build Coastguard Worker        def run(**kwargs):
616*da0073e9SAndroid Build Coastguard Worker            inputs_require_grads = kwargs.pop("inputs_require_grads", True)
617*da0073e9SAndroid Build Coastguard Worker
618*da0073e9SAndroid Build Coastguard Worker            def fn(x):
619*da0073e9SAndroid Build Coastguard Worker                return x + torch.ones(2, 3, **kwargs)
620*da0073e9SAndroid Build Coastguard Worker
621*da0073e9SAndroid Build Coastguard Worker            input_kwargs = kwargs.copy()
622*da0073e9SAndroid Build Coastguard Worker            if "out" in input_kwargs:
623*da0073e9SAndroid Build Coastguard Worker                del input_kwargs["out"]
624*da0073e9SAndroid Build Coastguard Worker            input = torch.ones(2, 3, **input_kwargs)
625*da0073e9SAndroid Build Coastguard Worker            self.checkTrace(fn, (input,), inputs_require_grads=inputs_require_grads)
626*da0073e9SAndroid Build Coastguard Worker            # check we recorded 'ones' and did not just record a constant
627*da0073e9SAndroid Build Coastguard Worker            tfn = torch.jit.trace(fn, input)
628*da0073e9SAndroid Build Coastguard Worker            self.assertTrue("ones" in str(tfn.graph))
629*da0073e9SAndroid Build Coastguard Worker
630*da0073e9SAndroid Build Coastguard Worker        run()
631*da0073e9SAndroid Build Coastguard Worker        run(dtype=torch.int, inputs_require_grads=False)
632*da0073e9SAndroid Build Coastguard Worker        run(out=torch.tensor([]))
633*da0073e9SAndroid Build Coastguard Worker        if RUN_CUDA:
634*da0073e9SAndroid Build Coastguard Worker            run(device="cuda:0")
635*da0073e9SAndroid Build Coastguard Worker        if RUN_CUDA_MULTI_GPU:
636*da0073e9SAndroid Build Coastguard Worker            run(device="cuda:1")
637*da0073e9SAndroid Build Coastguard Worker
638*da0073e9SAndroid Build Coastguard Worker    def test_trace_indexed_assignment(self):
639*da0073e9SAndroid Build Coastguard Worker        def stuff(x, y):
640*da0073e9SAndroid Build Coastguard Worker            x = x.clone()
641*da0073e9SAndroid Build Coastguard Worker            x[0] = y
642*da0073e9SAndroid Build Coastguard Worker            return x
643*da0073e9SAndroid Build Coastguard Worker
644*da0073e9SAndroid Build Coastguard Worker        example = torch.rand(3, 4)
645*da0073e9SAndroid Build Coastguard Worker        self.checkTrace(stuff, (example, example[0] + 1))
646*da0073e9SAndroid Build Coastguard Worker
647*da0073e9SAndroid Build Coastguard Worker    # TODO: implement
648*da0073e9SAndroid Build Coastguard Worker    @unittest.expectedFailure
649*da0073e9SAndroid Build Coastguard Worker    def test_output_unflatten(self):
650*da0073e9SAndroid Build Coastguard Worker        """Check that outputs of traced functions retain the original structure and nesting"""
651*da0073e9SAndroid Build Coastguard Worker
652*da0073e9SAndroid Build Coastguard Worker        def fn(x):
653*da0073e9SAndroid Build Coastguard Worker            return (
654*da0073e9SAndroid Build Coastguard Worker                x * 2,
655*da0073e9SAndroid Build Coastguard Worker                (
656*da0073e9SAndroid Build Coastguard Worker                    x**2,
657*da0073e9SAndroid Build Coastguard Worker                    x + 4,
658*da0073e9SAndroid Build Coastguard Worker                    (x + 2,),
659*da0073e9SAndroid Build Coastguard Worker                ),
660*da0073e9SAndroid Build Coastguard Worker                x * 4,
661*da0073e9SAndroid Build Coastguard Worker            )
662*da0073e9SAndroid Build Coastguard Worker
663*da0073e9SAndroid Build Coastguard Worker        self.checkTrace(fn, (torch.randn(2, 2),))
664*da0073e9SAndroid Build Coastguard Worker
665*da0073e9SAndroid Build Coastguard Worker    def test_input_flatten(self):
666*da0073e9SAndroid Build Coastguard Worker        """Check that inputs to traced functions are flattened"""
667*da0073e9SAndroid Build Coastguard Worker
668*da0073e9SAndroid Build Coastguard Worker        def fn(x, t):
669*da0073e9SAndroid Build Coastguard Worker            y, z = t
670*da0073e9SAndroid Build Coastguard Worker            return x * y * z
671*da0073e9SAndroid Build Coastguard Worker
672*da0073e9SAndroid Build Coastguard Worker        inputs = (torch.randn(1), (torch.randn(1), torch.randn(1)))
673*da0073e9SAndroid Build Coastguard Worker        self.checkTrace(fn, inputs)
674*da0073e9SAndroid Build Coastguard Worker
675*da0073e9SAndroid Build Coastguard Worker    def test_input_dict_empty(self):
676*da0073e9SAndroid Build Coastguard Worker        def test(d):
677*da0073e9SAndroid Build Coastguard Worker            pass
678*da0073e9SAndroid Build Coastguard Worker
679*da0073e9SAndroid Build Coastguard Worker        with self.assertRaises(RuntimeError):
680*da0073e9SAndroid Build Coastguard Worker            self.checkTrace(test, {})
681*da0073e9SAndroid Build Coastguard Worker
682*da0073e9SAndroid Build Coastguard Worker    def test_input_dict_remembers_keys(self):
683*da0073e9SAndroid Build Coastguard Worker        """Check that the trace remembers which keys were in a dict input"""
684*da0073e9SAndroid Build Coastguard Worker
685*da0073e9SAndroid Build Coastguard Worker        class TestModule(torch.nn.Module):
686*da0073e9SAndroid Build Coastguard Worker            def forward(self, dict_input):
687*da0073e9SAndroid Build Coastguard Worker                return dict_input["x"]
688*da0073e9SAndroid Build Coastguard Worker
689*da0073e9SAndroid Build Coastguard Worker        input_1 = {"x": torch.tensor(1)}
690*da0073e9SAndroid Build Coastguard Worker        m = TestModule()
691*da0073e9SAndroid Build Coastguard Worker        m_traced = torch.jit.trace(m, (input_1,))
692*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(m_traced(input_1), torch.tensor(1))
693*da0073e9SAndroid Build Coastguard Worker
694*da0073e9SAndroid Build Coastguard Worker        # should work to change the values and not the keys
695*da0073e9SAndroid Build Coastguard Worker        input_same_key_different_value = {"x": torch.tensor(2)}
696*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(m_traced(input_same_key_different_value), torch.tensor(2))
697*da0073e9SAndroid Build Coastguard Worker
698*da0073e9SAndroid Build Coastguard Worker        # error to use something that doesn't have `x`
699*da0073e9SAndroid Build Coastguard Worker        input_different_key = {"y": torch.tensor(3)}
700*da0073e9SAndroid Build Coastguard Worker        with self.assertRaises(RuntimeError):
701*da0073e9SAndroid Build Coastguard Worker            m_traced(input_different_key)
702*da0073e9SAndroid Build Coastguard Worker
703*da0073e9SAndroid Build Coastguard Worker        # it's okay to have additional elements in the dictionary, so long as 'x' is there
704*da0073e9SAndroid Build Coastguard Worker        input_additional_key = {"x": torch.tensor(4), "y": torch.tensor(3)}
705*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(m_traced(input_additional_key), torch.tensor(4))
706*da0073e9SAndroid Build Coastguard Worker
707*da0073e9SAndroid Build Coastguard Worker    def test_input_dict_insertion_order(self):
708*da0073e9SAndroid Build Coastguard Worker        """Check that dictionary access doesn't care about insertion order"""
709*da0073e9SAndroid Build Coastguard Worker
710*da0073e9SAndroid Build Coastguard Worker        class TestModule(torch.nn.Module):
711*da0073e9SAndroid Build Coastguard Worker            def forward(self, dict_input):
712*da0073e9SAndroid Build Coastguard Worker                return dict_input["x"], dict_input["y"]
713*da0073e9SAndroid Build Coastguard Worker
714*da0073e9SAndroid Build Coastguard Worker        input_x_then_y = {}
715*da0073e9SAndroid Build Coastguard Worker        input_x_then_y["x"] = torch.tensor(1)
716*da0073e9SAndroid Build Coastguard Worker        input_x_then_y["y"] = torch.tensor(2)
717*da0073e9SAndroid Build Coastguard Worker
718*da0073e9SAndroid Build Coastguard Worker        m = TestModule()
719*da0073e9SAndroid Build Coastguard Worker        m_traced = torch.jit.trace(m, (input_x_then_y,))
720*da0073e9SAndroid Build Coastguard Worker
721*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(m_traced(input_x_then_y), (torch.tensor(1), torch.tensor(2)))
722*da0073e9SAndroid Build Coastguard Worker
723*da0073e9SAndroid Build Coastguard Worker        input_y_then_x = {}
724*da0073e9SAndroid Build Coastguard Worker        input_y_then_x["y"] = torch.tensor(4)
725*da0073e9SAndroid Build Coastguard Worker        input_y_then_x["x"] = torch.tensor(3)
726*da0073e9SAndroid Build Coastguard Worker
727*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(m_traced(input_y_then_x), (torch.tensor(3), torch.tensor(4)))
728*da0073e9SAndroid Build Coastguard Worker
729*da0073e9SAndroid Build Coastguard Worker    def test_input_dict_recursive(self):
730*da0073e9SAndroid Build Coastguard Worker        class TestModule(torch.nn.Module):
731*da0073e9SAndroid Build Coastguard Worker            def forward(self, dict_input):
732*da0073e9SAndroid Build Coastguard Worker                return dict_input["x"][1]
733*da0073e9SAndroid Build Coastguard Worker
734*da0073e9SAndroid Build Coastguard Worker        input_1 = {"x": {1: torch.tensor(1)}}
735*da0073e9SAndroid Build Coastguard Worker        m = TestModule()
736*da0073e9SAndroid Build Coastguard Worker        m_traced = torch.jit.trace(m, (input_1,))
737*da0073e9SAndroid Build Coastguard Worker
738*da0073e9SAndroid Build Coastguard Worker        input_2 = {"x": {1: torch.tensor(2)}}
739*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(m_traced(input_2), torch.tensor(2))
740*da0073e9SAndroid Build Coastguard Worker
741*da0073e9SAndroid Build Coastguard Worker    def test_input_dict_checkTrace_mut(self):
742*da0073e9SAndroid Build Coastguard Worker        def test(d):
743*da0073e9SAndroid Build Coastguard Worker            d["x"].tanh_()
744*da0073e9SAndroid Build Coastguard Worker            return d["x"]
745*da0073e9SAndroid Build Coastguard Worker
746*da0073e9SAndroid Build Coastguard Worker        inputs = {"x": torch.rand(3, 4), "y": torch.rand(3, 4)}
747*da0073e9SAndroid Build Coastguard Worker        self.checkTrace(test, (inputs,), inputs_require_grads=False)
748*da0073e9SAndroid Build Coastguard Worker
749*da0073e9SAndroid Build Coastguard Worker    def test_input_dict_unify(self):
750*da0073e9SAndroid Build Coastguard Worker        def test(d):
751*da0073e9SAndroid Build Coastguard Worker            return d["int"], d["float"]
752*da0073e9SAndroid Build Coastguard Worker
753*da0073e9SAndroid Build Coastguard Worker        inputs = {
754*da0073e9SAndroid Build Coastguard Worker            "int": torch.ones((2, 2), dtype=torch.int32),
755*da0073e9SAndroid Build Coastguard Worker            "float": torch.ones((2, 2), dtype=torch.float32),
756*da0073e9SAndroid Build Coastguard Worker        }
757*da0073e9SAndroid Build Coastguard Worker        self.checkTrace(test, (inputs,), inputs_require_grads=False)
758*da0073e9SAndroid Build Coastguard Worker
759*da0073e9SAndroid Build Coastguard Worker    def test_input_tuple_of_dicts(self):
760*da0073e9SAndroid Build Coastguard Worker        def test(t):
761*da0073e9SAndroid Build Coastguard Worker            d = t[0]
762*da0073e9SAndroid Build Coastguard Worker            return d["x"]["y"]
763*da0073e9SAndroid Build Coastguard Worker
764*da0073e9SAndroid Build Coastguard Worker        inputs = {"x": {"y": torch.rand(2, 3)}}
765*da0073e9SAndroid Build Coastguard Worker        self.checkTrace(test, ((inputs, inputs),), allow_unused=True)
766*da0073e9SAndroid Build Coastguard Worker
767*da0073e9SAndroid Build Coastguard Worker    def test_input_dict_of_dicts(self):
768*da0073e9SAndroid Build Coastguard Worker        def test(d):
769*da0073e9SAndroid Build Coastguard Worker            return d["x"]["y"]
770*da0073e9SAndroid Build Coastguard Worker
771*da0073e9SAndroid Build Coastguard Worker        nested_input = {"y": torch.rand(2, 3)}
772*da0073e9SAndroid Build Coastguard Worker        unified_nested = {"y": torch.rand(3, 2)}
773*da0073e9SAndroid Build Coastguard Worker        inputs = {"x": nested_input, "force_unify": unified_nested}
774*da0073e9SAndroid Build Coastguard Worker        self.checkTrace(test, (inputs,), allow_unused=True)
775*da0073e9SAndroid Build Coastguard Worker
776*da0073e9SAndroid Build Coastguard Worker    def test_input_dict_of_lists(self):
777*da0073e9SAndroid Build Coastguard Worker        def test(d):
778*da0073e9SAndroid Build Coastguard Worker            return d["x"][0]
779*da0073e9SAndroid Build Coastguard Worker
780*da0073e9SAndroid Build Coastguard Worker        inputs = {"x": [torch.rand(3, 2)]}
781*da0073e9SAndroid Build Coastguard Worker        self.checkTrace(test, (inputs,))
782*da0073e9SAndroid Build Coastguard Worker
783*da0073e9SAndroid Build Coastguard Worker    def test_input_list_toplevel_flatten(self):
784*da0073e9SAndroid Build Coastguard Worker        def test(t1, t2):
785*da0073e9SAndroid Build Coastguard Worker            return torch.add(t1, t2)
786*da0073e9SAndroid Build Coastguard Worker
787*da0073e9SAndroid Build Coastguard Worker        inputs = [torch.ones(2, 2), torch.rand(2, 2)]
788*da0073e9SAndroid Build Coastguard Worker        self.checkTrace(test, inputs)
789*da0073e9SAndroid Build Coastguard Worker
790*da0073e9SAndroid Build Coastguard Worker    def test_input_list_toplevel_flatten_direct(self):
791*da0073e9SAndroid Build Coastguard Worker        class Test(torch.nn.Module):
792*da0073e9SAndroid Build Coastguard Worker            def forward(self, t1, t2):
793*da0073e9SAndroid Build Coastguard Worker                return torch.add(t1, t2)
794*da0073e9SAndroid Build Coastguard Worker
795*da0073e9SAndroid Build Coastguard Worker        inputs = [torch.ones(2, 2), torch.rand(2, 2)]
796*da0073e9SAndroid Build Coastguard Worker        torch.jit.trace(Test(), inputs)
797*da0073e9SAndroid Build Coastguard Worker
798*da0073e9SAndroid Build Coastguard Worker    def test_input_list_of_tuples(self):
799*da0073e9SAndroid Build Coastguard Worker        def test(l):
800*da0073e9SAndroid Build Coastguard Worker            return l[0][0]
801*da0073e9SAndroid Build Coastguard Worker
802*da0073e9SAndroid Build Coastguard Worker        inputs = [(torch.ones(2, 2),)]
803*da0073e9SAndroid Build Coastguard Worker        self.checkTrace(test, (inputs,))
804*da0073e9SAndroid Build Coastguard Worker
805*da0073e9SAndroid Build Coastguard Worker    def test_input_dict_empty_list(self):
806*da0073e9SAndroid Build Coastguard Worker        def test(d):
807*da0073e9SAndroid Build Coastguard Worker            pass
808*da0073e9SAndroid Build Coastguard Worker
809*da0073e9SAndroid Build Coastguard Worker        inputs = {1: []}
810*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(RuntimeError, "List trace"):
811*da0073e9SAndroid Build Coastguard Worker            self.checkTrace(test, (inputs,))
812*da0073e9SAndroid Build Coastguard Worker
813*da0073e9SAndroid Build Coastguard Worker    def test_input_list_mixed_type(self):
814*da0073e9SAndroid Build Coastguard Worker        def test(d):
815*da0073e9SAndroid Build Coastguard Worker            pass
816*da0073e9SAndroid Build Coastguard Worker
817*da0073e9SAndroid Build Coastguard Worker        inputs = [torch.rand(2, 3), (torch.ones(2), torch.ones(2))]
818*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(RuntimeError, "consistent"):
819*da0073e9SAndroid Build Coastguard Worker            self.checkTrace(test, (inputs,))
820*da0073e9SAndroid Build Coastguard Worker
821*da0073e9SAndroid Build Coastguard Worker    def test_conv(self):
822*da0073e9SAndroid Build Coastguard Worker        x = torch.ones(20, 16, 50, 40)
823*da0073e9SAndroid Build Coastguard Worker        g, outputs, inputs = torch.jit._get_trace_graph(
824*da0073e9SAndroid Build Coastguard Worker            nn.Conv2d(16, 13, 3, bias=False), x, return_inputs=True
825*da0073e9SAndroid Build Coastguard Worker        )
826*da0073e9SAndroid Build Coastguard Worker        m = self.createFunctionFromGraph(g)
827*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(outputs, m(*inputs))
828*da0073e9SAndroid Build Coastguard Worker
829*da0073e9SAndroid Build Coastguard Worker    def test_max_pool(self):
830*da0073e9SAndroid Build Coastguard Worker        x = torch.rand(20, 16, 10, 10)
831*da0073e9SAndroid Build Coastguard Worker
832*da0073e9SAndroid Build Coastguard Worker        def max_pool2d(x):
833*da0073e9SAndroid Build Coastguard Worker            return F.max_pool2d(x, 2) + 2
834*da0073e9SAndroid Build Coastguard Worker
835*da0073e9SAndroid Build Coastguard Worker        trace = torch.jit.trace(max_pool2d, (x))
836*da0073e9SAndroid Build Coastguard Worker        graph = trace.graph_for(x)
837*da0073e9SAndroid Build Coastguard Worker        FileCheck().check("aten::max_pool2d(").run(graph)
838*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(max_pool2d(x), trace(x))
839*da0073e9SAndroid Build Coastguard Worker
840*da0073e9SAndroid Build Coastguard Worker    def test_nested_inplace(self):
841*da0073e9SAndroid Build Coastguard Worker        x = torch.randn(2, 2)
842*da0073e9SAndroid Build Coastguard Worker        g, outputs, inputs = torch.jit._get_trace_graph(
843*da0073e9SAndroid Build Coastguard Worker            lambda x: F.threshold(x, 0, 0, inplace=True), (x,), return_inputs=True
844*da0073e9SAndroid Build Coastguard Worker        )
845*da0073e9SAndroid Build Coastguard Worker        m = self.createFunctionFromGraph(g)
846*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(outputs, m(*inputs))
847*da0073e9SAndroid Build Coastguard Worker        FileCheck().check("threshold_").run(str(g))
848*da0073e9SAndroid Build Coastguard Worker        self.assertExportImport(g, (x,))
849*da0073e9SAndroid Build Coastguard Worker
850*da0073e9SAndroid Build Coastguard Worker    def test_repeated_input(self):
851*da0073e9SAndroid Build Coastguard Worker        def fn(a, b):
852*da0073e9SAndroid Build Coastguard Worker            return a + b
853*da0073e9SAndroid Build Coastguard Worker
854*da0073e9SAndroid Build Coastguard Worker        ge = self.checkTrace(fn, [torch.randn(2, 2)] * 2)
855*da0073e9SAndroid Build Coastguard Worker        inputs = set(ge.graph.inputs())
856*da0073e9SAndroid Build Coastguard Worker        # three instead of 2 because the export/import in checkTrace adds a
857*da0073e9SAndroid Build Coastguard Worker        # `self` module argument
858*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(len(inputs) == 3)
859*da0073e9SAndroid Build Coastguard Worker
860*da0073e9SAndroid Build Coastguard Worker    def test_repeated_output(self):
861*da0073e9SAndroid Build Coastguard Worker        def fn(a, b):
862*da0073e9SAndroid Build Coastguard Worker            z = a + b
863*da0073e9SAndroid Build Coastguard Worker            return z, z
864*da0073e9SAndroid Build Coastguard Worker
865*da0073e9SAndroid Build Coastguard Worker        ge = self.checkTrace(fn, [torch.randn(2, 2) for _ in range(2)])
866*da0073e9SAndroid Build Coastguard Worker        tuple_output = list(ge.graph.outputs())[0]
867*da0073e9SAndroid Build Coastguard Worker        tuple_inputs = list(tuple_output.node().inputs())
868*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(tuple_inputs[0] == tuple_inputs[1])
869*da0073e9SAndroid Build Coastguard Worker
870*da0073e9SAndroid Build Coastguard Worker    def test_inplace_copy(self):
871*da0073e9SAndroid Build Coastguard Worker        x = torch.randn(4, 4, requires_grad=True)
872*da0073e9SAndroid Build Coastguard Worker
873*da0073e9SAndroid Build Coastguard Worker        def f(x):
874*da0073e9SAndroid Build Coastguard Worker            out = torch.zeros(x.size())
875*da0073e9SAndroid Build Coastguard Worker            out.copy_(x)
876*da0073e9SAndroid Build Coastguard Worker            return out
877*da0073e9SAndroid Build Coastguard Worker
878*da0073e9SAndroid Build Coastguard Worker        g, outputs, inputs = torch.jit._get_trace_graph(f, (x,), return_inputs=True)
879*da0073e9SAndroid Build Coastguard Worker        self.run_pass("dce", g)
880*da0073e9SAndroid Build Coastguard Worker        m = self.createFunctionFromGraph(g)
881*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(outputs, m(*inputs))
882*da0073e9SAndroid Build Coastguard Worker        self.assertExportImport(g, (x,))
883*da0073e9SAndroid Build Coastguard Worker
884*da0073e9SAndroid Build Coastguard Worker    def test_inplace_copy_force_outplace(self):
885*da0073e9SAndroid Build Coastguard Worker        x = torch.randn(4, 4, requires_grad=True)
886*da0073e9SAndroid Build Coastguard Worker
887*da0073e9SAndroid Build Coastguard Worker        def f(x):
888*da0073e9SAndroid Build Coastguard Worker            out = torch.zeros(x.size())
889*da0073e9SAndroid Build Coastguard Worker            out.copy_(x)
890*da0073e9SAndroid Build Coastguard Worker            return out
891*da0073e9SAndroid Build Coastguard Worker
892*da0073e9SAndroid Build Coastguard Worker        g, outputs, inputs = torch.jit._get_trace_graph(
893*da0073e9SAndroid Build Coastguard Worker            f, (x,), return_inputs=True, _force_outplace=True
894*da0073e9SAndroid Build Coastguard Worker        )
895*da0073e9SAndroid Build Coastguard Worker        self.run_pass("dce", g)
896*da0073e9SAndroid Build Coastguard Worker        m = self.createFunctionFromGraph(g)
897*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(outputs, m(*inputs))
898*da0073e9SAndroid Build Coastguard Worker        self.assertExportImport(g, (x,))
899*da0073e9SAndroid Build Coastguard Worker        FileCheck().check("expand_as").run(str(g))
900*da0073e9SAndroid Build Coastguard Worker
901*da0073e9SAndroid Build Coastguard Worker    def test_shared_param(self):
902*da0073e9SAndroid Build Coastguard Worker        class MyModule(torch.nn.Module):
903*da0073e9SAndroid Build Coastguard Worker            def __init__(self) -> None:
904*da0073e9SAndroid Build Coastguard Worker                super().__init__()
905*da0073e9SAndroid Build Coastguard Worker                self.b = self.a = nn.Parameter(torch.randn(2, 2))
906*da0073e9SAndroid Build Coastguard Worker
907*da0073e9SAndroid Build Coastguard Worker            def forward(self, x):
908*da0073e9SAndroid Build Coastguard Worker                return x * self.a + self.b
909*da0073e9SAndroid Build Coastguard Worker
910*da0073e9SAndroid Build Coastguard Worker        m = MyModule()
911*da0073e9SAndroid Build Coastguard Worker        g, _ = torch.jit._get_trace_graph(m, (torch.randn(2, 2),))
912*da0073e9SAndroid Build Coastguard Worker        self.run_pass("dce", g)
913*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(len(list(g.inputs())), 2)
914*da0073e9SAndroid Build Coastguard Worker        FileCheck().check("mul").check("add").run(str(g))
915*da0073e9SAndroid Build Coastguard Worker
916*da0073e9SAndroid Build Coastguard Worker    def run_ge_tests(self, optimize, use_cuda):
917*da0073e9SAndroid Build Coastguard Worker        with enable_profiling_mode_for_profiling_tests():
918*da0073e9SAndroid Build Coastguard Worker            with torch.jit.optimized_execution(optimize):
919*da0073e9SAndroid Build Coastguard Worker
920*da0073e9SAndroid Build Coastguard Worker                def rand(*args):
921*da0073e9SAndroid Build Coastguard Worker                    t = torch.rand(*args).float()
922*da0073e9SAndroid Build Coastguard Worker                    if use_cuda:
923*da0073e9SAndroid Build Coastguard Worker                        t = t.cuda()
924*da0073e9SAndroid Build Coastguard Worker                    return t
925*da0073e9SAndroid Build Coastguard Worker
926*da0073e9SAndroid Build Coastguard Worker                self.checkTrace(
927*da0073e9SAndroid Build Coastguard Worker                    lambda a, b: a * b + b, [rand(1), rand(1)], [rand(2, 3), rand(2, 3)]
928*da0073e9SAndroid Build Coastguard Worker                )
929*da0073e9SAndroid Build Coastguard Worker                # trivial identity
930*da0073e9SAndroid Build Coastguard Worker                self.checkTrace(lambda a, b: (b, a), [rand(1), rand(1)])
931*da0073e9SAndroid Build Coastguard Worker
932*da0073e9SAndroid Build Coastguard Worker                def foo(a):
933*da0073e9SAndroid Build Coastguard Worker                    t = a * a
934*da0073e9SAndroid Build Coastguard Worker                    return t * t, 4 * t
935*da0073e9SAndroid Build Coastguard Worker
936*da0073e9SAndroid Build Coastguard Worker                self.checkTrace(foo, [rand(1)])
937*da0073e9SAndroid Build Coastguard Worker                # unused input
938*da0073e9SAndroid Build Coastguard Worker                self.checkTrace(
939*da0073e9SAndroid Build Coastguard Worker                    lambda a, b: a * a, [rand(1), rand(1)], allow_unused=True
940*da0073e9SAndroid Build Coastguard Worker                )
941*da0073e9SAndroid Build Coastguard Worker                # test outputs that do not get used in grad
942*da0073e9SAndroid Build Coastguard Worker                self.checkTrace(foo, [rand(1)], drop=1)
943*da0073e9SAndroid Build Coastguard Worker                # test autograd fallback
944*da0073e9SAndroid Build Coastguard Worker                self.checkTrace(
945*da0073e9SAndroid Build Coastguard Worker                    lambda a, b: a * b / (a - 2 * b) + b, [rand(1), rand(1)]
946*da0073e9SAndroid Build Coastguard Worker                )
947*da0073e9SAndroid Build Coastguard Worker
948*da0073e9SAndroid Build Coastguard Worker    def test_ge_unoptimized(self):
949*da0073e9SAndroid Build Coastguard Worker        self.run_ge_tests(False, False)
950*da0073e9SAndroid Build Coastguard Worker
951*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(IS_SANDCASTLE, "NYI: fuser support for Sandcastle")
952*da0073e9SAndroid Build Coastguard Worker    @enable_cpu_fuser
953*da0073e9SAndroid Build Coastguard Worker    def test_ge_optimized(self):
954*da0073e9SAndroid Build Coastguard Worker        with enable_profiling_mode_for_profiling_tests():
955*da0073e9SAndroid Build Coastguard Worker            self.run_ge_tests(True, False)
956*da0073e9SAndroid Build Coastguard Worker
957*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(not RUN_CUDA, "requires CUDA")
958*da0073e9SAndroid Build Coastguard Worker    def test_ge_cuda(self):
959*da0073e9SAndroid Build Coastguard Worker        self.run_ge_tests(True, True)
960*da0073e9SAndroid Build Coastguard Worker
961*da0073e9SAndroid Build Coastguard Worker    # more manual test of graph executor that can be used as a scratchpad
962*da0073e9SAndroid Build Coastguard Worker    def test_ge(self):
963*da0073e9SAndroid Build Coastguard Worker        def foo(a, b):
964*da0073e9SAndroid Build Coastguard Worker            return a * b / (a - b) + b
965*da0073e9SAndroid Build Coastguard Worker
966*da0073e9SAndroid Build Coastguard Worker        V = Variable
967*da0073e9SAndroid Build Coastguard Worker        a, b = V(torch.rand(1)), V(torch.rand(1))
968*da0073e9SAndroid Build Coastguard Worker        ge = torch.jit.trace(foo, (a, b))
969*da0073e9SAndroid Build Coastguard Worker        a, b = V(torch.rand(1), requires_grad=True), V(
970*da0073e9SAndroid Build Coastguard Worker            torch.rand(1), requires_grad=True
971*da0073e9SAndroid Build Coastguard Worker        )
972*da0073e9SAndroid Build Coastguard Worker        (r,) = ge(a, b)
973*da0073e9SAndroid Build Coastguard Worker        da, db = torch.autograd.grad(r + 3, [a, b], create_graph=True)
974*da0073e9SAndroid Build Coastguard Worker
975*da0073e9SAndroid Build Coastguard Worker        l2 = da * db + db * db
976*da0073e9SAndroid Build Coastguard Worker        g2result = torch.autograd.grad(l2, [da, db])
977*da0073e9SAndroid Build Coastguard Worker
978*da0073e9SAndroid Build Coastguard Worker        r = foo(a, b)
979*da0073e9SAndroid Build Coastguard Worker        da2, db2 = torch.autograd.grad(r + 3, [a, b], create_graph=True)
980*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(da, da2)
981*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(db, db2)
982*da0073e9SAndroid Build Coastguard Worker        l3 = da2 * db2 + db2 * db2
983*da0073e9SAndroid Build Coastguard Worker        g2result2 = torch.autograd.grad(l3, [da2, db2])
984*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(g2result, g2result2)
985*da0073e9SAndroid Build Coastguard Worker
986*da0073e9SAndroid Build Coastguard Worker    def test_trace_annotation(self):
987*da0073e9SAndroid Build Coastguard Worker        @_trace(torch.rand(1))
988*da0073e9SAndroid Build Coastguard Worker        def foo(a):
989*da0073e9SAndroid Build Coastguard Worker            return a + a + a
990*da0073e9SAndroid Build Coastguard Worker
991*da0073e9SAndroid Build Coastguard Worker        x = torch.randn(5, 5)
992*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(foo(x), x + x + x)
993*da0073e9SAndroid Build Coastguard Worker
994*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(not RUN_CUDA, "calls .cuda()")
995*da0073e9SAndroid Build Coastguard Worker    # By default, on Ampere or later GPUs, nn.Linear computes float tensors at TF32 precision.
996*da0073e9SAndroid Build Coastguard Worker    # We want float tensors to be computed at full precision in order to use the default precision
997*da0073e9SAndroid Build Coastguard Worker    @with_tf32_off
998*da0073e9SAndroid Build Coastguard Worker    def test_traced_module_cuda(self):
999*da0073e9SAndroid Build Coastguard Worker        class Model(nn.Module):
1000*da0073e9SAndroid Build Coastguard Worker            def __init__(self, num_features, num_layers):
1001*da0073e9SAndroid Build Coastguard Worker                super().__init__()
1002*da0073e9SAndroid Build Coastguard Worker                self.num_layers = num_layers
1003*da0073e9SAndroid Build Coastguard Worker                layers = [
1004*da0073e9SAndroid Build Coastguard Worker                    [nn.Linear(num_features, num_features), nn.Sigmoid()]
1005*da0073e9SAndroid Build Coastguard Worker                    for _ in range(num_layers)
1006*da0073e9SAndroid Build Coastguard Worker                ]
1007*da0073e9SAndroid Build Coastguard Worker                self.submodule = nn.Sequential(*chain(*layers))
1008*da0073e9SAndroid Build Coastguard Worker
1009*da0073e9SAndroid Build Coastguard Worker            def forward(self, x):
1010*da0073e9SAndroid Build Coastguard Worker                for i in range(self.num_layers):
1011*da0073e9SAndroid Build Coastguard Worker                    x = self.submodule[i](x) + x
1012*da0073e9SAndroid Build Coastguard Worker                return x
1013*da0073e9SAndroid Build Coastguard Worker
1014*da0073e9SAndroid Build Coastguard Worker        model = Model(5, 3)
1015*da0073e9SAndroid Build Coastguard Worker        x = torch.randn(2, 5)
1016*da0073e9SAndroid Build Coastguard Worker        traced_model = torch.jit.trace(model, x)
1017*da0073e9SAndroid Build Coastguard Worker
1018*da0073e9SAndroid Build Coastguard Worker        # We're missing some attributes these modules had initially. Make sure we can
1019*da0073e9SAndroid Build Coastguard Worker        # still get the __repr__()
1020*da0073e9SAndroid Build Coastguard Worker        model.__repr__()
1021*da0073e9SAndroid Build Coastguard Worker
1022*da0073e9SAndroid Build Coastguard Worker        # XXX: indexing sequentials is broken
1023*da0073e9SAndroid Build Coastguard Worker        linear_submodule = next(iter(traced_model.submodule._modules.values()))
1024*da0073e9SAndroid Build Coastguard Worker
1025*da0073e9SAndroid Build Coastguard Worker        # All attributes that aren't parameters should raise
1026*da0073e9SAndroid Build Coastguard Worker        with self.assertRaises(AttributeError):
1027*da0073e9SAndroid Build Coastguard Worker            linear_submodule.in_features
1028*da0073e9SAndroid Build Coastguard Worker        linear_submodule.weight
1029*da0073e9SAndroid Build Coastguard Worker        linear_submodule.weight = nn.Parameter(
1030*da0073e9SAndroid Build Coastguard Worker            torch.randn(linear_submodule.weight.shape)
1031*da0073e9SAndroid Build Coastguard Worker        )
1032*da0073e9SAndroid Build Coastguard Worker        with self.assertRaises(RuntimeError):
1033*da0073e9SAndroid Build Coastguard Worker            del linear_submodule.weight
1034*da0073e9SAndroid Build Coastguard Worker
1035*da0073e9SAndroid Build Coastguard Worker        # Submodules can't be called
1036*da0073e9SAndroid Build Coastguard Worker        with self.assertRaises(RuntimeError):
1037*da0073e9SAndroid Build Coastguard Worker            linear_submodule(x)
1038*da0073e9SAndroid Build Coastguard Worker
1039*da0073e9SAndroid Build Coastguard Worker        # Type casts
1040*da0073e9SAndroid Build Coastguard Worker        linear_submodule.cuda()
1041*da0073e9SAndroid Build Coastguard Worker        traced_model.float().cuda()
1042*da0073e9SAndroid Build Coastguard Worker        cuda_out = traced_model(x.float().cuda())
1043*da0073e9SAndroid Build Coastguard Worker        traced_model.cpu()
1044*da0073e9SAndroid Build Coastguard Worker        cpu_out = traced_model(x.float())
1045*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(cpu_out, cuda_out)
1046*da0073e9SAndroid Build Coastguard Worker        traced_model.to("cuda")
1047*da0073e9SAndroid Build Coastguard Worker        cuda_out = traced_model(x.float().cuda())
1048*da0073e9SAndroid Build Coastguard Worker        traced_model.to("cpu")
1049*da0073e9SAndroid Build Coastguard Worker        cpu_out = traced_model(x.float())
1050*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(cpu_out, cuda_out)
1051*da0073e9SAndroid Build Coastguard Worker        traced_model.to(torch.get_default_dtype())
1052*da0073e9SAndroid Build Coastguard Worker
1053*da0073e9SAndroid Build Coastguard Worker        # state_dict + load_state_dict
1054*da0073e9SAndroid Build Coastguard Worker        state = {k: v.clone() for k, v in traced_model.state_dict().items()}
1055*da0073e9SAndroid Build Coastguard Worker        new_state = {k: v.clone().fill_(1) for k, v in state.items()}
1056*da0073e9SAndroid Build Coastguard Worker        out = traced_model(x)
1057*da0073e9SAndroid Build Coastguard Worker        traced_model.load_state_dict(new_state)
1058*da0073e9SAndroid Build Coastguard Worker        out_ones = traced_model(x)
1059*da0073e9SAndroid Build Coastguard Worker        traced_model.load_state_dict(state)
1060*da0073e9SAndroid Build Coastguard Worker        out_state = traced_model(x)
1061*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(out, out_state)
1062*da0073e9SAndroid Build Coastguard Worker        self.assertNotEqual(out, out_ones)
1063*da0073e9SAndroid Build Coastguard Worker
1064*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(not RUN_CUDA, "uses cuda")
1065*da0073e9SAndroid Build Coastguard Worker    def test_type_same_device(self):
1066*da0073e9SAndroid Build Coastguard Worker        class Model(torch.nn.Module):
1067*da0073e9SAndroid Build Coastguard Worker            def __init__(self) -> None:
1068*da0073e9SAndroid Build Coastguard Worker                super().__init__()
1069*da0073e9SAndroid Build Coastguard Worker                self.dtype = torch.float16
1070*da0073e9SAndroid Build Coastguard Worker
1071*da0073e9SAndroid Build Coastguard Worker            def forward(self, x=None):
1072*da0073e9SAndroid Build Coastguard Worker                h = x.type(self.dtype)
1073*da0073e9SAndroid Build Coastguard Worker                return h
1074*da0073e9SAndroid Build Coastguard Worker
1075*da0073e9SAndroid Build Coastguard Worker        a = Model()
1076*da0073e9SAndroid Build Coastguard Worker        b = torch.jit.trace(
1077*da0073e9SAndroid Build Coastguard Worker            a, example_inputs=(torch.ones([1], device=torch.device("cuda")),)
1078*da0073e9SAndroid Build Coastguard Worker        )
1079*da0073e9SAndroid Build Coastguard Worker        FileCheck().check_not("device").run(b.code)
1080*da0073e9SAndroid Build Coastguard Worker
1081*da0073e9SAndroid Build Coastguard Worker    def test_export_no_reorder(self):
1082*da0073e9SAndroid Build Coastguard Worker        def func(a, b):
1083*da0073e9SAndroid Build Coastguard Worker            return a * b / (a - 2 * b) + b
1084*da0073e9SAndroid Build Coastguard Worker
1085*da0073e9SAndroid Build Coastguard Worker        recording_inputs = [
1086*da0073e9SAndroid Build Coastguard Worker            torch.tensor(
1087*da0073e9SAndroid Build Coastguard Worker                [0.55619788169860839844], dtype=torch.float32, requires_grad=True
1088*da0073e9SAndroid Build Coastguard Worker            ),
1089*da0073e9SAndroid Build Coastguard Worker            torch.tensor(
1090*da0073e9SAndroid Build Coastguard Worker                [0.25947844982147216797], dtype=torch.float32, requires_grad=True
1091*da0073e9SAndroid Build Coastguard Worker            ),
1092*da0073e9SAndroid Build Coastguard Worker        ]
1093*da0073e9SAndroid Build Coastguard Worker
1094*da0073e9SAndroid Build Coastguard Worker        ge1 = torch.jit.trace(func, recording_inputs)
1095*da0073e9SAndroid Build Coastguard Worker        ge2 = self.getExportImportCopy(ge1)
1096*da0073e9SAndroid Build Coastguard Worker
1097*da0073e9SAndroid Build Coastguard Worker        outputs_ge1 = ge1(*recording_inputs)
1098*da0073e9SAndroid Build Coastguard Worker        outputs_ge2 = ge2(*recording_inputs)
1099*da0073e9SAndroid Build Coastguard Worker
1100*da0073e9SAndroid Build Coastguard Worker        grad_ge1 = torch.autograd.grad(outputs_ge1, recording_inputs)
1101*da0073e9SAndroid Build Coastguard Worker        grad_ge2 = torch.autograd.grad(outputs_ge2, recording_inputs)
1102*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(outputs_ge1 == outputs_ge2)
1103*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(grad_ge1 == grad_ge2)
1104*da0073e9SAndroid Build Coastguard Worker
1105*da0073e9SAndroid Build Coastguard Worker    def test_python_function(self):
1106*da0073e9SAndroid Build Coastguard Worker        class MyFn(Function):
1107*da0073e9SAndroid Build Coastguard Worker            @staticmethod
1108*da0073e9SAndroid Build Coastguard Worker            def forward(ctx, x):
1109*da0073e9SAndroid Build Coastguard Worker                return x + 1
1110*da0073e9SAndroid Build Coastguard Worker
1111*da0073e9SAndroid Build Coastguard Worker            @staticmethod
1112*da0073e9SAndroid Build Coastguard Worker            def backward(ctx, grad_output):
1113*da0073e9SAndroid Build Coastguard Worker                return grad_output
1114*da0073e9SAndroid Build Coastguard Worker
1115*da0073e9SAndroid Build Coastguard Worker        @_trace(torch.zeros(2))
1116*da0073e9SAndroid Build Coastguard Worker        def fn(x):
1117*da0073e9SAndroid Build Coastguard Worker            return MyFn.apply(x + 2) + 3
1118*da0073e9SAndroid Build Coastguard Worker
1119*da0073e9SAndroid Build Coastguard Worker        x = torch.tensor([1.0, 2.0, 3.0])
1120*da0073e9SAndroid Build Coastguard Worker        y = torch.randn(2, 2, requires_grad=True)
1121*da0073e9SAndroid Build Coastguard Worker        fn(x)
1122*da0073e9SAndroid Build Coastguard Worker        fn(y)
1123*da0073e9SAndroid Build Coastguard Worker
1124*da0073e9SAndroid Build Coastguard Worker    def test_python_function_tup(self):
1125*da0073e9SAndroid Build Coastguard Worker        class MyFn(Function):
1126*da0073e9SAndroid Build Coastguard Worker            @staticmethod
1127*da0073e9SAndroid Build Coastguard Worker            def forward(ctx, x):
1128*da0073e9SAndroid Build Coastguard Worker                return x + 1, x - 1
1129*da0073e9SAndroid Build Coastguard Worker
1130*da0073e9SAndroid Build Coastguard Worker            @staticmethod
1131*da0073e9SAndroid Build Coastguard Worker            def backward(ctx, grad_output):
1132*da0073e9SAndroid Build Coastguard Worker                return grad_output, grad_output
1133*da0073e9SAndroid Build Coastguard Worker
1134*da0073e9SAndroid Build Coastguard Worker        @_trace(torch.zeros(2))
1135*da0073e9SAndroid Build Coastguard Worker        def fn(x):
1136*da0073e9SAndroid Build Coastguard Worker            a, b = MyFn.apply(x + 2)
1137*da0073e9SAndroid Build Coastguard Worker            return a + b + 3
1138*da0073e9SAndroid Build Coastguard Worker
1139*da0073e9SAndroid Build Coastguard Worker        x = torch.tensor([1.0, 2.0, 3.0])
1140*da0073e9SAndroid Build Coastguard Worker        y = torch.randn(2, 2, requires_grad=True)
1141*da0073e9SAndroid Build Coastguard Worker        fn(x)
1142*da0073e9SAndroid Build Coastguard Worker        fn(y)
1143*da0073e9SAndroid Build Coastguard Worker
1144*da0073e9SAndroid Build Coastguard Worker    def test_trace_detach(self):
1145*da0073e9SAndroid Build Coastguard Worker        def foo(x, w):
1146*da0073e9SAndroid Build Coastguard Worker            return torch.matmul(x, w).detach()
1147*da0073e9SAndroid Build Coastguard Worker
1148*da0073e9SAndroid Build Coastguard Worker        traced = torch.jit.trace(foo, (torch.rand(3, 4), torch.rand(4, 5)))
1149*da0073e9SAndroid Build Coastguard Worker
1150*da0073e9SAndroid Build Coastguard Worker        FileCheck().check("matmul").check("detach").run(str(traced.graph))
1151*da0073e9SAndroid Build Coastguard Worker        x, w = torch.rand(3, 4), torch.rand(4, 5, requires_grad=True)
1152*da0073e9SAndroid Build Coastguard Worker        traced_result = traced(x, w)
1153*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(foo(x, w), traced_result)
1154*da0073e9SAndroid Build Coastguard Worker        self.assertFalse(traced_result.requires_grad)
1155*da0073e9SAndroid Build Coastguard Worker        self.assertIsNone(traced_result.grad_fn)
1156*da0073e9SAndroid Build Coastguard Worker
1157*da0073e9SAndroid Build Coastguard Worker    def test_trace_detach_redispatch(self):
1158*da0073e9SAndroid Build Coastguard Worker        def foo(x, w):
1159*da0073e9SAndroid Build Coastguard Worker            y = torch.matmul(x, w)
1160*da0073e9SAndroid Build Coastguard Worker            assert y.requires_grad
1161*da0073e9SAndroid Build Coastguard Worker            y = y.detach()
1162*da0073e9SAndroid Build Coastguard Worker            # Make sure trace kernel redispatches to the right lower kernel.
1163*da0073e9SAndroid Build Coastguard Worker            assert not y.requires_grad
1164*da0073e9SAndroid Build Coastguard Worker            return y
1165*da0073e9SAndroid Build Coastguard Worker
1166*da0073e9SAndroid Build Coastguard Worker        x, w = torch.rand(3, 4), torch.rand(4, 5, requires_grad=True)
1167*da0073e9SAndroid Build Coastguard Worker        # With `check_trace=True` it will run with `@torch.no_grad()` and break assert.
1168*da0073e9SAndroid Build Coastguard Worker        torch.jit.trace(foo, (x, w), check_trace=False)
1169*da0073e9SAndroid Build Coastguard Worker
1170*da0073e9SAndroid Build Coastguard Worker    def test_trace_detach_inplace(self):
1171*da0073e9SAndroid Build Coastguard Worker        def foo(x, w):
1172*da0073e9SAndroid Build Coastguard Worker            y = torch.matmul(x, w)
1173*da0073e9SAndroid Build Coastguard Worker            y.detach_()
1174*da0073e9SAndroid Build Coastguard Worker            return y
1175*da0073e9SAndroid Build Coastguard Worker
1176*da0073e9SAndroid Build Coastguard Worker        traced = torch.jit.trace(foo, (torch.rand(3, 4), torch.rand(4, 5)))
1177*da0073e9SAndroid Build Coastguard Worker
1178*da0073e9SAndroid Build Coastguard Worker        FileCheck().check("matmul").check("detach(").run(str(traced.graph))
1179*da0073e9SAndroid Build Coastguard Worker        x, w = torch.rand(3, 4), torch.rand(4, 5, requires_grad=True)
1180*da0073e9SAndroid Build Coastguard Worker        traced_result = traced(x, w)
1181*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(foo(x, w), traced_result)
1182*da0073e9SAndroid Build Coastguard Worker        self.assertFalse(traced_result.requires_grad)
1183*da0073e9SAndroid Build Coastguard Worker        self.assertIsNone(traced_result.grad_fn)
1184*da0073e9SAndroid Build Coastguard Worker
1185*da0073e9SAndroid Build Coastguard Worker    def test_trace_detach_inplace_redispatch(self):
1186*da0073e9SAndroid Build Coastguard Worker        def foo(x, w):
1187*da0073e9SAndroid Build Coastguard Worker            y = torch.matmul(x, w)
1188*da0073e9SAndroid Build Coastguard Worker            assert y.requires_grad
1189*da0073e9SAndroid Build Coastguard Worker            y.detach_()
1190*da0073e9SAndroid Build Coastguard Worker            # Make sure trace kernel redispatches to the right lower kernel.
1191*da0073e9SAndroid Build Coastguard Worker            assert not y.requires_grad
1192*da0073e9SAndroid Build Coastguard Worker            return y
1193*da0073e9SAndroid Build Coastguard Worker
1194*da0073e9SAndroid Build Coastguard Worker        x, w = torch.rand(3, 4), torch.rand(4, 5, requires_grad=True)
1195*da0073e9SAndroid Build Coastguard Worker        # With `check_trace=True` it will run with `@torch.no_grad()` and break assert.
1196*da0073e9SAndroid Build Coastguard Worker        torch.jit.trace(foo, (x, w), check_trace=False)
1197*da0073e9SAndroid Build Coastguard Worker
1198*da0073e9SAndroid Build Coastguard Worker    def test_trace_slice_full_dim(self):
1199*da0073e9SAndroid Build Coastguard Worker        def foo(x):
1200*da0073e9SAndroid Build Coastguard Worker            return x[0:5, 0] + 1.0
1201*da0073e9SAndroid Build Coastguard Worker
1202*da0073e9SAndroid Build Coastguard Worker        traced = torch.jit.trace(foo, (torch.rand(5, 4),))
1203*da0073e9SAndroid Build Coastguard Worker        test_x = torch.rand(6, 3)
1204*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(foo(test_x), traced(test_x))
1205*da0073e9SAndroid Build Coastguard Worker
1206*da0073e9SAndroid Build Coastguard Worker    def test_trace_dict_input(self):
1207*da0073e9SAndroid Build Coastguard Worker        class Bar(torch.nn.Module):
1208*da0073e9SAndroid Build Coastguard Worker            def __init__(self) -> None:
1209*da0073e9SAndroid Build Coastguard Worker                super().__init__()
1210*da0073e9SAndroid Build Coastguard Worker                self.foo = Foo()
1211*da0073e9SAndroid Build Coastguard Worker
1212*da0073e9SAndroid Build Coastguard Worker            def forward(self, a, b):
1213*da0073e9SAndroid Build Coastguard Worker                return self.foo({"a": a, "b": b})["a"]
1214*da0073e9SAndroid Build Coastguard Worker
1215*da0073e9SAndroid Build Coastguard Worker        class Foo(torch.nn.Module):
1216*da0073e9SAndroid Build Coastguard Worker            def forward(self, x):
1217*da0073e9SAndroid Build Coastguard Worker                return {"a": x["a"] * x["b"]}
1218*da0073e9SAndroid Build Coastguard Worker
1219*da0073e9SAndroid Build Coastguard Worker        x = (torch.rand(3), torch.rand(3))
1220*da0073e9SAndroid Build Coastguard Worker        model = Bar()
1221*da0073e9SAndroid Build Coastguard Worker        self.checkTrace(model, x)
1222*da0073e9SAndroid Build Coastguard Worker
1223*da0073e9SAndroid Build Coastguard Worker    def test_trace_dict_output(self):
1224*da0073e9SAndroid Build Coastguard Worker        class TraceDictStrTensor(torch.nn.Module):
1225*da0073e9SAndroid Build Coastguard Worker            def forward(self, a, b):
1226*da0073e9SAndroid Build Coastguard Worker                return {"a": a, "b": b}
1227*da0073e9SAndroid Build Coastguard Worker
1228*da0073e9SAndroid Build Coastguard Worker        class TraceDictTensorTensor(torch.nn.Module):
1229*da0073e9SAndroid Build Coastguard Worker            def forward(self, a, b):
1230*da0073e9SAndroid Build Coastguard Worker                return {a: b, b: a}
1231*da0073e9SAndroid Build Coastguard Worker
1232*da0073e9SAndroid Build Coastguard Worker        x = (torch.rand(3), torch.rand(3))
1233*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(RuntimeError, r"Encountering a dict at the output"):
1234*da0073e9SAndroid Build Coastguard Worker            torch.jit.trace(TraceDictStrTensor(), x)
1235*da0073e9SAndroid Build Coastguard Worker
1236*da0073e9SAndroid Build Coastguard Worker        traced_dict_str_mod = torch.jit.trace(TraceDictStrTensor(), x, strict=False)
1237*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(traced_dict_str_mod(*x), {"a": x[0], "b": x[1]})
1238*da0073e9SAndroid Build Coastguard Worker
1239*da0073e9SAndroid Build Coastguard Worker        traced_dict_tensor_mod = torch.jit.trace(
1240*da0073e9SAndroid Build Coastguard Worker            TraceDictTensorTensor(), x, strict=False
1241*da0073e9SAndroid Build Coastguard Worker        )
1242*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(traced_dict_tensor_mod(*x), {x[0]: x[1], x[1]: x[0]})
1243*da0073e9SAndroid Build Coastguard Worker
1244*da0073e9SAndroid Build Coastguard Worker    def test_trace_with_tensor_list_output(self):
1245*da0073e9SAndroid Build Coastguard Worker        def f():
1246*da0073e9SAndroid Build Coastguard Worker            return [torch.zeros(1), torch.zeros(5)]
1247*da0073e9SAndroid Build Coastguard Worker
1248*da0073e9SAndroid Build Coastguard Worker        with self.assertWarnsRegex(
1249*da0073e9SAndroid Build Coastguard Worker            torch.jit.TracerWarning, "cause the trace to be incorrect"
1250*da0073e9SAndroid Build Coastguard Worker        ):
1251*da0073e9SAndroid Build Coastguard Worker            torch.jit.trace(f, [])
1252*da0073e9SAndroid Build Coastguard Worker        traced_non_strict_f = torch.jit.trace(f, [], strict=False)
1253*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(traced_non_strict_f(), f())
1254*da0073e9SAndroid Build Coastguard Worker
1255*da0073e9SAndroid Build Coastguard Worker    def test_trace_with_number_list_output(self):
1256*da0073e9SAndroid Build Coastguard Worker        def f():
1257*da0073e9SAndroid Build Coastguard Worker            return [1, 5]
1258*da0073e9SAndroid Build Coastguard Worker
1259*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(
1260*da0073e9SAndroid Build Coastguard Worker            RuntimeError, r"Only tensors.+can be output from traced functions"
1261*da0073e9SAndroid Build Coastguard Worker        ):
1262*da0073e9SAndroid Build Coastguard Worker            traced_f = torch.jit.trace(f, [])
1263*da0073e9SAndroid Build Coastguard Worker
1264*da0073e9SAndroid Build Coastguard Worker    def test_trace_with_nested_tensor_list_output(self):
1265*da0073e9SAndroid Build Coastguard Worker        def f():
1266*da0073e9SAndroid Build Coastguard Worker            return [[torch.zeros(1)], [torch.zeros(5)]]
1267*da0073e9SAndroid Build Coastguard Worker
1268*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(
1269*da0073e9SAndroid Build Coastguard Worker            RuntimeError, r"Only tensors.+can be output from traced functions"
1270*da0073e9SAndroid Build Coastguard Worker        ):
1271*da0073e9SAndroid Build Coastguard Worker            traced_f = torch.jit.trace(f, [])
1272*da0073e9SAndroid Build Coastguard Worker
1273*da0073e9SAndroid Build Coastguard Worker    def test_trace_with_nested_strided_tensor_output(self):
1274*da0073e9SAndroid Build Coastguard Worker        @torch.jit.script
1275*da0073e9SAndroid Build Coastguard Worker        def nt_construct(values, kv_lengths):
1276*da0073e9SAndroid Build Coastguard Worker            kv_lengths_list: List[int] = kv_lengths.tolist()
1277*da0073e9SAndroid Build Coastguard Worker            return torch._nested_tensor_from_tensor_list(
1278*da0073e9SAndroid Build Coastguard Worker                list(values.split(kv_lengths_list, dim=0)), None, None, None, None
1279*da0073e9SAndroid Build Coastguard Worker            )
1280*da0073e9SAndroid Build Coastguard Worker
1281*da0073e9SAndroid Build Coastguard Worker        def f(x, offsets):
1282*da0073e9SAndroid Build Coastguard Worker            kv_lengths = offsets[1:] - offsets[:-1]
1283*da0073e9SAndroid Build Coastguard Worker            return nt_construct(x, kv_lengths).cos()
1284*da0073e9SAndroid Build Coastguard Worker
1285*da0073e9SAndroid Build Coastguard Worker        x = torch.rand(5, 4)
1286*da0073e9SAndroid Build Coastguard Worker        offsets = torch.tensor([0, 2, 5])
1287*da0073e9SAndroid Build Coastguard Worker        ref = f(x, offsets)
1288*da0073e9SAndroid Build Coastguard Worker        f_t = torch.jit.trace(f, (x, offsets))
1289*da0073e9SAndroid Build Coastguard Worker        res = f_t(x, offsets)
1290*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(ref, res)
1291*da0073e9SAndroid Build Coastguard Worker        x2 = torch.rand((8, 4))
1292*da0073e9SAndroid Build Coastguard Worker        offsets2 = torch.tensor([0, 2, 4, 8])
1293*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(f(x2, offsets2), f_t(x2, offsets2))
1294*da0073e9SAndroid Build Coastguard Worker
1295*da0073e9SAndroid Build Coastguard Worker    def test_trace_variable_instantiation(self):
1296*da0073e9SAndroid Build Coastguard Worker        def random_foo(x):
1297*da0073e9SAndroid Build Coastguard Worker            return Variable(Variable(x) + 1.0)
1298*da0073e9SAndroid Build Coastguard Worker
1299*da0073e9SAndroid Build Coastguard Worker        random_foo_traced = torch.jit.trace(random_foo, (torch.rand(3, 4),))
1300*da0073e9SAndroid Build Coastguard Worker
1301*da0073e9SAndroid Build Coastguard Worker        x = torch.rand(5, 6)
1302*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(random_foo(x), random_foo_traced(x))
1303*da0073e9SAndroid Build Coastguard Worker
1304*da0073e9SAndroid Build Coastguard Worker    def test_trace_slice_expr_complete_type(self):
1305*da0073e9SAndroid Build Coastguard Worker        def random_foo(x):
1306*da0073e9SAndroid Build Coastguard Worker            return x + 1.0
1307*da0073e9SAndroid Build Coastguard Worker
1308*da0073e9SAndroid Build Coastguard Worker        random_foo_traced = torch.jit.trace(random_foo, (torch.rand(3, 4),))
1309*da0073e9SAndroid Build Coastguard Worker
1310*da0073e9SAndroid Build Coastguard Worker        @torch.jit.script
1311*da0073e9SAndroid Build Coastguard Worker        def random_bar(x):
1312*da0073e9SAndroid Build Coastguard Worker            return random_foo_traced(x)[0:1]
1313*da0073e9SAndroid Build Coastguard Worker
1314*da0073e9SAndroid Build Coastguard Worker        x = torch.rand(3, 4)
1315*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(random_bar(x), (x + 1)[0:1])
1316*da0073e9SAndroid Build Coastguard Worker
1317*da0073e9SAndroid Build Coastguard Worker    def test_trace_inline_shape(self):
1318*da0073e9SAndroid Build Coastguard Worker        # testing peephole optimization of size is turned into a constant
1319*da0073e9SAndroid Build Coastguard Worker        # in script fn
1320*da0073e9SAndroid Build Coastguard Worker
1321*da0073e9SAndroid Build Coastguard Worker        @torch.jit.script
1322*da0073e9SAndroid Build Coastguard Worker        def tensor_size(x: torch.Tensor) -> torch.Tensor:
1323*da0073e9SAndroid Build Coastguard Worker            return torch.tensor([x.size()[0]])
1324*da0073e9SAndroid Build Coastguard Worker
1325*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(
1326*da0073e9SAndroid Build Coastguard Worker            tensor_size(
1327*da0073e9SAndroid Build Coastguard Worker                torch.rand(
1328*da0073e9SAndroid Build Coastguard Worker                    15,
1329*da0073e9SAndroid Build Coastguard Worker                )
1330*da0073e9SAndroid Build Coastguard Worker            ),
1331*da0073e9SAndroid Build Coastguard Worker            torch.tensor([15]),
1332*da0073e9SAndroid Build Coastguard Worker        )
1333*da0073e9SAndroid Build Coastguard Worker
1334*da0073e9SAndroid Build Coastguard Worker        traced_tensor_size = torch.jit.trace(
1335*da0073e9SAndroid Build Coastguard Worker            tensor_size,
1336*da0073e9SAndroid Build Coastguard Worker            torch.rand(
1337*da0073e9SAndroid Build Coastguard Worker                7,
1338*da0073e9SAndroid Build Coastguard Worker            ),
1339*da0073e9SAndroid Build Coastguard Worker        )
1340*da0073e9SAndroid Build Coastguard Worker
1341*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(
1342*da0073e9SAndroid Build Coastguard Worker            traced_tensor_size(
1343*da0073e9SAndroid Build Coastguard Worker                torch.rand(
1344*da0073e9SAndroid Build Coastguard Worker                    15,
1345*da0073e9SAndroid Build Coastguard Worker                )
1346*da0073e9SAndroid Build Coastguard Worker            ),
1347*da0073e9SAndroid Build Coastguard Worker            torch.tensor([15]),
1348*da0073e9SAndroid Build Coastguard Worker        )
1349*da0073e9SAndroid Build Coastguard Worker
1350*da0073e9SAndroid Build Coastguard Worker        @torch.jit.script
1351*da0073e9SAndroid Build Coastguard Worker        def use_device(x):
1352*da0073e9SAndroid Build Coastguard Worker            return torch.zeros_like(x, device=x.device)
1353*da0073e9SAndroid Build Coastguard Worker
1354*da0073e9SAndroid Build Coastguard Worker        def foo(x):
1355*da0073e9SAndroid Build Coastguard Worker            return use_device(x)
1356*da0073e9SAndroid Build Coastguard Worker
1357*da0073e9SAndroid Build Coastguard Worker        traced_tensor_size = torch.jit.trace(
1358*da0073e9SAndroid Build Coastguard Worker            foo,
1359*da0073e9SAndroid Build Coastguard Worker            torch.rand(
1360*da0073e9SAndroid Build Coastguard Worker                7,
1361*da0073e9SAndroid Build Coastguard Worker            ),
1362*da0073e9SAndroid Build Coastguard Worker        )
1363*da0073e9SAndroid Build Coastguard Worker        self.run_pass("inline", traced_tensor_size.graph)
1364*da0073e9SAndroid Build Coastguard Worker        FileCheck().check("prim::device").run(traced_tensor_size.graph)
1365*da0073e9SAndroid Build Coastguard Worker
1366*da0073e9SAndroid Build Coastguard Worker    def test_trace_save(self):
1367*da0073e9SAndroid Build Coastguard Worker        def fn(x):
1368*da0073e9SAndroid Build Coastguard Worker            return x + 2
1369*da0073e9SAndroid Build Coastguard Worker
1370*da0073e9SAndroid Build Coastguard Worker        def check(func):
1371*da0073e9SAndroid Build Coastguard Worker            with TemporaryFileName() as fname:
1372*da0073e9SAndroid Build Coastguard Worker                func.save(fname)
1373*da0073e9SAndroid Build Coastguard Worker                loaded = torch.jit.load(fname)
1374*da0073e9SAndroid Build Coastguard Worker                input = torch.randn(2, 2)
1375*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(func(input), loaded(input))
1376*da0073e9SAndroid Build Coastguard Worker
1377*da0073e9SAndroid Build Coastguard Worker        out = torch.jit.trace(fn, (torch.ones(2, 2),))
1378*da0073e9SAndroid Build Coastguard Worker        check(out)
1379*da0073e9SAndroid Build Coastguard Worker
1380*da0073e9SAndroid Build Coastguard Worker    def test_trace_optioanl_dtype(self):
1381*da0073e9SAndroid Build Coastguard Worker        class Test(torch.nn.Module):
1382*da0073e9SAndroid Build Coastguard Worker            def forward(self):
1383*da0073e9SAndroid Build Coastguard Worker                return torch.arange(5)
1384*da0073e9SAndroid Build Coastguard Worker
1385*da0073e9SAndroid Build Coastguard Worker        traced = torch.jit.trace(Test(), ())
1386*da0073e9SAndroid Build Coastguard Worker        torch.allclose(traced(), Test()())
1387*da0073e9SAndroid Build Coastguard Worker
1388*da0073e9SAndroid Build Coastguard Worker    def test_trace_save_load_copy(self):
1389*da0073e9SAndroid Build Coastguard Worker        class Test(torch.nn.Module):
1390*da0073e9SAndroid Build Coastguard Worker            def __init__(self) -> None:
1391*da0073e9SAndroid Build Coastguard Worker                super().__init__()
1392*da0073e9SAndroid Build Coastguard Worker                self.conv = torch.nn.Conv2d(3, 3, 3)
1393*da0073e9SAndroid Build Coastguard Worker
1394*da0073e9SAndroid Build Coastguard Worker            def forward(self, x):
1395*da0073e9SAndroid Build Coastguard Worker                return self.conv(x)
1396*da0073e9SAndroid Build Coastguard Worker
1397*da0073e9SAndroid Build Coastguard Worker        traced = torch.jit.trace(Test(), torch.rand(1, 3, 224, 224))
1398*da0073e9SAndroid Build Coastguard Worker        buffer = io.BytesIO()
1399*da0073e9SAndroid Build Coastguard Worker        torch.jit.save(traced, buffer)
1400*da0073e9SAndroid Build Coastguard Worker        buffer.seek(0)
1401*da0073e9SAndroid Build Coastguard Worker        loaded = torch.jit.load(buffer)
1402*da0073e9SAndroid Build Coastguard Worker        # should work
1403*da0073e9SAndroid Build Coastguard Worker        copy.copy(loaded)
1404*da0073e9SAndroid Build Coastguard Worker        copy.deepcopy(loaded)
1405*da0073e9SAndroid Build Coastguard Worker
1406*da0073e9SAndroid Build Coastguard Worker    def test_trace_export_fns(self):
1407*da0073e9SAndroid Build Coastguard Worker        class Foo(torch.nn.Module):
1408*da0073e9SAndroid Build Coastguard Worker            def __init__(self) -> None:
1409*da0073e9SAndroid Build Coastguard Worker                super().__init__()
1410*da0073e9SAndroid Build Coastguard Worker                self.a = 3
1411*da0073e9SAndroid Build Coastguard Worker
1412*da0073e9SAndroid Build Coastguard Worker            @torch.jit.export
1413*da0073e9SAndroid Build Coastguard Worker            def __getstate__(self):
1414*da0073e9SAndroid Build Coastguard Worker                return (3, self.training)
1415*da0073e9SAndroid Build Coastguard Worker
1416*da0073e9SAndroid Build Coastguard Worker            @torch.jit.export
1417*da0073e9SAndroid Build Coastguard Worker            def __setstate__(self, state):
1418*da0073e9SAndroid Build Coastguard Worker                self.a = state[0]
1419*da0073e9SAndroid Build Coastguard Worker                self.training = state[1]
1420*da0073e9SAndroid Build Coastguard Worker
1421*da0073e9SAndroid Build Coastguard Worker            def forward(self, x):
1422*da0073e9SAndroid Build Coastguard Worker                return x + self.a
1423*da0073e9SAndroid Build Coastguard Worker
1424*da0073e9SAndroid Build Coastguard Worker        f = Foo()
1425*da0073e9SAndroid Build Coastguard Worker
1426*da0073e9SAndroid Build Coastguard Worker        traced = torch.jit.trace(f, (torch.rand(3, 4),))
1427*da0073e9SAndroid Build Coastguard Worker        expected_names = ["__getstate__", "__setstate__"]
1428*da0073e9SAndroid Build Coastguard Worker
1429*da0073e9SAndroid Build Coastguard Worker        def check(mod):
1430*da0073e9SAndroid Build Coastguard Worker            self.assertTrue(
1431*da0073e9SAndroid Build Coastguard Worker                all(name in mod._c._method_names() for name in expected_names)
1432*da0073e9SAndroid Build Coastguard Worker            )
1433*da0073e9SAndroid Build Coastguard Worker
1434*da0073e9SAndroid Build Coastguard Worker        check(traced)
1435*da0073e9SAndroid Build Coastguard Worker
1436*da0073e9SAndroid Build Coastguard Worker        imported = self.getExportImportCopy(traced)
1437*da0073e9SAndroid Build Coastguard Worker        check(imported)
1438*da0073e9SAndroid Build Coastguard Worker
1439*da0073e9SAndroid Build Coastguard Worker    def test_trace_export_fns_recursive(self):
1440*da0073e9SAndroid Build Coastguard Worker        class Foo(torch.nn.Module):
1441*da0073e9SAndroid Build Coastguard Worker            def __init__(self) -> None:
1442*da0073e9SAndroid Build Coastguard Worker                super().__init__()
1443*da0073e9SAndroid Build Coastguard Worker                self.a = 3
1444*da0073e9SAndroid Build Coastguard Worker
1445*da0073e9SAndroid Build Coastguard Worker            @torch.jit.export
1446*da0073e9SAndroid Build Coastguard Worker            def __getstate__(self):
1447*da0073e9SAndroid Build Coastguard Worker                return (3, self.training)
1448*da0073e9SAndroid Build Coastguard Worker
1449*da0073e9SAndroid Build Coastguard Worker            @torch.jit.export
1450*da0073e9SAndroid Build Coastguard Worker            def __setstate__(self, state):
1451*da0073e9SAndroid Build Coastguard Worker                self.a = state[0]
1452*da0073e9SAndroid Build Coastguard Worker                self.training = state[1]
1453*da0073e9SAndroid Build Coastguard Worker
1454*da0073e9SAndroid Build Coastguard Worker            def forward(self, x):
1455*da0073e9SAndroid Build Coastguard Worker                return x + self.a
1456*da0073e9SAndroid Build Coastguard Worker
1457*da0073e9SAndroid Build Coastguard Worker        class Wrapper(torch.nn.Module):
1458*da0073e9SAndroid Build Coastguard Worker            def __init__(self) -> None:
1459*da0073e9SAndroid Build Coastguard Worker                super().__init__()
1460*da0073e9SAndroid Build Coastguard Worker                self.foo = Foo()
1461*da0073e9SAndroid Build Coastguard Worker
1462*da0073e9SAndroid Build Coastguard Worker            def forward(self, x):
1463*da0073e9SAndroid Build Coastguard Worker                return self.foo(x)
1464*da0073e9SAndroid Build Coastguard Worker
1465*da0073e9SAndroid Build Coastguard Worker        f = Wrapper()
1466*da0073e9SAndroid Build Coastguard Worker
1467*da0073e9SAndroid Build Coastguard Worker        traced = torch.jit.trace(f, (torch.rand(3, 4),))
1468*da0073e9SAndroid Build Coastguard Worker        expected_names = ["__getstate__", "__setstate__"]
1469*da0073e9SAndroid Build Coastguard Worker
1470*da0073e9SAndroid Build Coastguard Worker        def check(mod):
1471*da0073e9SAndroid Build Coastguard Worker            self.assertTrue(
1472*da0073e9SAndroid Build Coastguard Worker                all(name in mod._c._method_names() for name in expected_names)
1473*da0073e9SAndroid Build Coastguard Worker            )
1474*da0073e9SAndroid Build Coastguard Worker
1475*da0073e9SAndroid Build Coastguard Worker        check(traced.foo)
1476*da0073e9SAndroid Build Coastguard Worker
1477*da0073e9SAndroid Build Coastguard Worker        imported = self.getExportImportCopy(traced)
1478*da0073e9SAndroid Build Coastguard Worker        check(imported.foo)
1479*da0073e9SAndroid Build Coastguard Worker
1480*da0073e9SAndroid Build Coastguard Worker        # Note that Bar's forward can only be traced, but not scripted
1481*da0073e9SAndroid Build Coastguard Worker        class Bar(nn.Module):
1482*da0073e9SAndroid Build Coastguard Worker            @torch.jit.export
1483*da0073e9SAndroid Build Coastguard Worker            def addTwo(self, x):
1484*da0073e9SAndroid Build Coastguard Worker                return x + 2
1485*da0073e9SAndroid Build Coastguard Worker
1486*da0073e9SAndroid Build Coastguard Worker            def forward(self, input):
1487*da0073e9SAndroid Build Coastguard Worker                return (lambda a: a + 1)(input)  # noqa: PLC3002
1488*da0073e9SAndroid Build Coastguard Worker
1489*da0073e9SAndroid Build Coastguard Worker        # When tracing Bar as a submodule, we only want to script the
1490*da0073e9SAndroid Build Coastguard Worker        # exported methods, and we want to keep the forwards still
1491*da0073e9SAndroid Build Coastguard Worker        # being traced.
1492*da0073e9SAndroid Build Coastguard Worker        class WrapperExports(torch.nn.Module):
1493*da0073e9SAndroid Build Coastguard Worker            def __init__(self) -> None:
1494*da0073e9SAndroid Build Coastguard Worker                super().__init__()
1495*da0073e9SAndroid Build Coastguard Worker                self.bar = Bar()
1496*da0073e9SAndroid Build Coastguard Worker
1497*da0073e9SAndroid Build Coastguard Worker            @torch.jit.export
1498*da0073e9SAndroid Build Coastguard Worker            def addOne(self, x):
1499*da0073e9SAndroid Build Coastguard Worker                return x + 1
1500*da0073e9SAndroid Build Coastguard Worker
1501*da0073e9SAndroid Build Coastguard Worker            def forward(self, x):
1502*da0073e9SAndroid Build Coastguard Worker                return self.bar(x)
1503*da0073e9SAndroid Build Coastguard Worker
1504*da0073e9SAndroid Build Coastguard Worker        f = WrapperExports()
1505*da0073e9SAndroid Build Coastguard Worker
1506*da0073e9SAndroid Build Coastguard Worker        traced = torch.jit.trace(f, (torch.rand(3, 4),))
1507*da0073e9SAndroid Build Coastguard Worker        expected_names = ["addOne"]
1508*da0073e9SAndroid Build Coastguard Worker        check(traced)
1509*da0073e9SAndroid Build Coastguard Worker
1510*da0073e9SAndroid Build Coastguard Worker    def test_trace_autograd_function(self):
1511*da0073e9SAndroid Build Coastguard Worker        class TestFunc(torch.autograd.Function):
1512*da0073e9SAndroid Build Coastguard Worker            @staticmethod
1513*da0073e9SAndroid Build Coastguard Worker            def forward(ctx, input):
1514*da0073e9SAndroid Build Coastguard Worker                return torch.neg(input)
1515*da0073e9SAndroid Build Coastguard Worker
1516*da0073e9SAndroid Build Coastguard Worker            @staticmethod
1517*da0073e9SAndroid Build Coastguard Worker            def backward(ctx, grad_output):
1518*da0073e9SAndroid Build Coastguard Worker                return torch.neg(grad_output)
1519*da0073e9SAndroid Build Coastguard Worker
1520*da0073e9SAndroid Build Coastguard Worker        class TracedModule(torch.nn.Module):
1521*da0073e9SAndroid Build Coastguard Worker            def forward(self, x):
1522*da0073e9SAndroid Build Coastguard Worker                return torch.relu(TestFunc.apply(x))
1523*da0073e9SAndroid Build Coastguard Worker
1524*da0073e9SAndroid Build Coastguard Worker        class Wrapper(torch.nn.Module):
1525*da0073e9SAndroid Build Coastguard Worker            def __init__(self) -> None:
1526*da0073e9SAndroid Build Coastguard Worker                super().__init__()
1527*da0073e9SAndroid Build Coastguard Worker                self.tm = TracedModule()
1528*da0073e9SAndroid Build Coastguard Worker
1529*da0073e9SAndroid Build Coastguard Worker            def forward(self, x):
1530*da0073e9SAndroid Build Coastguard Worker                return self.tm(x)
1531*da0073e9SAndroid Build Coastguard Worker
1532*da0073e9SAndroid Build Coastguard Worker        traced = torch.jit.trace(Wrapper(), (torch.rand(3, 4),))
1533*da0073e9SAndroid Build Coastguard Worker
1534*da0073e9SAndroid Build Coastguard Worker    def test_trace_multi_output_function(self):
1535*da0073e9SAndroid Build Coastguard Worker        # An autograd.Function with two outputs.
1536*da0073e9SAndroid Build Coastguard Worker        # It swaps inputs so we can check if shape
1537*da0073e9SAndroid Build Coastguard Worker        # handling is correct in TorchScript.
1538*da0073e9SAndroid Build Coastguard Worker        class Foo(torch.autograd.Function):
1539*da0073e9SAndroid Build Coastguard Worker            @staticmethod
1540*da0073e9SAndroid Build Coastguard Worker            def forward(ctx, x, y):
1541*da0073e9SAndroid Build Coastguard Worker                return y, x
1542*da0073e9SAndroid Build Coastguard Worker
1543*da0073e9SAndroid Build Coastguard Worker            @staticmethod
1544*da0073e9SAndroid Build Coastguard Worker            def backward(ctx, du, dv):
1545*da0073e9SAndroid Build Coastguard Worker                return dv, du
1546*da0073e9SAndroid Build Coastguard Worker
1547*da0073e9SAndroid Build Coastguard Worker        class Bar(torch.nn.Module):
1548*da0073e9SAndroid Build Coastguard Worker            def forward(self, x, y):
1549*da0073e9SAndroid Build Coastguard Worker                x = x.relu()
1550*da0073e9SAndroid Build Coastguard Worker                y = y.relu()
1551*da0073e9SAndroid Build Coastguard Worker                z = Foo.apply(x, y)
1552*da0073e9SAndroid Build Coastguard Worker                return z
1553*da0073e9SAndroid Build Coastguard Worker
1554*da0073e9SAndroid Build Coastguard Worker        x = torch.rand(3, 2, dtype=torch.double)
1555*da0073e9SAndroid Build Coastguard Worker        y = torch.rand(1, 2, dtype=torch.double)
1556*da0073e9SAndroid Build Coastguard Worker
1557*da0073e9SAndroid Build Coastguard Worker        # Generate JIT IR.
1558*da0073e9SAndroid Build Coastguard Worker        traced = torch.jit.trace(Bar(), (x, y))
1559*da0073e9SAndroid Build Coastguard Worker        print(traced.graph)
1560*da0073e9SAndroid Build Coastguard Worker
1561*da0073e9SAndroid Build Coastguard Worker        # Expected output schema of the custom autograd.Function.
1562*da0073e9SAndroid Build Coastguard Worker        schema = (
1563*da0073e9SAndroid Build Coastguard Worker            "(Double(1, 2, strides=[2, 1], requires_grad=0, device=cpu), "
1564*da0073e9SAndroid Build Coastguard Worker            "Double(3, 2, strides=[2, 1], requires_grad=0, device=cpu)) "
1565*da0073e9SAndroid Build Coastguard Worker            "= ^Foo"
1566*da0073e9SAndroid Build Coastguard Worker        )
1567*da0073e9SAndroid Build Coastguard Worker
1568*da0073e9SAndroid Build Coastguard Worker        # See if expected schema exists.
1569*da0073e9SAndroid Build Coastguard Worker        FileCheck().check(schema).run(traced.graph)
1570*da0073e9SAndroid Build Coastguard Worker
1571*da0073e9SAndroid Build Coastguard Worker        # Also examine if the graph is runnable and produces
1572*da0073e9SAndroid Build Coastguard Worker        # the right result.
1573*da0073e9SAndroid Build Coastguard Worker        u, v = traced(x, y)
1574*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(u, y)
1575*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(v, x)
1576*da0073e9SAndroid Build Coastguard Worker
1577*da0073e9SAndroid Build Coastguard Worker    def test_interpolate_trace(self):
1578*da0073e9SAndroid Build Coastguard Worker        class test(nn.Module):
1579*da0073e9SAndroid Build Coastguard Worker            def __init__(self) -> None:
1580*da0073e9SAndroid Build Coastguard Worker                super().__init__()
1581*da0073e9SAndroid Build Coastguard Worker                self.conv = nn.Conv2d(1, 32, kernel_size=3, padding=1)
1582*da0073e9SAndroid Build Coastguard Worker
1583*da0073e9SAndroid Build Coastguard Worker            def forward(self, x):
1584*da0073e9SAndroid Build Coastguard Worker                y = self.conv(x)
1585*da0073e9SAndroid Build Coastguard Worker                w = nn.functional.interpolate(
1586*da0073e9SAndroid Build Coastguard Worker                    y, mode="bilinear", align_corners=False, scale_factor=3
1587*da0073e9SAndroid Build Coastguard Worker                )
1588*da0073e9SAndroid Build Coastguard Worker                return w
1589*da0073e9SAndroid Build Coastguard Worker
1590*da0073e9SAndroid Build Coastguard Worker        f = test()
1591*da0073e9SAndroid Build Coastguard Worker        # no failure
1592*da0073e9SAndroid Build Coastguard Worker        g = torch.jit.trace(f, (torch.zeros(1, 1, 28, 28),))
1593*da0073e9SAndroid Build Coastguard Worker        x = torch.zeros(1, 1, 14, 14)
1594*da0073e9SAndroid Build Coastguard Worker        # constants not baked in
1595*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(g(x), f(x))
1596*da0073e9SAndroid Build Coastguard Worker
1597*da0073e9SAndroid Build Coastguard Worker    @_tmp_donotuse_dont_inline_everything
1598*da0073e9SAndroid Build Coastguard Worker    def test_trace_optional(self):
1599*da0073e9SAndroid Build Coastguard Worker        @torch.jit.script
1600*da0073e9SAndroid Build Coastguard Worker        def test(x: Optional[Tensor]):
1601*da0073e9SAndroid Build Coastguard Worker            if x is None:
1602*da0073e9SAndroid Build Coastguard Worker                return torch.zeros(1)
1603*da0073e9SAndroid Build Coastguard Worker            else:
1604*da0073e9SAndroid Build Coastguard Worker                return x
1605*da0073e9SAndroid Build Coastguard Worker
1606*da0073e9SAndroid Build Coastguard Worker        def test_none():
1607*da0073e9SAndroid Build Coastguard Worker            return test(None)
1608*da0073e9SAndroid Build Coastguard Worker
1609*da0073e9SAndroid Build Coastguard Worker        def test_tensor():
1610*da0073e9SAndroid Build Coastguard Worker            return test(torch.zeros(2))
1611*da0073e9SAndroid Build Coastguard Worker
1612*da0073e9SAndroid Build Coastguard Worker        f_none = torch.jit.trace(test_none, ())
1613*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(f_none(), torch.zeros(1))
1614*da0073e9SAndroid Build Coastguard Worker
1615*da0073e9SAndroid Build Coastguard Worker        f_tensor = torch.jit.trace(test_tensor, ())
1616*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(f_tensor(), torch.zeros(2))
1617*da0073e9SAndroid Build Coastguard Worker
1618*da0073e9SAndroid Build Coastguard Worker        graph = f_tensor.graph
1619*da0073e9SAndroid Build Coastguard Worker        FileCheck().check('name="test"').check_next("prim::CallFunction").run(graph)
1620*da0073e9SAndroid Build Coastguard Worker
1621*da0073e9SAndroid Build Coastguard Worker    def test_trace_nested_datatypes(self):
1622*da0073e9SAndroid Build Coastguard Worker        @torch.jit.script
1623*da0073e9SAndroid Build Coastguard Worker        def foo(x):
1624*da0073e9SAndroid Build Coastguard Worker            return [[x + 1, x - 1], [x + 2, x - 2]]
1625*da0073e9SAndroid Build Coastguard Worker
1626*da0073e9SAndroid Build Coastguard Worker        def bar(x):
1627*da0073e9SAndroid Build Coastguard Worker            list_stuff = foo(x)
1628*da0073e9SAndroid Build Coastguard Worker            return list_stuff[0][0], list_stuff[1][1]
1629*da0073e9SAndroid Build Coastguard Worker
1630*da0073e9SAndroid Build Coastguard Worker        traced = torch.jit.trace(bar, torch.rand(3, 4))
1631*da0073e9SAndroid Build Coastguard Worker        x = torch.rand(5, 6)
1632*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(bar(x), traced(x))
1633*da0073e9SAndroid Build Coastguard Worker
1634*da0073e9SAndroid Build Coastguard Worker    @_tmp_donotuse_dont_inline_everything
1635*da0073e9SAndroid Build Coastguard Worker    def test_call_traced_fn_from_traced_module(self):
1636*da0073e9SAndroid Build Coastguard Worker        @_trace(torch.rand(3, 4))
1637*da0073e9SAndroid Build Coastguard Worker        def traced_fn(x):
1638*da0073e9SAndroid Build Coastguard Worker            return torch.neg(x)
1639*da0073e9SAndroid Build Coastguard Worker
1640*da0073e9SAndroid Build Coastguard Worker        class TracedModule(torch.nn.Module):
1641*da0073e9SAndroid Build Coastguard Worker            def __init__(self) -> None:
1642*da0073e9SAndroid Build Coastguard Worker                super().__init__()
1643*da0073e9SAndroid Build Coastguard Worker                self.param = torch.nn.Parameter(torch.rand(4, 5))
1644*da0073e9SAndroid Build Coastguard Worker
1645*da0073e9SAndroid Build Coastguard Worker            def forward(self, x):
1646*da0073e9SAndroid Build Coastguard Worker                return traced_fn(torch.mm(x, self.param))
1647*da0073e9SAndroid Build Coastguard Worker
1648*da0073e9SAndroid Build Coastguard Worker        tm = torch.jit.trace(TracedModule(), torch.rand(3, 4))
1649*da0073e9SAndroid Build Coastguard Worker
1650*da0073e9SAndroid Build Coastguard Worker        # Note: neg op from the traced function should be properly inlined
1651*da0073e9SAndroid Build Coastguard Worker        FileCheck().check("aten::mm").check('name="traced_fn"').check_next(
1652*da0073e9SAndroid Build Coastguard Worker            "prim::CallFunction"
1653*da0073e9SAndroid Build Coastguard Worker        ).run(str(tm.graph))
1654*da0073e9SAndroid Build Coastguard Worker
1655*da0073e9SAndroid Build Coastguard Worker    @_tmp_donotuse_dont_inline_everything
1656*da0073e9SAndroid Build Coastguard Worker    def test_call_traced_module_from_traced_module(self):
1657*da0073e9SAndroid Build Coastguard Worker        class TracedModule1(torch.nn.Module):
1658*da0073e9SAndroid Build Coastguard Worker            def __init__(self) -> None:
1659*da0073e9SAndroid Build Coastguard Worker                super().__init__()
1660*da0073e9SAndroid Build Coastguard Worker                self.param = torch.nn.Parameter(torch.rand(5, 7))
1661*da0073e9SAndroid Build Coastguard Worker
1662*da0073e9SAndroid Build Coastguard Worker            def forward(self, x):
1663*da0073e9SAndroid Build Coastguard Worker                return torch.mm(x, self.param)
1664*da0073e9SAndroid Build Coastguard Worker
1665*da0073e9SAndroid Build Coastguard Worker        class TracedModule(torch.nn.Module):
1666*da0073e9SAndroid Build Coastguard Worker            def __init__(self) -> None:
1667*da0073e9SAndroid Build Coastguard Worker                super().__init__()
1668*da0073e9SAndroid Build Coastguard Worker                self.param = torch.nn.Parameter(torch.rand(4, 5))
1669*da0073e9SAndroid Build Coastguard Worker                self.mod = torch.jit.trace(TracedModule1(), torch.rand(3, 5))
1670*da0073e9SAndroid Build Coastguard Worker
1671*da0073e9SAndroid Build Coastguard Worker            def forward(self, x):
1672*da0073e9SAndroid Build Coastguard Worker                return self.mod(torch.mm(x, self.param)) + 1.0
1673*da0073e9SAndroid Build Coastguard Worker
1674*da0073e9SAndroid Build Coastguard Worker        tm = torch.jit.trace(TracedModule(), torch.rand(3, 4))
1675*da0073e9SAndroid Build Coastguard Worker
1676*da0073e9SAndroid Build Coastguard Worker        FileCheck().check("aten::mm").check("prim::CallMethod").check_same(
1677*da0073e9SAndroid Build Coastguard Worker            "forward"
1678*da0073e9SAndroid Build Coastguard Worker        ).check("aten::add").run(str(tm.graph))
1679*da0073e9SAndroid Build Coastguard Worker
1680*da0073e9SAndroid Build Coastguard Worker    def test_index_put_trace_with_view(self):
1681*da0073e9SAndroid Build Coastguard Worker        @_trace(torch.rand(100), torch.tensor([1, 2, 3, 4]), torch.rand(1, 1, 1, 4))
1682*da0073e9SAndroid Build Coastguard Worker        def test_index_put(target, indices, rhs):
1683*da0073e9SAndroid Build Coastguard Worker            target[indices] = rhs
1684*da0073e9SAndroid Build Coastguard Worker            return target
1685*da0073e9SAndroid Build Coastguard Worker
1686*da0073e9SAndroid Build Coastguard Worker        FileCheck().check("aten::view").check("index_put_").run(
1687*da0073e9SAndroid Build Coastguard Worker            str(test_index_put.graph)
1688*da0073e9SAndroid Build Coastguard Worker        )
1689*da0073e9SAndroid Build Coastguard Worker
1690*da0073e9SAndroid Build Coastguard Worker    def test_index_put_trace_without_view(self):
1691*da0073e9SAndroid Build Coastguard Worker        @_trace(torch.rand(100), torch.tensor([1, 2, 3, 4]), torch.rand(4))
1692*da0073e9SAndroid Build Coastguard Worker        def test_index_put(target, indices, rhs):
1693*da0073e9SAndroid Build Coastguard Worker            target[indices] = rhs
1694*da0073e9SAndroid Build Coastguard Worker            return target
1695*da0073e9SAndroid Build Coastguard Worker
1696*da0073e9SAndroid Build Coastguard Worker        FileCheck().check_not("aten::view").check("index_put_").run(
1697*da0073e9SAndroid Build Coastguard Worker            str(test_index_put.graph)
1698*da0073e9SAndroid Build Coastguard Worker        )
1699*da0073e9SAndroid Build Coastguard Worker
1700*da0073e9SAndroid Build Coastguard Worker    @suppress_warnings
1701*da0073e9SAndroid Build Coastguard Worker    def test_trace_checker_dot_data(self):
1702*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(
1703*da0073e9SAndroid Build Coastguard Worker            torch.jit.TracingCheckError,
1704*da0073e9SAndroid Build Coastguard Worker            r"Tensor-valued Constant nodes differed in value " r"across invocations",
1705*da0073e9SAndroid Build Coastguard Worker        ):
1706*da0073e9SAndroid Build Coastguard Worker
1707*da0073e9SAndroid Build Coastguard Worker            @_trace(torch.rand(3, 4), check_inputs=[(torch.rand(3, 4),)])
1708*da0073e9SAndroid Build Coastguard Worker            def foo(x):
1709*da0073e9SAndroid Build Coastguard Worker                y = x.data
1710*da0073e9SAndroid Build Coastguard Worker                return x + y
1711*da0073e9SAndroid Build Coastguard Worker
1712*da0073e9SAndroid Build Coastguard Worker    @suppress_warnings
1713*da0073e9SAndroid Build Coastguard Worker    def test_trace_checker_control_flow(self):
1714*da0073e9SAndroid Build Coastguard Worker        def foo(x):
1715*da0073e9SAndroid Build Coastguard Worker            for _ in range(x.size(0)):
1716*da0073e9SAndroid Build Coastguard Worker                x = torch.neg(x)
1717*da0073e9SAndroid Build Coastguard Worker            return x
1718*da0073e9SAndroid Build Coastguard Worker
1719*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(
1720*da0073e9SAndroid Build Coastguard Worker            torch.jit.TracingCheckError, r"Graphs differed across invocations!"
1721*da0073e9SAndroid Build Coastguard Worker        ):
1722*da0073e9SAndroid Build Coastguard Worker            torch.jit.trace(foo, torch.randn(3, 4), check_inputs=[torch.randn(4, 4)])
1723*da0073e9SAndroid Build Coastguard Worker
1724*da0073e9SAndroid Build Coastguard Worker    @suppress_warnings
1725*da0073e9SAndroid Build Coastguard Worker    def test_trace_checker_memoization(self):
1726*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(
1727*da0073e9SAndroid Build Coastguard Worker            torch.jit.TracingCheckError, r"Graphs differed across invocations!"
1728*da0073e9SAndroid Build Coastguard Worker        ):
1729*da0073e9SAndroid Build Coastguard Worker
1730*da0073e9SAndroid Build Coastguard Worker            def foo(x):
1731*da0073e9SAndroid Build Coastguard Worker                if not hasattr(foo, "cache"):
1732*da0073e9SAndroid Build Coastguard Worker                    foo.cache = torch.neg(x)
1733*da0073e9SAndroid Build Coastguard Worker                return x + foo.cache
1734*da0073e9SAndroid Build Coastguard Worker
1735*da0073e9SAndroid Build Coastguard Worker            traced = torch.jit.trace(
1736*da0073e9SAndroid Build Coastguard Worker                foo, torch.rand(3, 4), check_inputs=[(torch.rand(3, 4),)]
1737*da0073e9SAndroid Build Coastguard Worker            )
1738*da0073e9SAndroid Build Coastguard Worker
1739*da0073e9SAndroid Build Coastguard Worker    def test_trace_checker_slice_lhs(self):
1740*da0073e9SAndroid Build Coastguard Worker        def foo(x):
1741*da0073e9SAndroid Build Coastguard Worker            for i in range(3):
1742*da0073e9SAndroid Build Coastguard Worker                x[i, :] = torch.zeros(4)
1743*da0073e9SAndroid Build Coastguard Worker            return x
1744*da0073e9SAndroid Build Coastguard Worker
1745*da0073e9SAndroid Build Coastguard Worker        self.checkTrace(foo, (torch.rand(3, 4),), inputs_require_grads=False)
1746*da0073e9SAndroid Build Coastguard Worker
1747*da0073e9SAndroid Build Coastguard Worker    def test_trace_checker_inplace_on_view(self):
1748*da0073e9SAndroid Build Coastguard Worker        def foo(x):
1749*da0073e9SAndroid Build Coastguard Worker            x.view(-1).add_(-x.view(-1))
1750*da0073e9SAndroid Build Coastguard Worker            return x
1751*da0073e9SAndroid Build Coastguard Worker
1752*da0073e9SAndroid Build Coastguard Worker        with self.assertWarnsRegex(
1753*da0073e9SAndroid Build Coastguard Worker            torch.jit.TracerWarning,
1754*da0073e9SAndroid Build Coastguard Worker            "Output nr 1. of the traced function does not match the "
1755*da0073e9SAndroid Build Coastguard Worker            "corresponding output of the Python function",
1756*da0073e9SAndroid Build Coastguard Worker        ):
1757*da0073e9SAndroid Build Coastguard Worker            torch.jit.trace(
1758*da0073e9SAndroid Build Coastguard Worker                foo,
1759*da0073e9SAndroid Build Coastguard Worker                torch.rand(3, 4),
1760*da0073e9SAndroid Build Coastguard Worker                check_inputs=[torch.rand(5, 6)],
1761*da0073e9SAndroid Build Coastguard Worker                _force_outplace=True,
1762*da0073e9SAndroid Build Coastguard Worker            )
1763*da0073e9SAndroid Build Coastguard Worker
1764*da0073e9SAndroid Build Coastguard Worker    def test_lhs_index_fails(self):
1765*da0073e9SAndroid Build Coastguard Worker        def foo(x):
1766*da0073e9SAndroid Build Coastguard Worker            x[0, 1] = 4
1767*da0073e9SAndroid Build Coastguard Worker            return x
1768*da0073e9SAndroid Build Coastguard Worker
1769*da0073e9SAndroid Build Coastguard Worker        with self.assertWarnsRegex(
1770*da0073e9SAndroid Build Coastguard Worker            torch.jit.TracerWarning, "cause the trace to be incorrect"
1771*da0073e9SAndroid Build Coastguard Worker        ):
1772*da0073e9SAndroid Build Coastguard Worker            torch.jit.trace(foo, torch.rand(3, 4), _force_outplace=True)
1773*da0073e9SAndroid Build Coastguard Worker
1774*da0073e9SAndroid Build Coastguard Worker    def test_lhs_index_trivial(self):
1775*da0073e9SAndroid Build Coastguard Worker        def foo(y, x):
1776*da0073e9SAndroid Build Coastguard Worker            y[...] = x
1777*da0073e9SAndroid Build Coastguard Worker            return y
1778*da0073e9SAndroid Build Coastguard Worker
1779*da0073e9SAndroid Build Coastguard Worker        self.checkTrace(
1780*da0073e9SAndroid Build Coastguard Worker            foo, (torch.rand(3, 4), torch.rand(4)), inputs_require_grads=False
1781*da0073e9SAndroid Build Coastguard Worker        )
1782*da0073e9SAndroid Build Coastguard Worker
1783*da0073e9SAndroid Build Coastguard Worker    def test_inplace_warn(self):
1784*da0073e9SAndroid Build Coastguard Worker        def foo(x):
1785*da0073e9SAndroid Build Coastguard Worker            x.view(-1).add_(-x.view(-1))
1786*da0073e9SAndroid Build Coastguard Worker            return x
1787*da0073e9SAndroid Build Coastguard Worker
1788*da0073e9SAndroid Build Coastguard Worker        with self.assertWarnsRegex(
1789*da0073e9SAndroid Build Coastguard Worker            torch.jit.TracerWarning, "cause the trace to be incorrect"
1790*da0073e9SAndroid Build Coastguard Worker        ):
1791*da0073e9SAndroid Build Coastguard Worker            torch.jit.trace(foo, torch.rand(3, 4), _force_outplace=True)
1792*da0073e9SAndroid Build Coastguard Worker
1793*da0073e9SAndroid Build Coastguard Worker    @suppress_warnings
1794*da0073e9SAndroid Build Coastguard Worker    def test_trace_checker_dropout_train(self):
1795*da0073e9SAndroid Build Coastguard Worker        def foo(x):
1796*da0073e9SAndroid Build Coastguard Worker            return torch.dropout(x, p=0.5, train=True)
1797*da0073e9SAndroid Build Coastguard Worker
1798*da0073e9SAndroid Build Coastguard Worker        with self.assertWarnsRegex(
1799*da0073e9SAndroid Build Coastguard Worker            torch.jit.TracerWarning,
1800*da0073e9SAndroid Build Coastguard Worker            "Output nr 1. of the traced function does not match the "
1801*da0073e9SAndroid Build Coastguard Worker            "corresponding output of the Python function",
1802*da0073e9SAndroid Build Coastguard Worker        ):
1803*da0073e9SAndroid Build Coastguard Worker            torch.jit.trace(foo, torch.rand(3, 4), check_inputs=[torch.rand(5, 6)])
1804*da0073e9SAndroid Build Coastguard Worker
1805*da0073e9SAndroid Build Coastguard Worker        with self.assertWarnsRegex(
1806*da0073e9SAndroid Build Coastguard Worker            torch.jit.TracerWarning, "Trace had nondeterministic nodes"
1807*da0073e9SAndroid Build Coastguard Worker        ):
1808*da0073e9SAndroid Build Coastguard Worker            torch.jit.trace(foo, torch.rand(3, 4), check_inputs=[torch.rand(5, 6)])
1809*da0073e9SAndroid Build Coastguard Worker
1810*da0073e9SAndroid Build Coastguard Worker    def test_trace_checker_dropout_notrain(self):
1811*da0073e9SAndroid Build Coastguard Worker        input = torch.rand(3, 4)
1812*da0073e9SAndroid Build Coastguard Worker
1813*da0073e9SAndroid Build Coastguard Worker        @_trace(input)
1814*da0073e9SAndroid Build Coastguard Worker        def foo(x):
1815*da0073e9SAndroid Build Coastguard Worker            return torch.dropout(x, p=0.5, train=False)
1816*da0073e9SAndroid Build Coastguard Worker
1817*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(foo(input), input)
1818*da0073e9SAndroid Build Coastguard Worker
1819*da0073e9SAndroid Build Coastguard Worker    def test_trace_contiguous(self):
1820*da0073e9SAndroid Build Coastguard Worker        def foo(x):
1821*da0073e9SAndroid Build Coastguard Worker            return x[:, :, ::2].contiguous().view(12)
1822*da0073e9SAndroid Build Coastguard Worker
1823*da0073e9SAndroid Build Coastguard Worker        x = torch.rand(2, 3, 4)
1824*da0073e9SAndroid Build Coastguard Worker        traced = torch.jit.trace(foo, (x,))
1825*da0073e9SAndroid Build Coastguard Worker        y = traced(x)
1826*da0073e9SAndroid Build Coastguard Worker        self.assertNotEqual(x.storage().data_ptr(), y.storage().data_ptr())
1827*da0073e9SAndroid Build Coastguard Worker
1828*da0073e9SAndroid Build Coastguard Worker    # This tests the logic in THPVariable_contiguous. There is short-circuiting
1829*da0073e9SAndroid Build Coastguard Worker    # code that prevents us from even getting to VariableType::contiguous, since
1830*da0073e9SAndroid Build Coastguard Worker    # it is an optimization that prevents us from acquiring the GIL for touching
1831*da0073e9SAndroid Build Coastguard Worker    # the device. We needed to add the tracing logic directly into the
1832*da0073e9SAndroid Build Coastguard Worker    # THPVariable_contiguous function only for the path where we are skipping
1833*da0073e9SAndroid Build Coastguard Worker    # dispatch into contiguous. We should see an aten::contiguous in this trace!
1834*da0073e9SAndroid Build Coastguard Worker    def test_trace_contiguous_short_circuit(self):
1835*da0073e9SAndroid Build Coastguard Worker        def foo(x):
1836*da0073e9SAndroid Build Coastguard Worker            return x.contiguous()
1837*da0073e9SAndroid Build Coastguard Worker
1838*da0073e9SAndroid Build Coastguard Worker        x = torch.rand(2, 3, 4)
1839*da0073e9SAndroid Build Coastguard Worker        traced = torch.jit.trace(foo, (x,))
1840*da0073e9SAndroid Build Coastguard Worker        FileCheck().check("aten::contiguous").run(str(traced.graph))
1841*da0073e9SAndroid Build Coastguard Worker
1842*da0073e9SAndroid Build Coastguard Worker    def test_trace_inverse(self):
1843*da0073e9SAndroid Build Coastguard Worker        def foo(x):
1844*da0073e9SAndroid Build Coastguard Worker            return ~x
1845*da0073e9SAndroid Build Coastguard Worker
1846*da0073e9SAndroid Build Coastguard Worker        foo_traced = torch.jit.trace(foo, torch.zeros(3, 4, dtype=torch.uint8))
1847*da0073e9SAndroid Build Coastguard Worker        eg = torch.zeros(3, dtype=torch.uint8)
1848*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(foo_traced(eg), foo(eg))
1849*da0073e9SAndroid Build Coastguard Worker
1850*da0073e9SAndroid Build Coastguard Worker    def test_trace_modulelist(self):
1851*da0073e9SAndroid Build Coastguard Worker        class MySubmod(torch.nn.Module):
1852*da0073e9SAndroid Build Coastguard Worker            def __init__(self) -> None:
1853*da0073e9SAndroid Build Coastguard Worker                super().__init__()
1854*da0073e9SAndroid Build Coastguard Worker                self.relu = torch.nn.ReLU()
1855*da0073e9SAndroid Build Coastguard Worker
1856*da0073e9SAndroid Build Coastguard Worker            def forward(self, x):
1857*da0073e9SAndroid Build Coastguard Worker                return self.relu(x)
1858*da0073e9SAndroid Build Coastguard Worker
1859*da0073e9SAndroid Build Coastguard Worker        class MyMod(torch.nn.Module):
1860*da0073e9SAndroid Build Coastguard Worker            def __init__(self) -> None:
1861*da0073e9SAndroid Build Coastguard Worker                super().__init__()
1862*da0073e9SAndroid Build Coastguard Worker                self.ml = torch.nn.ModuleList([MySubmod(), MySubmod()])
1863*da0073e9SAndroid Build Coastguard Worker
1864*da0073e9SAndroid Build Coastguard Worker            def forward(self, x):
1865*da0073e9SAndroid Build Coastguard Worker                for mod in self.ml:
1866*da0073e9SAndroid Build Coastguard Worker                    x = mod(x)
1867*da0073e9SAndroid Build Coastguard Worker                return x
1868*da0073e9SAndroid Build Coastguard Worker
1869*da0073e9SAndroid Build Coastguard Worker        traced = torch.jit.trace(MyMod(), (torch.rand(3, 4),))
1870*da0073e9SAndroid Build Coastguard Worker
1871*da0073e9SAndroid Build Coastguard Worker    def test_trace_fork_join_and_module(self):
1872*da0073e9SAndroid Build Coastguard Worker        class MySubmod(torch.nn.Module):
1873*da0073e9SAndroid Build Coastguard Worker            def __init__(self) -> None:
1874*da0073e9SAndroid Build Coastguard Worker                super().__init__()
1875*da0073e9SAndroid Build Coastguard Worker                self.relu = torch.nn.ReLU()
1876*da0073e9SAndroid Build Coastguard Worker
1877*da0073e9SAndroid Build Coastguard Worker            def forward(self, x):
1878*da0073e9SAndroid Build Coastguard Worker                return self.relu(x), torch.neg(x)
1879*da0073e9SAndroid Build Coastguard Worker
1880*da0073e9SAndroid Build Coastguard Worker        class Mod(torch.nn.Module):
1881*da0073e9SAndroid Build Coastguard Worker            def __init__(self) -> None:
1882*da0073e9SAndroid Build Coastguard Worker                super().__init__()
1883*da0073e9SAndroid Build Coastguard Worker                self.ml = torch.nn.ModuleList([MySubmod() for i in range(2)])
1884*da0073e9SAndroid Build Coastguard Worker
1885*da0073e9SAndroid Build Coastguard Worker            def forward(self, x):
1886*da0073e9SAndroid Build Coastguard Worker                futs = []
1887*da0073e9SAndroid Build Coastguard Worker                for i in range(2):
1888*da0073e9SAndroid Build Coastguard Worker                    futs.append(torch.jit._fork(self.ml[i], x))
1889*da0073e9SAndroid Build Coastguard Worker
1890*da0073e9SAndroid Build Coastguard Worker                results = []
1891*da0073e9SAndroid Build Coastguard Worker                for i in range(2):
1892*da0073e9SAndroid Build Coastguard Worker                    results.append(torch.jit._wait(futs[i])[0])
1893*da0073e9SAndroid Build Coastguard Worker
1894*da0073e9SAndroid Build Coastguard Worker                return torch.stack(results)
1895*da0073e9SAndroid Build Coastguard Worker
1896*da0073e9SAndroid Build Coastguard Worker        m = Mod()
1897*da0073e9SAndroid Build Coastguard Worker        traced = torch.jit.trace(m, torch.rand(3, 4))
1898*da0073e9SAndroid Build Coastguard Worker
1899*da0073e9SAndroid Build Coastguard Worker    def test_trace_invert_module_hierarchy(self):
1900*da0073e9SAndroid Build Coastguard Worker        class MySubmod(torch.nn.Module):
1901*da0073e9SAndroid Build Coastguard Worker            def __init__(self) -> None:
1902*da0073e9SAndroid Build Coastguard Worker                super().__init__()
1903*da0073e9SAndroid Build Coastguard Worker                self.relu = torch.nn.ReLU()
1904*da0073e9SAndroid Build Coastguard Worker
1905*da0073e9SAndroid Build Coastguard Worker            def forward(self, x):
1906*da0073e9SAndroid Build Coastguard Worker                return self.relu(x), torch.neg(x)
1907*da0073e9SAndroid Build Coastguard Worker
1908*da0073e9SAndroid Build Coastguard Worker        class MyFunctionalMod(torch.nn.Module):
1909*da0073e9SAndroid Build Coastguard Worker            def forward(self, x, submod):
1910*da0073e9SAndroid Build Coastguard Worker                return submod(x)
1911*da0073e9SAndroid Build Coastguard Worker
1912*da0073e9SAndroid Build Coastguard Worker        class Mod(torch.nn.Module):
1913*da0073e9SAndroid Build Coastguard Worker            def __init__(self) -> None:
1914*da0073e9SAndroid Build Coastguard Worker                super().__init__()
1915*da0073e9SAndroid Build Coastguard Worker                self.sm = MySubmod()
1916*da0073e9SAndroid Build Coastguard Worker                self.fm = MyFunctionalMod()
1917*da0073e9SAndroid Build Coastguard Worker
1918*da0073e9SAndroid Build Coastguard Worker            def forward(self, x):
1919*da0073e9SAndroid Build Coastguard Worker                return self.fm(x, self.sm)
1920*da0073e9SAndroid Build Coastguard Worker
1921*da0073e9SAndroid Build Coastguard Worker        torch.jit.trace(Mod(), (torch.rand(3, 4),))
1922*da0073e9SAndroid Build Coastguard Worker
1923*da0073e9SAndroid Build Coastguard Worker    @skipIfCrossRef
1924*da0073e9SAndroid Build Coastguard Worker    def test_trace_records_names(self):
1925*da0073e9SAndroid Build Coastguard Worker        def foo(bar, baz):
1926*da0073e9SAndroid Build Coastguard Worker            baz = bar + 3
1927*da0073e9SAndroid Build Coastguard Worker            quick_brown_fox = torch.neg(baz)
1928*da0073e9SAndroid Build Coastguard Worker            for _ in range(20):
1929*da0073e9SAndroid Build Coastguard Worker                yeet = quick_brown_fox - 3.14
1930*da0073e9SAndroid Build Coastguard Worker            return yeet
1931*da0073e9SAndroid Build Coastguard Worker
1932*da0073e9SAndroid Build Coastguard Worker        traced = torch.jit.trace(foo, (torch.rand(3, 3), torch.rand(3, 3)))
1933*da0073e9SAndroid Build Coastguard Worker        graph_str = str(traced.graph)
1934*da0073e9SAndroid Build Coastguard Worker        assert "bar" in graph_str
1935*da0073e9SAndroid Build Coastguard Worker        assert "baz" in graph_str
1936*da0073e9SAndroid Build Coastguard Worker        assert "quick_brown_fox" in graph_str
1937*da0073e9SAndroid Build Coastguard Worker
1938*da0073e9SAndroid Build Coastguard Worker    @skipIfTorchDynamo("Not a suitable test for TorchDynamo")
1939*da0073e9SAndroid Build Coastguard Worker    def test_tracing_hooks(self):
1940*da0073e9SAndroid Build Coastguard Worker        class Net(nn.Module):
1941*da0073e9SAndroid Build Coastguard Worker            def forward(self, x):
1942*da0073e9SAndroid Build Coastguard Worker                return x + x
1943*da0073e9SAndroid Build Coastguard Worker
1944*da0073e9SAndroid Build Coastguard Worker        def test_hook(is_post_hook, hook, fc):
1945*da0073e9SAndroid Build Coastguard Worker            n = Net()
1946*da0073e9SAndroid Build Coastguard Worker            if is_post_hook:
1947*da0073e9SAndroid Build Coastguard Worker                n.register_forward_hook(hook)
1948*da0073e9SAndroid Build Coastguard Worker            else:
1949*da0073e9SAndroid Build Coastguard Worker                n.register_forward_pre_hook(hook)
1950*da0073e9SAndroid Build Coastguard Worker
1951*da0073e9SAndroid Build Coastguard Worker            module = torch.jit.trace(n, (torch.tensor(1.0),))
1952*da0073e9SAndroid Build Coastguard Worker
1953*da0073e9SAndroid Build Coastguard Worker            eager_input = torch.tensor(1.0)
1954*da0073e9SAndroid Build Coastguard Worker            eager_out = n(eager_input)
1955*da0073e9SAndroid Build Coastguard Worker
1956*da0073e9SAndroid Build Coastguard Worker            fc.run(module.forward.graph)
1957*da0073e9SAndroid Build Coastguard Worker            input = torch.tensor(1.0)
1958*da0073e9SAndroid Build Coastguard Worker            output = module(input)
1959*da0073e9SAndroid Build Coastguard Worker
1960*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(input, eager_input)
1961*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(output, eager_out)
1962*da0073e9SAndroid Build Coastguard Worker
1963*da0073e9SAndroid Build Coastguard Worker        def hook_no_return(mod, input, output):
1964*da0073e9SAndroid Build Coastguard Worker            input[0].add_(1)
1965*da0073e9SAndroid Build Coastguard Worker            output.sub_(1)
1966*da0073e9SAndroid Build Coastguard Worker
1967*da0073e9SAndroid Build Coastguard Worker        fc = FileCheck().check("add(").check("add_(").check("sub_(")
1968*da0073e9SAndroid Build Coastguard Worker        test_hook(True, hook_no_return, fc)
1969*da0073e9SAndroid Build Coastguard Worker
1970*da0073e9SAndroid Build Coastguard Worker        def hook_return(mod, input, output):
1971*da0073e9SAndroid Build Coastguard Worker            input[0].add_(1)
1972*da0073e9SAndroid Build Coastguard Worker            return output - 3
1973*da0073e9SAndroid Build Coastguard Worker
1974*da0073e9SAndroid Build Coastguard Worker        fc = FileCheck().check("add(").check("add_(").check("sub(")
1975*da0073e9SAndroid Build Coastguard Worker        test_hook(True, hook_return, fc)
1976*da0073e9SAndroid Build Coastguard Worker
1977*da0073e9SAndroid Build Coastguard Worker        b = torch.tensor(3.0)
1978*da0073e9SAndroid Build Coastguard Worker
1979*da0073e9SAndroid Build Coastguard Worker        def captured_hook(mod, input, output):
1980*da0073e9SAndroid Build Coastguard Worker            return output - b
1981*da0073e9SAndroid Build Coastguard Worker
1982*da0073e9SAndroid Build Coastguard Worker        fc = FileCheck().check("add(").check("sub(")
1983*da0073e9SAndroid Build Coastguard Worker        test_hook(True, captured_hook, fc)
1984*da0073e9SAndroid Build Coastguard Worker
1985*da0073e9SAndroid Build Coastguard Worker        def pre_hook_no_ret(mod, input):
1986*da0073e9SAndroid Build Coastguard Worker            input[0].add_(3)
1987*da0073e9SAndroid Build Coastguard Worker
1988*da0073e9SAndroid Build Coastguard Worker        fc = FileCheck().check("add_(").check("add(")
1989*da0073e9SAndroid Build Coastguard Worker        test_hook(False, pre_hook_no_ret, fc)
1990*da0073e9SAndroid Build Coastguard Worker
1991*da0073e9SAndroid Build Coastguard Worker        def pre_hook_ret(mod, input):
1992*da0073e9SAndroid Build Coastguard Worker            return input[0] - 4
1993*da0073e9SAndroid Build Coastguard Worker
1994*da0073e9SAndroid Build Coastguard Worker        fc = FileCheck().check("sub(").check("add(")
1995*da0073e9SAndroid Build Coastguard Worker        test_hook(False, pre_hook_ret, fc)
1996*da0073e9SAndroid Build Coastguard Worker
1997*da0073e9SAndroid Build Coastguard Worker    def test_tracing_backward_hook_error(self):
1998*da0073e9SAndroid Build Coastguard Worker        class Net(nn.Module):
1999*da0073e9SAndroid Build Coastguard Worker            def forward(self, x):
2000*da0073e9SAndroid Build Coastguard Worker                return x + x
2001*da0073e9SAndroid Build Coastguard Worker
2002*da0073e9SAndroid Build Coastguard Worker        n = Net()
2003*da0073e9SAndroid Build Coastguard Worker
2004*da0073e9SAndroid Build Coastguard Worker        def backward_hook(module, grad_input, grad_output):
2005*da0073e9SAndroid Build Coastguard Worker            pass
2006*da0073e9SAndroid Build Coastguard Worker
2007*da0073e9SAndroid Build Coastguard Worker        n.register_backward_hook(backward_hook)
2008*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(Exception, "backward hooks assigned"):
2009*da0073e9SAndroid Build Coastguard Worker            torch.jit.trace(n, (torch.tensor(1.0),))
2010*da0073e9SAndroid Build Coastguard Worker
2011*da0073e9SAndroid Build Coastguard Worker    def test_tracing_multiple_methods(self):
2012*da0073e9SAndroid Build Coastguard Worker        class Net(nn.Module):
2013*da0073e9SAndroid Build Coastguard Worker            def __init__(self) -> None:
2014*da0073e9SAndroid Build Coastguard Worker                super().__init__()
2015*da0073e9SAndroid Build Coastguard Worker                self.conv = nn.Conv2d(1, 1, 3)
2016*da0073e9SAndroid Build Coastguard Worker
2017*da0073e9SAndroid Build Coastguard Worker            def forward(self, x):
2018*da0073e9SAndroid Build Coastguard Worker                return self.conv(x)
2019*da0073e9SAndroid Build Coastguard Worker
2020*da0073e9SAndroid Build Coastguard Worker            def weighted_kernel_sum(self, weight):
2021*da0073e9SAndroid Build Coastguard Worker                return weight * self.conv.weight
2022*da0073e9SAndroid Build Coastguard Worker
2023*da0073e9SAndroid Build Coastguard Worker        example_weight = torch.rand(1, 1, 3, 3)
2024*da0073e9SAndroid Build Coastguard Worker        example_forward_input = torch.rand(1, 1, 3, 3)
2025*da0073e9SAndroid Build Coastguard Worker        inputs = {
2026*da0073e9SAndroid Build Coastguard Worker            "forward": example_forward_input,
2027*da0073e9SAndroid Build Coastguard Worker            "weighted_kernel_sum": example_weight,
2028*da0073e9SAndroid Build Coastguard Worker        }
2029*da0073e9SAndroid Build Coastguard Worker        n = Net()
2030*da0073e9SAndroid Build Coastguard Worker        module = torch.jit.trace_module(n, inputs)
2031*da0073e9SAndroid Build Coastguard Worker
2032*da0073e9SAndroid Build Coastguard Worker        check_inputs = []
2033*da0073e9SAndroid Build Coastguard Worker        for i in range(2):
2034*da0073e9SAndroid Build Coastguard Worker            check_weight = torch.rand(1, 1, 3, 3)
2035*da0073e9SAndroid Build Coastguard Worker            check_forward_input = torch.rand(1, 1, 3, 3)
2036*da0073e9SAndroid Build Coastguard Worker            check_inputs.append(
2037*da0073e9SAndroid Build Coastguard Worker                {"forward": check_forward_input, "weighted_kernel_sum": check_weight}
2038*da0073e9SAndroid Build Coastguard Worker            )
2039*da0073e9SAndroid Build Coastguard Worker        module = torch.jit.trace_module(
2040*da0073e9SAndroid Build Coastguard Worker            n, inputs, check_trace=True, check_inputs=check_inputs
2041*da0073e9SAndroid Build Coastguard Worker        )
2042*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(module._c._has_method("forward"))
2043*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(module._c._has_method("weighted_kernel_sum"))
2044*da0073e9SAndroid Build Coastguard Worker
2045*da0073e9SAndroid Build Coastguard Worker        module = torch.jit.trace(n.forward, example_forward_input)
2046*da0073e9SAndroid Build Coastguard Worker        module = torch.jit.trace(
2047*da0073e9SAndroid Build Coastguard Worker            n.forward,
2048*da0073e9SAndroid Build Coastguard Worker            example_forward_input,
2049*da0073e9SAndroid Build Coastguard Worker            check_trace=True,
2050*da0073e9SAndroid Build Coastguard Worker            check_inputs=[example_forward_input],
2051*da0073e9SAndroid Build Coastguard Worker        )
2052*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(
2053*da0073e9SAndroid Build Coastguard Worker            AttributeError,
2054*da0073e9SAndroid Build Coastguard Worker            "trace doesn't support compiling individual module's functions",
2055*da0073e9SAndroid Build Coastguard Worker        ):
2056*da0073e9SAndroid Build Coastguard Worker            module = torch.jit.trace(n.weighted_kernel_sum, inputs)
2057*da0073e9SAndroid Build Coastguard Worker
2058*da0073e9SAndroid Build Coastguard Worker    def test_tensor_with_grad_as_constant(self):
2059*da0073e9SAndroid Build Coastguard Worker        param = torch.randn(3).requires_grad_()
2060*da0073e9SAndroid Build Coastguard Worker        x = torch.randn(3)
2061*da0073e9SAndroid Build Coastguard Worker
2062*da0073e9SAndroid Build Coastguard Worker        def f(x):
2063*da0073e9SAndroid Build Coastguard Worker            return x + param
2064*da0073e9SAndroid Build Coastguard Worker
2065*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(
2066*da0073e9SAndroid Build Coastguard Worker            RuntimeError, "Cannot insert a Tensor that requires grad as a constant"
2067*da0073e9SAndroid Build Coastguard Worker        ):
2068*da0073e9SAndroid Build Coastguard Worker            torch.jit.trace(f, x)
2069*da0073e9SAndroid Build Coastguard Worker
2070*da0073e9SAndroid Build Coastguard Worker    def test_non_tensor_tracing(self):
2071*da0073e9SAndroid Build Coastguard Worker        def f(x):
2072*da0073e9SAndroid Build Coastguard Worker            return x + param  # noqa: F821
2073*da0073e9SAndroid Build Coastguard Worker
2074*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(
2075*da0073e9SAndroid Build Coastguard Worker            RuntimeError, r"Type 'Tuple\[int\]' cannot be traced"
2076*da0073e9SAndroid Build Coastguard Worker        ):
2077*da0073e9SAndroid Build Coastguard Worker            torch.jit.trace(f, (1,))
2078*da0073e9SAndroid Build Coastguard Worker
2079*da0073e9SAndroid Build Coastguard Worker    def test_trace_skip_none_submodule(self):
2080*da0073e9SAndroid Build Coastguard Worker        class TestModule(torch.nn.Module):
2081*da0073e9SAndroid Build Coastguard Worker            def __init__(self) -> None:
2082*da0073e9SAndroid Build Coastguard Worker                super().__init__()
2083*da0073e9SAndroid Build Coastguard Worker                self.submod = torch.nn.Linear(3, 4)
2084*da0073e9SAndroid Build Coastguard Worker                self.submod = None
2085*da0073e9SAndroid Build Coastguard Worker
2086*da0073e9SAndroid Build Coastguard Worker            def forward(self, inputs):
2087*da0073e9SAndroid Build Coastguard Worker                return inputs
2088*da0073e9SAndroid Build Coastguard Worker
2089*da0073e9SAndroid Build Coastguard Worker        m = TestModule()
2090*da0073e9SAndroid Build Coastguard Worker        tm = torch.jit.trace(m, torch.tensor(1.0))
2091*da0073e9SAndroid Build Coastguard Worker        self.assertFalse(hasattr(tm, "submod"))
2092*da0073e9SAndroid Build Coastguard Worker
2093*da0073e9SAndroid Build Coastguard Worker    def test_trace_with_conditional_property(self):
2094*da0073e9SAndroid Build Coastguard Worker        class Net(nn.Module):
2095*da0073e9SAndroid Build Coastguard Worker            def __init__(self, attr=None):
2096*da0073e9SAndroid Build Coastguard Worker                super().__init__()
2097*da0073e9SAndroid Build Coastguard Worker                if attr is not None:
2098*da0073e9SAndroid Build Coastguard Worker                    self._attr = attr
2099*da0073e9SAndroid Build Coastguard Worker                self.attr_name = "_attr"
2100*da0073e9SAndroid Build Coastguard Worker
2101*da0073e9SAndroid Build Coastguard Worker            @property
2102*da0073e9SAndroid Build Coastguard Worker            def attr(self):
2103*da0073e9SAndroid Build Coastguard Worker                return getattr(self, self.attr_name)
2104*da0073e9SAndroid Build Coastguard Worker
2105*da0073e9SAndroid Build Coastguard Worker            def forward(self, x):
2106*da0073e9SAndroid Build Coastguard Worker                return x
2107*da0073e9SAndroid Build Coastguard Worker
2108*da0073e9SAndroid Build Coastguard Worker        x = torch.ones(1)
2109*da0073e9SAndroid Build Coastguard Worker        torch.jit.trace(Net(), x)
2110*da0073e9SAndroid Build Coastguard Worker
2111*da0073e9SAndroid Build Coastguard Worker    def test_trace_func_argument_names_captured(self):
2112*da0073e9SAndroid Build Coastguard Worker        def fn(first_arg: torch.Tensor, second_arg: torch.Tensor) -> torch.Tensor:
2113*da0073e9SAndroid Build Coastguard Worker            return first_arg + second_arg
2114*da0073e9SAndroid Build Coastguard Worker
2115*da0073e9SAndroid Build Coastguard Worker        traced_fn = torch.jit.trace(fn, (torch.ones(1), torch.ones(1)))
2116*da0073e9SAndroid Build Coastguard Worker        FileCheck().check("first_arg").check_next("second_arg").run(
2117*da0073e9SAndroid Build Coastguard Worker            str(traced_fn.graph)
2118*da0073e9SAndroid Build Coastguard Worker        )
2119*da0073e9SAndroid Build Coastguard Worker
2120*da0073e9SAndroid Build Coastguard Worker    def test_trace_partial_func_argument_names_captured(self):
2121*da0073e9SAndroid Build Coastguard Worker        def fn(first_arg: torch.Tensor, second_arg=1) -> torch.Tensor:
2122*da0073e9SAndroid Build Coastguard Worker            return first_arg + second_arg
2123*da0073e9SAndroid Build Coastguard Worker
2124*da0073e9SAndroid Build Coastguard Worker        traced_fn = torch.jit.trace(fn, (torch.ones(1),))
2125*da0073e9SAndroid Build Coastguard Worker        FileCheck().check("first_arg").check_not("second_arg").run(str(traced_fn.graph))
2126*da0073e9SAndroid Build Coastguard Worker
2127*da0073e9SAndroid Build Coastguard Worker    def test_trace_module_argument_names_captured(self):
2128*da0073e9SAndroid Build Coastguard Worker        class TestModule(nn.Module):
2129*da0073e9SAndroid Build Coastguard Worker            def __init__(self) -> None:
2130*da0073e9SAndroid Build Coastguard Worker                super().__init__()
2131*da0073e9SAndroid Build Coastguard Worker                self.conv = nn.Conv2d(1, 1, 3)
2132*da0073e9SAndroid Build Coastguard Worker
2133*da0073e9SAndroid Build Coastguard Worker            def forward(self, first_arg: torch.Tensor, second_arg: torch.Tensor):
2134*da0073e9SAndroid Build Coastguard Worker                return self.conv(first_arg) + second_arg
2135*da0073e9SAndroid Build Coastguard Worker
2136*da0073e9SAndroid Build Coastguard Worker        m = TestModule()
2137*da0073e9SAndroid Build Coastguard Worker        example_input = (torch.ones(1, 1, 3, 3), torch.ones(1, 1, 3, 3))
2138*da0073e9SAndroid Build Coastguard Worker
2139*da0073e9SAndroid Build Coastguard Worker        # Explicitly tracing module's forward method
2140*da0073e9SAndroid Build Coastguard Worker        traced_module_forward = torch.jit.trace(m.forward, example_input)
2141*da0073e9SAndroid Build Coastguard Worker        FileCheck().check("first_arg").check_next("second_arg").run(
2142*da0073e9SAndroid Build Coastguard Worker            str(traced_module_forward.graph)
2143*da0073e9SAndroid Build Coastguard Worker        )
2144*da0073e9SAndroid Build Coastguard Worker
2145*da0073e9SAndroid Build Coastguard Worker        # Tracing module's directly
2146*da0073e9SAndroid Build Coastguard Worker        traced_module = torch.jit.trace(m, example_input)
2147*da0073e9SAndroid Build Coastguard Worker        FileCheck().check("first_arg").check_next("second_arg").run(
2148*da0073e9SAndroid Build Coastguard Worker            str(traced_module.graph)
2149*da0073e9SAndroid Build Coastguard Worker        )
2150*da0073e9SAndroid Build Coastguard Worker
2151*da0073e9SAndroid Build Coastguard Worker    def test_trace_checking_with_deprecated_name(self):
2152*da0073e9SAndroid Build Coastguard Worker        class MyClass(torch.nn.Module):
2153*da0073e9SAndroid Build Coastguard Worker            def __init__(self) -> None:
2154*da0073e9SAndroid Build Coastguard Worker                super(MyClass, self).__init__()
2155*da0073e9SAndroid Build Coastguard Worker
2156*da0073e9SAndroid Build Coastguard Worker            def forward(self, x, y, **deprecated_arguments):
2157*da0073e9SAndroid Build Coastguard Worker                if len(deprecated_arguments) > 0:
2158*da0073e9SAndroid Build Coastguard Worker                    raise RuntimeError(
2159*da0073e9SAndroid Build Coastguard Worker                        f"Got unexpected arguments: {deprecated_arguments}"
2160*da0073e9SAndroid Build Coastguard Worker                    )
2161*da0073e9SAndroid Build Coastguard Worker                return x + y
2162*da0073e9SAndroid Build Coastguard Worker
2163*da0073e9SAndroid Build Coastguard Worker        model = MyClass()
2164*da0073e9SAndroid Build Coastguard Worker        m2 = torch.jit.trace(model, (torch.ones(1), torch.ones(1)))
2165*da0073e9SAndroid Build Coastguard Worker        m3 = torch.jit.trace(
2166*da0073e9SAndroid Build Coastguard Worker            model,
2167*da0073e9SAndroid Build Coastguard Worker            example_kwarg_inputs={"x": torch.ones(1), "y": torch.ones(1)},
2168*da0073e9SAndroid Build Coastguard Worker            strict=False,
2169*da0073e9SAndroid Build Coastguard Worker        )
2170*da0073e9SAndroid Build Coastguard Worker
2171*da0073e9SAndroid Build Coastguard Worker    def test_trace_with_tuple_tensor(self):
2172*da0073e9SAndroid Build Coastguard Worker        class MyClass(torch.nn.Module):
2173*da0073e9SAndroid Build Coastguard Worker            def __init__(self) -> None:
2174*da0073e9SAndroid Build Coastguard Worker                super(MyClass, self).__init__()
2175*da0073e9SAndroid Build Coastguard Worker
2176*da0073e9SAndroid Build Coastguard Worker            def forward(self, x, y):
2177*da0073e9SAndroid Build Coastguard Worker                return x + y[0] + y[1]
2178*da0073e9SAndroid Build Coastguard Worker
2179*da0073e9SAndroid Build Coastguard Worker        model = MyClass()
2180*da0073e9SAndroid Build Coastguard Worker        traced_model = torch.jit.trace(
2181*da0073e9SAndroid Build Coastguard Worker            model, (torch.ones(1), (torch.ones(1), torch.ones(1)))
2182*da0073e9SAndroid Build Coastguard Worker        )
2183*da0073e9SAndroid Build Coastguard Worker        input_dict = {
2184*da0073e9SAndroid Build Coastguard Worker            "x": torch.tensor([2, 3]),
2185*da0073e9SAndroid Build Coastguard Worker            "y": (torch.tensor([5, 6]), torch.tensor([7, 8])),
2186*da0073e9SAndroid Build Coastguard Worker        }
2187*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(model(**input_dict), traced_model(**input_dict))
2188*da0073e9SAndroid Build Coastguard Worker        traced_model = torch.jit.trace(
2189*da0073e9SAndroid Build Coastguard Worker            model,
2190*da0073e9SAndroid Build Coastguard Worker            example_kwarg_inputs={
2191*da0073e9SAndroid Build Coastguard Worker                "x": torch.ones(1),
2192*da0073e9SAndroid Build Coastguard Worker                "y": (torch.ones(1), torch.ones(1)),
2193*da0073e9SAndroid Build Coastguard Worker            },
2194*da0073e9SAndroid Build Coastguard Worker        )
2195*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(model(**input_dict), traced_model(**input_dict))
2196*da0073e9SAndroid Build Coastguard Worker
2197*da0073e9SAndroid Build Coastguard Worker    def test_trace_no_duplicated_lifted_input_output(self):
2198*da0073e9SAndroid Build Coastguard Worker        class Normalize(nn.Module):
2199*da0073e9SAndroid Build Coastguard Worker            def __init__(self) -> None:
2200*da0073e9SAndroid Build Coastguard Worker                super().__init__()
2201*da0073e9SAndroid Build Coastguard Worker                self.norm = nn.GroupNorm(num_groups=32, num_channels=32)
2202*da0073e9SAndroid Build Coastguard Worker
2203*da0073e9SAndroid Build Coastguard Worker            def forward(self, x, y):
2204*da0073e9SAndroid Build Coastguard Worker                if y is None:
2205*da0073e9SAndroid Build Coastguard Worker                    y = x
2206*da0073e9SAndroid Build Coastguard Worker                else:
2207*da0073e9SAndroid Build Coastguard Worker                    y = self.norm(y)
2208*da0073e9SAndroid Build Coastguard Worker                y = y * 2
2209*da0073e9SAndroid Build Coastguard Worker                return y
2210*da0073e9SAndroid Build Coastguard Worker
2211*da0073e9SAndroid Build Coastguard Worker        class G(nn.Module):
2212*da0073e9SAndroid Build Coastguard Worker            def __init__(self) -> None:
2213*da0073e9SAndroid Build Coastguard Worker                super().__init__()
2214*da0073e9SAndroid Build Coastguard Worker                self.norm = Normalize()
2215*da0073e9SAndroid Build Coastguard Worker
2216*da0073e9SAndroid Build Coastguard Worker            def forward(self, x):
2217*da0073e9SAndroid Build Coastguard Worker                A = self.norm(x, None)
2218*da0073e9SAndroid Build Coastguard Worker                B = F.relu(A)
2219*da0073e9SAndroid Build Coastguard Worker                return A, B
2220*da0073e9SAndroid Build Coastguard Worker
2221*da0073e9SAndroid Build Coastguard Worker        class Net(nn.Module):
2222*da0073e9SAndroid Build Coastguard Worker            def __init__(self) -> None:
2223*da0073e9SAndroid Build Coastguard Worker                super().__init__()
2224*da0073e9SAndroid Build Coastguard Worker                self.g = G()
2225*da0073e9SAndroid Build Coastguard Worker                self.norm_1 = Normalize()
2226*da0073e9SAndroid Build Coastguard Worker
2227*da0073e9SAndroid Build Coastguard Worker            def forward(self, x):
2228*da0073e9SAndroid Build Coastguard Worker                hs = self.g(x)
2229*da0073e9SAndroid Build Coastguard Worker                A, B = hs
2230*da0073e9SAndroid Build Coastguard Worker                h = self.norm_1(B, A)
2231*da0073e9SAndroid Build Coastguard Worker                return h
2232*da0073e9SAndroid Build Coastguard Worker
2233*da0073e9SAndroid Build Coastguard Worker        net = Net()
2234*da0073e9SAndroid Build Coastguard Worker        net = net.eval()
2235*da0073e9SAndroid Build Coastguard Worker        x = torch.randn(1, 32, 16, 16)
2236*da0073e9SAndroid Build Coastguard Worker        traced = torch.jit.trace(net, x)
2237*da0073e9SAndroid Build Coastguard Worker        FileCheck().check_not("prim::TupleUnpack").run(str(traced.graph))
2238*da0073e9SAndroid Build Coastguard Worker
2239*da0073e9SAndroid Build Coastguard Worker
2240*da0073e9SAndroid Build Coastguard Worker@skipIfTorchDynamo("Not a suitable test for TorchDynamo")
2241*da0073e9SAndroid Build Coastguard Workerclass TestMixTracingScripting(JitTestCase):
2242*da0073e9SAndroid Build Coastguard Worker    def test_trace_script(self):
2243*da0073e9SAndroid Build Coastguard Worker        @torch.jit.script
2244*da0073e9SAndroid Build Coastguard Worker        def func1(x: Tuple[Tensor, Tensor]) -> Tensor:
2245*da0073e9SAndroid Build Coastguard Worker            return x[0] + x[1]
2246*da0073e9SAndroid Build Coastguard Worker
2247*da0073e9SAndroid Build Coastguard Worker        @torch.jit.script
2248*da0073e9SAndroid Build Coastguard Worker        def func2(x: List[Tensor]) -> Tensor:
2249*da0073e9SAndroid Build Coastguard Worker            return x[0] + x[1]
2250*da0073e9SAndroid Build Coastguard Worker
2251*da0073e9SAndroid Build Coastguard Worker        a = torch.randn(5)
2252*da0073e9SAndroid Build Coastguard Worker        b = torch.randn(5)
2253*da0073e9SAndroid Build Coastguard Worker
2254*da0073e9SAndroid Build Coastguard Worker        self.checkTrace(func1, ((a, b),))
2255*da0073e9SAndroid Build Coastguard Worker        self.checkTrace(func2, ((a, b),))
2256*da0073e9SAndroid Build Coastguard Worker
2257*da0073e9SAndroid Build Coastguard Worker        @torch.jit.script
2258*da0073e9SAndroid Build Coastguard Worker        def func3(
2259*da0073e9SAndroid Build Coastguard Worker            x: Tensor, method: str = "bilinear", align_corners: bool = True
2260*da0073e9SAndroid Build Coastguard Worker        ) -> Tensor:
2261*da0073e9SAndroid Build Coastguard Worker            hw = x.shape[2:4]
2262*da0073e9SAndroid Build Coastguard Worker            return F.interpolate(x, hw, mode=method, align_corners=align_corners)
2263*da0073e9SAndroid Build Coastguard Worker
2264*da0073e9SAndroid Build Coastguard Worker        inp = torch.rand(1, 3, 6, 6)
2265*da0073e9SAndroid Build Coastguard Worker        self.checkTrace(func3, (inp,))
2266*da0073e9SAndroid Build Coastguard Worker
2267*da0073e9SAndroid Build Coastguard Worker        @torch.jit.script
2268*da0073e9SAndroid Build Coastguard Worker        def func4(x: Tensor, a: List[Optional[str]]) -> Tensor:
2269*da0073e9SAndroid Build Coastguard Worker            if len(a) == 2:
2270*da0073e9SAndroid Build Coastguard Worker                return x + 2
2271*da0073e9SAndroid Build Coastguard Worker            else:
2272*da0073e9SAndroid Build Coastguard Worker                return x
2273*da0073e9SAndroid Build Coastguard Worker
2274*da0073e9SAndroid Build Coastguard Worker    def test_trace_mixed_by_script_with_dict_output(self):
2275*da0073e9SAndroid Build Coastguard Worker        @torch.jit.script
2276*da0073e9SAndroid Build Coastguard Worker        def return_dict(input: torch.Tensor) -> Dict[str, torch.Tensor]:
2277*da0073e9SAndroid Build Coastguard Worker            return {"foo": input + 1}
2278*da0073e9SAndroid Build Coastguard Worker
2279*da0073e9SAndroid Build Coastguard Worker        class TraceModule(torch.nn.Module):
2280*da0073e9SAndroid Build Coastguard Worker            def forward(self, input):
2281*da0073e9SAndroid Build Coastguard Worker                dict = return_dict(input)
2282*da0073e9SAndroid Build Coastguard Worker                return dict["foo"] + dict["foo"]
2283*da0073e9SAndroid Build Coastguard Worker
2284*da0073e9SAndroid Build Coastguard Worker        x = torch.ones(1)
2285*da0073e9SAndroid Build Coastguard Worker        tm = torch.jit.trace(TraceModule(), x)
2286*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(tm(x), x + 1 + x + 1)
2287*da0073e9SAndroid Build Coastguard Worker
2288*da0073e9SAndroid Build Coastguard Worker    def test_trace_of_script(self):
2289*da0073e9SAndroid Build Coastguard Worker        @torch.jit.script
2290*da0073e9SAndroid Build Coastguard Worker        def foo(a, c):
2291*da0073e9SAndroid Build Coastguard Worker            b = 0.0
2292*da0073e9SAndroid Build Coastguard Worker            if bool(a == 0.0):
2293*da0073e9SAndroid Build Coastguard Worker                b = 1.0
2294*da0073e9SAndroid Build Coastguard Worker            return b + c
2295*da0073e9SAndroid Build Coastguard Worker
2296*da0073e9SAndroid Build Coastguard Worker        a = torch.ones(1, dtype=torch.float)
2297*da0073e9SAndroid Build Coastguard Worker
2298*da0073e9SAndroid Build Coastguard Worker        @_trace(torch.zeros(1, dtype=torch.float))
2299*da0073e9SAndroid Build Coastguard Worker        def use(b):
2300*da0073e9SAndroid Build Coastguard Worker            return foo(b - 1.0, a) + 1.0
2301*da0073e9SAndroid Build Coastguard Worker
2302*da0073e9SAndroid Build Coastguard Worker        # test we propagated shapes through the function
2303*da0073e9SAndroid Build Coastguard Worker        self.assertTrue("Dynamic" not in str(use.graph))
2304*da0073e9SAndroid Build Coastguard Worker
2305*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(3, use(torch.ones(1, dtype=torch.float)))
2306*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(2, use(torch.zeros(1, dtype=torch.float)))
2307*da0073e9SAndroid Build Coastguard Worker
2308*da0073e9SAndroid Build Coastguard Worker    def test_trace_with_size(self):
2309*da0073e9SAndroid Build Coastguard Worker        @_trace(torch.zeros(1, 1))
2310*da0073e9SAndroid Build Coastguard Worker        def foo(x):
2311*da0073e9SAndroid Build Coastguard Worker            return x + 1
2312*da0073e9SAndroid Build Coastguard Worker
2313*da0073e9SAndroid Build Coastguard Worker        @torch.jit.script
2314*da0073e9SAndroid Build Coastguard Worker        def bar(x):
2315*da0073e9SAndroid Build Coastguard Worker            y = int(foo(x))
2316*da0073e9SAndroid Build Coastguard Worker            if 1 == 1:
2317*da0073e9SAndroid Build Coastguard Worker                y = 7
2318*da0073e9SAndroid Build Coastguard Worker            return y + 1
2319*da0073e9SAndroid Build Coastguard Worker
2320*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(8, bar(torch.ones(1, 1)))
2321*da0073e9SAndroid Build Coastguard Worker
2322*da0073e9SAndroid Build Coastguard Worker    def test_tracing_slicing(self):
2323*da0073e9SAndroid Build Coastguard Worker        @_trace(torch.zeros(10))
2324*da0073e9SAndroid Build Coastguard Worker        def foo_trace(x):
2325*da0073e9SAndroid Build Coastguard Worker            return x[-5:-3]
2326*da0073e9SAndroid Build Coastguard Worker
2327*da0073e9SAndroid Build Coastguard Worker        @torch.jit.script
2328*da0073e9SAndroid Build Coastguard Worker        def foo_script(x):
2329*da0073e9SAndroid Build Coastguard Worker            return x[-5:-3]
2330*da0073e9SAndroid Build Coastguard Worker
2331*da0073e9SAndroid Build Coastguard Worker        def foo(x):
2332*da0073e9SAndroid Build Coastguard Worker            return x[-5:-3]
2333*da0073e9SAndroid Build Coastguard Worker
2334*da0073e9SAndroid Build Coastguard Worker        a = torch.arange(0, 8)
2335*da0073e9SAndroid Build Coastguard Worker        b = torch.arange(0, 20)
2336*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(foo_trace(a), foo_script(a))
2337*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(foo_trace(a), foo(a))
2338*da0073e9SAndroid Build Coastguard Worker        self.assertNotEqual(foo_trace(a), foo_trace(b))
2339*da0073e9SAndroid Build Coastguard Worker
2340*da0073e9SAndroid Build Coastguard Worker    def test_tracing_indexing(self):
2341*da0073e9SAndroid Build Coastguard Worker        @_trace(torch.zeros(10))
2342*da0073e9SAndroid Build Coastguard Worker        def foo_trace(x):
2343*da0073e9SAndroid Build Coastguard Worker            return x[-2]
2344*da0073e9SAndroid Build Coastguard Worker
2345*da0073e9SAndroid Build Coastguard Worker        @torch.jit.script
2346*da0073e9SAndroid Build Coastguard Worker        def foo_script(x):
2347*da0073e9SAndroid Build Coastguard Worker            return x[-2]
2348*da0073e9SAndroid Build Coastguard Worker
2349*da0073e9SAndroid Build Coastguard Worker        def foo(x):
2350*da0073e9SAndroid Build Coastguard Worker            return x[-2]
2351*da0073e9SAndroid Build Coastguard Worker
2352*da0073e9SAndroid Build Coastguard Worker        a = torch.arange(0, 8)
2353*da0073e9SAndroid Build Coastguard Worker        b = torch.arange(0, 20)
2354*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(foo_script(a), foo_trace(a))
2355*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(foo_trace(a), foo(a))
2356*da0073e9SAndroid Build Coastguard Worker        self.assertNotEqual(foo_trace(a), foo_trace(b))
2357*da0073e9SAndroid Build Coastguard Worker
2358*da0073e9SAndroid Build Coastguard Worker    def test_trace_hierarchy(self):
2359*da0073e9SAndroid Build Coastguard Worker        # Test that we preserve the module hierarchy for a ScriptModule
2360*da0073e9SAndroid Build Coastguard Worker        # submodule during tracing
2361*da0073e9SAndroid Build Coastguard Worker
2362*da0073e9SAndroid Build Coastguard Worker        class AnotherScriptMod(torch.jit.ScriptModule):
2363*da0073e9SAndroid Build Coastguard Worker            def __init__(self) -> None:
2364*da0073e9SAndroid Build Coastguard Worker                super().__init__()
2365*da0073e9SAndroid Build Coastguard Worker                self.param = torch.nn.Parameter(torch.rand(1, 2, 3))
2366*da0073e9SAndroid Build Coastguard Worker
2367*da0073e9SAndroid Build Coastguard Worker            @torch.jit.script_method
2368*da0073e9SAndroid Build Coastguard Worker            def bar(self):
2369*da0073e9SAndroid Build Coastguard Worker                return torch.zeros(4, 5)
2370*da0073e9SAndroid Build Coastguard Worker
2371*da0073e9SAndroid Build Coastguard Worker        class SomeScriptMod(torch.jit.ScriptModule):
2372*da0073e9SAndroid Build Coastguard Worker            def __init__(self) -> None:
2373*da0073e9SAndroid Build Coastguard Worker                super().__init__()
2374*da0073e9SAndroid Build Coastguard Worker                self.asm = AnotherScriptMod()
2375*da0073e9SAndroid Build Coastguard Worker
2376*da0073e9SAndroid Build Coastguard Worker            @torch.jit.script_method
2377*da0073e9SAndroid Build Coastguard Worker            def foo(self):
2378*da0073e9SAndroid Build Coastguard Worker                return torch.zeros(3, 4)
2379*da0073e9SAndroid Build Coastguard Worker
2380*da0073e9SAndroid Build Coastguard Worker            @torch.jit.script_method
2381*da0073e9SAndroid Build Coastguard Worker            def bar(self):
2382*da0073e9SAndroid Build Coastguard Worker                return torch.zeros(4, 3)
2383*da0073e9SAndroid Build Coastguard Worker
2384*da0073e9SAndroid Build Coastguard Worker        class TraceMe(torch.nn.Module):
2385*da0073e9SAndroid Build Coastguard Worker            def __init__(self) -> None:
2386*da0073e9SAndroid Build Coastguard Worker                super().__init__()
2387*da0073e9SAndroid Build Coastguard Worker                self.ssm = SomeScriptMod()
2388*da0073e9SAndroid Build Coastguard Worker
2389*da0073e9SAndroid Build Coastguard Worker            def forward(self, x):
2390*da0073e9SAndroid Build Coastguard Worker                return self.ssm.bar() + x
2391*da0073e9SAndroid Build Coastguard Worker
2392*da0073e9SAndroid Build Coastguard Worker        orig = TraceMe()
2393*da0073e9SAndroid Build Coastguard Worker        traced = torch.jit.trace(orig, (torch.rand(4, 3),))
2394*da0073e9SAndroid Build Coastguard Worker        # for each of these checks, check that *BOTH* the underlying
2395*da0073e9SAndroid Build Coastguard Worker        # _C.ScriptModule object has the expected method/param, as well as the
2396*da0073e9SAndroid Build Coastguard Worker        # Python object that wraps it.
2397*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(traced.ssm._c._has_method("foo"))
2398*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(hasattr(traced.ssm, "foo"))
2399*da0073e9SAndroid Build Coastguard Worker
2400*da0073e9SAndroid Build Coastguard Worker        imported = self.getExportImportCopy(traced)
2401*da0073e9SAndroid Build Coastguard Worker
2402*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(imported.ssm._c._has_method("foo"))
2403*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(hasattr(imported.ssm, "foo"))
2404*da0073e9SAndroid Build Coastguard Worker
2405*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(imported.ssm.asm._c._has_method("bar"))
2406*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(hasattr(imported.ssm.asm, "bar"))
2407*da0073e9SAndroid Build Coastguard Worker
2408*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(hasattr(imported.ssm.asm, "param"))
2409*da0073e9SAndroid Build Coastguard Worker
2410*da0073e9SAndroid Build Coastguard Worker    def test_trace_parameter(self):
2411*da0073e9SAndroid Build Coastguard Worker        class Param(nn.Module):
2412*da0073e9SAndroid Build Coastguard Worker            def __init__(self) -> None:
2413*da0073e9SAndroid Build Coastguard Worker                super().__init__()
2414*da0073e9SAndroid Build Coastguard Worker                self.register_parameter("bias", nn.Parameter(torch.empty(4, 4)))
2415*da0073e9SAndroid Build Coastguard Worker
2416*da0073e9SAndroid Build Coastguard Worker            def forward(self, x):
2417*da0073e9SAndroid Build Coastguard Worker                return x
2418*da0073e9SAndroid Build Coastguard Worker
2419*da0073e9SAndroid Build Coastguard Worker        class M3(torch.jit.ScriptModule):
2420*da0073e9SAndroid Build Coastguard Worker            def __init__(self, model):
2421*da0073e9SAndroid Build Coastguard Worker                super().__init__()
2422*da0073e9SAndroid Build Coastguard Worker                self.traced = torch.jit.trace(model, (torch.rand(3, 3)))
2423*da0073e9SAndroid Build Coastguard Worker
2424*da0073e9SAndroid Build Coastguard Worker            @torch.jit.script_method
2425*da0073e9SAndroid Build Coastguard Worker            def forward(self, x):
2426*da0073e9SAndroid Build Coastguard Worker                return self.traced(x)
2427*da0073e9SAndroid Build Coastguard Worker
2428*da0073e9SAndroid Build Coastguard Worker        class M2(nn.Module):
2429*da0073e9SAndroid Build Coastguard Worker            def __init__(self, model):
2430*da0073e9SAndroid Build Coastguard Worker                super().__init__()
2431*da0073e9SAndroid Build Coastguard Worker                self.module = M3(model)
2432*da0073e9SAndroid Build Coastguard Worker
2433*da0073e9SAndroid Build Coastguard Worker            def forward(self, x):
2434*da0073e9SAndroid Build Coastguard Worker                return self.module(x)
2435*da0073e9SAndroid Build Coastguard Worker
2436*da0073e9SAndroid Build Coastguard Worker        class M1(torch.jit.ScriptModule):
2437*da0073e9SAndroid Build Coastguard Worker            def __init__(self, model):
2438*da0073e9SAndroid Build Coastguard Worker                super().__init__()
2439*da0073e9SAndroid Build Coastguard Worker                self.traced = torch.jit.trace(M2(model), (torch.rand(3, 3)))
2440*da0073e9SAndroid Build Coastguard Worker
2441*da0073e9SAndroid Build Coastguard Worker            @torch.jit.script_method
2442*da0073e9SAndroid Build Coastguard Worker            def forward(self, x):
2443*da0073e9SAndroid Build Coastguard Worker                return self.traced(x)
2444*da0073e9SAndroid Build Coastguard Worker
2445*da0073e9SAndroid Build Coastguard Worker        with torch.jit.optimized_execution(False):
2446*da0073e9SAndroid Build Coastguard Worker            module = M1(Param())
2447*da0073e9SAndroid Build Coastguard Worker            f = io.BytesIO()
2448*da0073e9SAndroid Build Coastguard Worker            torch.jit.save(module, f)
2449*da0073e9SAndroid Build Coastguard Worker
2450*da0073e9SAndroid Build Coastguard Worker    @_tmp_donotuse_dont_inline_everything
2451*da0073e9SAndroid Build Coastguard Worker    def test_call_script_fn_from_traced_module(self):
2452*da0073e9SAndroid Build Coastguard Worker        @torch.jit.script
2453*da0073e9SAndroid Build Coastguard Worker        def scripted_fn(x):
2454*da0073e9SAndroid Build Coastguard Worker            return torch.neg(x)
2455*da0073e9SAndroid Build Coastguard Worker
2456*da0073e9SAndroid Build Coastguard Worker        class TracedModule(torch.nn.Module):
2457*da0073e9SAndroid Build Coastguard Worker            def __init__(self) -> None:
2458*da0073e9SAndroid Build Coastguard Worker                super().__init__()
2459*da0073e9SAndroid Build Coastguard Worker                self.param = torch.nn.Parameter(torch.rand(4, 5))
2460*da0073e9SAndroid Build Coastguard Worker
2461*da0073e9SAndroid Build Coastguard Worker            def forward(self, x):
2462*da0073e9SAndroid Build Coastguard Worker                return scripted_fn(torch.mm(x, self.param))
2463*da0073e9SAndroid Build Coastguard Worker
2464*da0073e9SAndroid Build Coastguard Worker        tm = torch.jit.trace(TracedModule(), torch.rand(3, 4))
2465*da0073e9SAndroid Build Coastguard Worker        FileCheck().check("aten::mm").check('name="scripted_fn"').check(
2466*da0073e9SAndroid Build Coastguard Worker            "prim::CallFunction"
2467*da0073e9SAndroid Build Coastguard Worker        ).run(str(tm.graph))
2468*da0073e9SAndroid Build Coastguard Worker
2469*da0073e9SAndroid Build Coastguard Worker    @_tmp_donotuse_dont_inline_everything
2470*da0073e9SAndroid Build Coastguard Worker    def test_call_script_module_from_traced_module(self):
2471*da0073e9SAndroid Build Coastguard Worker        class ScriptMod(torch.jit.ScriptModule):
2472*da0073e9SAndroid Build Coastguard Worker            def __init__(self) -> None:
2473*da0073e9SAndroid Build Coastguard Worker                super().__init__()
2474*da0073e9SAndroid Build Coastguard Worker                self.param_foo = torch.nn.Parameter(torch.rand(5, 7))
2475*da0073e9SAndroid Build Coastguard Worker
2476*da0073e9SAndroid Build Coastguard Worker            @torch.jit.script_method
2477*da0073e9SAndroid Build Coastguard Worker            def forward(self, x):
2478*da0073e9SAndroid Build Coastguard Worker                return torch.mm(x, self.param_foo)
2479*da0073e9SAndroid Build Coastguard Worker
2480*da0073e9SAndroid Build Coastguard Worker        class TracedModule(torch.nn.Module):
2481*da0073e9SAndroid Build Coastguard Worker            def __init__(self) -> None:
2482*da0073e9SAndroid Build Coastguard Worker                super().__init__()
2483*da0073e9SAndroid Build Coastguard Worker                self.param = torch.nn.Parameter(torch.rand(4, 5))
2484*da0073e9SAndroid Build Coastguard Worker                self.mod = ScriptMod()
2485*da0073e9SAndroid Build Coastguard Worker
2486*da0073e9SAndroid Build Coastguard Worker            def forward(self, x):
2487*da0073e9SAndroid Build Coastguard Worker                return self.mod(torch.mm(x, self.param)) + 1.0
2488*da0073e9SAndroid Build Coastguard Worker
2489*da0073e9SAndroid Build Coastguard Worker        tm = torch.jit.trace(TracedModule(), torch.rand(3, 4))
2490*da0073e9SAndroid Build Coastguard Worker
2491*da0073e9SAndroid Build Coastguard Worker        FileCheck().check("aten::mm").check("prim::CallMethod").check_same(
2492*da0073e9SAndroid Build Coastguard Worker            "forward"
2493*da0073e9SAndroid Build Coastguard Worker        ).check("aten::add").run(str(tm.graph))
2494*da0073e9SAndroid Build Coastguard Worker
2495*da0073e9SAndroid Build Coastguard Worker    @_tmp_donotuse_dont_inline_everything
2496*da0073e9SAndroid Build Coastguard Worker    def test_call_traced_fn_from_script_fn(self):
2497*da0073e9SAndroid Build Coastguard Worker        @_trace(torch.rand(3, 4))
2498*da0073e9SAndroid Build Coastguard Worker        def traced_fn(x):
2499*da0073e9SAndroid Build Coastguard Worker            return torch.neg(x)
2500*da0073e9SAndroid Build Coastguard Worker
2501*da0073e9SAndroid Build Coastguard Worker        @torch.jit.script
2502*da0073e9SAndroid Build Coastguard Worker        def script_fn(x):
2503*da0073e9SAndroid Build Coastguard Worker            return traced_fn(x) + 1
2504*da0073e9SAndroid Build Coastguard Worker
2505*da0073e9SAndroid Build Coastguard Worker        FileCheck().check("prim::CallFunction").check("aten::add").run(
2506*da0073e9SAndroid Build Coastguard Worker            str(script_fn.graph)
2507*da0073e9SAndroid Build Coastguard Worker        )
2508*da0073e9SAndroid Build Coastguard Worker
2509*da0073e9SAndroid Build Coastguard Worker    def test_call_traced_mod_from_script_fn(self):
2510*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(
2511*da0073e9SAndroid Build Coastguard Worker            RuntimeError,
2512*da0073e9SAndroid Build Coastguard Worker            "Cannot call a ScriptModule that is not a submodule of the caller",
2513*da0073e9SAndroid Build Coastguard Worker        ):
2514*da0073e9SAndroid Build Coastguard Worker
2515*da0073e9SAndroid Build Coastguard Worker            class TracedModule(torch.nn.Module):
2516*da0073e9SAndroid Build Coastguard Worker                def forward(self, x):
2517*da0073e9SAndroid Build Coastguard Worker                    return torch.mm(x, torch.zeros(4, 3))
2518*da0073e9SAndroid Build Coastguard Worker
2519*da0073e9SAndroid Build Coastguard Worker            tm = torch.jit.trace(TracedModule(), torch.rand(3, 4))
2520*da0073e9SAndroid Build Coastguard Worker
2521*da0073e9SAndroid Build Coastguard Worker            @torch.jit.script
2522*da0073e9SAndroid Build Coastguard Worker            def script_fn(x):
2523*da0073e9SAndroid Build Coastguard Worker                return tm(x) + 1
2524*da0073e9SAndroid Build Coastguard Worker
2525*da0073e9SAndroid Build Coastguard Worker    @_tmp_donotuse_dont_inline_everything
2526*da0073e9SAndroid Build Coastguard Worker    def test_call_tracing_fn_from_script_module(self):
2527*da0073e9SAndroid Build Coastguard Worker        @_trace(torch.rand(3, 3))
2528*da0073e9SAndroid Build Coastguard Worker        def traced_fn(x):
2529*da0073e9SAndroid Build Coastguard Worker            return torch.neg(x)
2530*da0073e9SAndroid Build Coastguard Worker
2531*da0073e9SAndroid Build Coastguard Worker        class ScriptMod(torch.jit.ScriptModule):
2532*da0073e9SAndroid Build Coastguard Worker            def __init__(self) -> None:
2533*da0073e9SAndroid Build Coastguard Worker                super().__init__()
2534*da0073e9SAndroid Build Coastguard Worker                self.param = torch.nn.Parameter(torch.rand(4, 3))
2535*da0073e9SAndroid Build Coastguard Worker
2536*da0073e9SAndroid Build Coastguard Worker            @torch.jit.script_method
2537*da0073e9SAndroid Build Coastguard Worker            def forward(self, x):
2538*da0073e9SAndroid Build Coastguard Worker                return traced_fn(torch.mm(x, self.param))
2539*da0073e9SAndroid Build Coastguard Worker
2540*da0073e9SAndroid Build Coastguard Worker        sm = ScriptMod()
2541*da0073e9SAndroid Build Coastguard Worker        FileCheck().check("aten::mm").check("prim::CallFunction").run(
2542*da0073e9SAndroid Build Coastguard Worker            str(sm.forward.graph)
2543*da0073e9SAndroid Build Coastguard Worker        )
2544*da0073e9SAndroid Build Coastguard Worker
2545*da0073e9SAndroid Build Coastguard Worker    @_tmp_donotuse_dont_inline_everything
2546*da0073e9SAndroid Build Coastguard Worker    def test_call_tracing_mod_from_script_module(self):
2547*da0073e9SAndroid Build Coastguard Worker        class TracedMod(torch.nn.Module):
2548*da0073e9SAndroid Build Coastguard Worker            def __init__(self) -> None:
2549*da0073e9SAndroid Build Coastguard Worker                super().__init__()
2550*da0073e9SAndroid Build Coastguard Worker                self.param = torch.nn.Parameter(torch.rand(3, 5))
2551*da0073e9SAndroid Build Coastguard Worker
2552*da0073e9SAndroid Build Coastguard Worker            def forward(self, x):
2553*da0073e9SAndroid Build Coastguard Worker                return torch.mm(x, self.param)
2554*da0073e9SAndroid Build Coastguard Worker
2555*da0073e9SAndroid Build Coastguard Worker        class ScriptMod(torch.jit.ScriptModule):
2556*da0073e9SAndroid Build Coastguard Worker            def __init__(self) -> None:
2557*da0073e9SAndroid Build Coastguard Worker                super().__init__()
2558*da0073e9SAndroid Build Coastguard Worker                self.param = torch.nn.Parameter(torch.rand(4, 3))
2559*da0073e9SAndroid Build Coastguard Worker                self.tm = torch.jit.trace(TracedMod(), torch.rand(3, 3))
2560*da0073e9SAndroid Build Coastguard Worker
2561*da0073e9SAndroid Build Coastguard Worker            @torch.jit.script_method
2562*da0073e9SAndroid Build Coastguard Worker            def forward(self, x):
2563*da0073e9SAndroid Build Coastguard Worker                return self.tm(torch.mm(x, self.param))
2564*da0073e9SAndroid Build Coastguard Worker
2565*da0073e9SAndroid Build Coastguard Worker        sm = ScriptMod()
2566*da0073e9SAndroid Build Coastguard Worker        FileCheck().check("aten::mm").check("prim::CallMethod").run(str(sm.graph))
2567*da0073e9SAndroid Build Coastguard Worker
2568*da0073e9SAndroid Build Coastguard Worker    def test_script_inline_trace_multiple_args(self):
2569*da0073e9SAndroid Build Coastguard Worker        class M(torch.nn.Module):
2570*da0073e9SAndroid Build Coastguard Worker            def forward(self, input, input2):
2571*da0073e9SAndroid Build Coastguard Worker                return input + input2
2572*da0073e9SAndroid Build Coastguard Worker
2573*da0073e9SAndroid Build Coastguard Worker        class M2(torch.jit.ScriptModule):
2574*da0073e9SAndroid Build Coastguard Worker            def __init__(self) -> None:
2575*da0073e9SAndroid Build Coastguard Worker                super().__init__()
2576*da0073e9SAndroid Build Coastguard Worker                self.m = torch.jit.trace(M(), (torch.zeros(4, 3), torch.zeros(4, 3)))
2577*da0073e9SAndroid Build Coastguard Worker
2578*da0073e9SAndroid Build Coastguard Worker            @torch.jit.script_method
2579*da0073e9SAndroid Build Coastguard Worker            def forward(self, inp):
2580*da0073e9SAndroid Build Coastguard Worker                return self.m(inp, inp)
2581*da0073e9SAndroid Build Coastguard Worker
2582*da0073e9SAndroid Build Coastguard Worker        with torch.jit.optimized_execution(False):
2583*da0073e9SAndroid Build Coastguard Worker            m2 = M2()
2584*da0073e9SAndroid Build Coastguard Worker            m2(torch.zeros(4, 3))
2585*da0073e9SAndroid Build Coastguard Worker
2586*da0073e9SAndroid Build Coastguard Worker    def test_trace_dict_mix_script(self):
2587*da0073e9SAndroid Build Coastguard Worker        class testB(torch.nn.Module):
2588*da0073e9SAndroid Build Coastguard Worker            def __init__(self) -> None:
2589*da0073e9SAndroid Build Coastguard Worker                super().__init__()
2590*da0073e9SAndroid Build Coastguard Worker                self.linear = torch.nn.Linear(2, 2)
2591*da0073e9SAndroid Build Coastguard Worker
2592*da0073e9SAndroid Build Coastguard Worker            def forward(self, feature_map: Dict[str, List[Tensor]]) -> Tensor:
2593*da0073e9SAndroid Build Coastguard Worker                output = []
2594*da0073e9SAndroid Build Coastguard Worker                for j in feature_map.values():
2595*da0073e9SAndroid Build Coastguard Worker                    output.append(self.linear(j[0]))
2596*da0073e9SAndroid Build Coastguard Worker
2597*da0073e9SAndroid Build Coastguard Worker                return torch.stack(output)
2598*da0073e9SAndroid Build Coastguard Worker
2599*da0073e9SAndroid Build Coastguard Worker        class testA(torch.nn.Module):
2600*da0073e9SAndroid Build Coastguard Worker            def __init__(self) -> None:
2601*da0073e9SAndroid Build Coastguard Worker                super().__init__()
2602*da0073e9SAndroid Build Coastguard Worker                self.b = torch.jit.script(testB())
2603*da0073e9SAndroid Build Coastguard Worker
2604*da0073e9SAndroid Build Coastguard Worker            def forward(self, input_map: Dict[str, List[Tensor]]) -> Tensor:
2605*da0073e9SAndroid Build Coastguard Worker                feature_map = {}
2606*da0073e9SAndroid Build Coastguard Worker                for i, j in input_map.items():
2607*da0073e9SAndroid Build Coastguard Worker                    feature_map[i] = [j[0]]
2608*da0073e9SAndroid Build Coastguard Worker
2609*da0073e9SAndroid Build Coastguard Worker                return self.b(feature_map)
2610*da0073e9SAndroid Build Coastguard Worker
2611*da0073e9SAndroid Build Coastguard Worker        input_map = {
2612*da0073e9SAndroid Build Coastguard Worker            "1": [torch.rand(2, 2), torch.rand(2, 2)],
2613*da0073e9SAndroid Build Coastguard Worker            "3": [torch.rand(2, 2), torch.rand(2, 2)],
2614*da0073e9SAndroid Build Coastguard Worker        }
2615*da0073e9SAndroid Build Coastguard Worker        model = testA()
2616*da0073e9SAndroid Build Coastguard Worker        traced_model = torch.jit.trace(model, input_map)
2617*da0073e9SAndroid Build Coastguard Worker        new_input_map = {
2618*da0073e9SAndroid Build Coastguard Worker            "1": [torch.rand(2, 2), torch.randn(2, 2)],
2619*da0073e9SAndroid Build Coastguard Worker            "3": [torch.rand(2, 2), torch.rand(2, 2)],
2620*da0073e9SAndroid Build Coastguard Worker        }
2621*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(model(new_input_map), traced_model(new_input_map))
2622*da0073e9SAndroid Build Coastguard Worker
2623*da0073e9SAndroid Build Coastguard Worker    def test_trace_script_returning_complex_dict(self):
2624*da0073e9SAndroid Build Coastguard Worker        """Tracing over a script function returning a dictionary should work.
2625*da0073e9SAndroid Build Coastguard Worker        The dictionary can should be able to contain other containers (like a tuple) recursively.
2626*da0073e9SAndroid Build Coastguard Worker        """
2627*da0073e9SAndroid Build Coastguard Worker
2628*da0073e9SAndroid Build Coastguard Worker        class ReturnsDict(torch.nn.Module):
2629*da0073e9SAndroid Build Coastguard Worker            def forward(
2630*da0073e9SAndroid Build Coastguard Worker                self,
2631*da0073e9SAndroid Build Coastguard Worker                id_score_list: Dict[
2632*da0073e9SAndroid Build Coastguard Worker                    str, Tuple[torch.Tensor, torch.Tensor, torch.Tensor]
2633*da0073e9SAndroid Build Coastguard Worker                ],
2634*da0073e9SAndroid Build Coastguard Worker            ) -> Dict[str, Tuple[torch.Tensor, torch.Tensor, torch.Tensor]]:
2635*da0073e9SAndroid Build Coastguard Worker                # do some random operations and then return a dict of the same structure
2636*da0073e9SAndroid Build Coastguard Worker                v = id_score_list["1000"]
2637*da0073e9SAndroid Build Coastguard Worker                idx_keys = v[1] - 1500000
2638*da0073e9SAndroid Build Coastguard Worker                weights = v[2]
2639*da0073e9SAndroid Build Coastguard Worker                result = {"1000": (v[0], idx_keys, weights)}
2640*da0073e9SAndroid Build Coastguard Worker                return result
2641*da0073e9SAndroid Build Coastguard Worker
2642*da0073e9SAndroid Build Coastguard Worker        class ChecksDict(torch.nn.Module):
2643*da0073e9SAndroid Build Coastguard Worker            def forward(
2644*da0073e9SAndroid Build Coastguard Worker                self, input: Dict[str, Tuple[torch.Tensor, torch.Tensor, torch.Tensor]]
2645*da0073e9SAndroid Build Coastguard Worker            ):
2646*da0073e9SAndroid Build Coastguard Worker                v = input["1000"]
2647*da0073e9SAndroid Build Coastguard Worker                return v[1] + 1
2648*da0073e9SAndroid Build Coastguard Worker
2649*da0073e9SAndroid Build Coastguard Worker        class TestModule(torch.nn.Module):
2650*da0073e9SAndroid Build Coastguard Worker            def __init__(self, checks_dict, returns_dict):
2651*da0073e9SAndroid Build Coastguard Worker                super().__init__()
2652*da0073e9SAndroid Build Coastguard Worker                self.checks_dict = checks_dict
2653*da0073e9SAndroid Build Coastguard Worker                self.returns_dict = returns_dict
2654*da0073e9SAndroid Build Coastguard Worker
2655*da0073e9SAndroid Build Coastguard Worker            def forward(
2656*da0073e9SAndroid Build Coastguard Worker                self, input: Dict[str, Tuple[torch.Tensor, torch.Tensor, torch.Tensor]]
2657*da0073e9SAndroid Build Coastguard Worker            ):
2658*da0073e9SAndroid Build Coastguard Worker                foo = self.returns_dict(input)
2659*da0073e9SAndroid Build Coastguard Worker                return self.checks_dict(foo)
2660*da0073e9SAndroid Build Coastguard Worker
2661*da0073e9SAndroid Build Coastguard Worker        input1 = {
2662*da0073e9SAndroid Build Coastguard Worker            "1000": (
2663*da0073e9SAndroid Build Coastguard Worker                torch.tensor([0]),
2664*da0073e9SAndroid Build Coastguard Worker                torch.tensor([], dtype=torch.int64),
2665*da0073e9SAndroid Build Coastguard Worker                torch.tensor([]),
2666*da0073e9SAndroid Build Coastguard Worker            )
2667*da0073e9SAndroid Build Coastguard Worker        }
2668*da0073e9SAndroid Build Coastguard Worker
2669*da0073e9SAndroid Build Coastguard Worker        input2 = {
2670*da0073e9SAndroid Build Coastguard Worker            "1000": (
2671*da0073e9SAndroid Build Coastguard Worker                torch.tensor([0]),
2672*da0073e9SAndroid Build Coastguard Worker                torch.tensor([1500000, 1500004], dtype=torch.int64),
2673*da0073e9SAndroid Build Coastguard Worker                torch.tensor([2.0, 3.0]),
2674*da0073e9SAndroid Build Coastguard Worker            )
2675*da0073e9SAndroid Build Coastguard Worker        }
2676*da0073e9SAndroid Build Coastguard Worker
2677*da0073e9SAndroid Build Coastguard Worker        checks_dict = torch.jit.script(ChecksDict())
2678*da0073e9SAndroid Build Coastguard Worker        returns_dict = torch.jit.script(ReturnsDict())
2679*da0073e9SAndroid Build Coastguard Worker        eager_module = TestModule(checks_dict, returns_dict)
2680*da0073e9SAndroid Build Coastguard Worker        traced_module = torch.jit.trace(eager_module, input1)
2681*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(traced_module(input1), eager_module(input1))
2682*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(traced_module(input2), eager_module(input2))
2683*da0073e9SAndroid Build Coastguard Worker
2684*da0073e9SAndroid Build Coastguard Worker    def test_trace_returning_dict_with_tensor_tuples(self):
2685*da0073e9SAndroid Build Coastguard Worker        """Tracing over a module returning a dictionary whose values are tuples of tensors
2686*da0073e9SAndroid Build Coastguard Worker        should work.
2687*da0073e9SAndroid Build Coastguard Worker        """
2688*da0073e9SAndroid Build Coastguard Worker
2689*da0073e9SAndroid Build Coastguard Worker        class ReturnsDict(torch.nn.Module):
2690*da0073e9SAndroid Build Coastguard Worker            def forward(
2691*da0073e9SAndroid Build Coastguard Worker                self, k: torch.Tensor, v: torch.Tensor
2692*da0073e9SAndroid Build Coastguard Worker            ) -> Dict[str, Tuple[torch.Tensor, torch.Tensor]]:
2693*da0073e9SAndroid Build Coastguard Worker                x = 2 * k
2694*da0073e9SAndroid Build Coastguard Worker                y = 3 * v
2695*da0073e9SAndroid Build Coastguard Worker                result = {"imakey": (x, y)}
2696*da0073e9SAndroid Build Coastguard Worker                return result
2697*da0073e9SAndroid Build Coastguard Worker
2698*da0073e9SAndroid Build Coastguard Worker        class ReturnsBadDict(torch.nn.Module):
2699*da0073e9SAndroid Build Coastguard Worker            def forward(
2700*da0073e9SAndroid Build Coastguard Worker                self, k: torch.Tensor, v: torch.Tensor
2701*da0073e9SAndroid Build Coastguard Worker            ) -> Dict[str, Tuple[torch.Tensor, float]]:
2702*da0073e9SAndroid Build Coastguard Worker                x = 2 * k
2703*da0073e9SAndroid Build Coastguard Worker                result = {"imakey": (x, 1)}
2704*da0073e9SAndroid Build Coastguard Worker                return result
2705*da0073e9SAndroid Build Coastguard Worker
2706*da0073e9SAndroid Build Coastguard Worker        mod = ReturnsDict()
2707*da0073e9SAndroid Build Coastguard Worker        traced_module = torch.jit.trace(
2708*da0073e9SAndroid Build Coastguard Worker            mod, [torch.ones(1), torch.ones(1)], strict=False
2709*da0073e9SAndroid Build Coastguard Worker        )
2710*da0073e9SAndroid Build Coastguard Worker        out = traced_module(torch.ones(1), torch.ones(1))
2711*da0073e9SAndroid Build Coastguard Worker        expected = {"imakey": (torch.tensor([2.0]), torch.tensor([3.0]))}
2712*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(out, expected)
2713*da0073e9SAndroid Build Coastguard Worker
2714*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(
2715*da0073e9SAndroid Build Coastguard Worker            RuntimeError, "cannot be understood by the tracer, only outputs matching"
2716*da0073e9SAndroid Build Coastguard Worker        ):
2717*da0073e9SAndroid Build Coastguard Worker            mod = ReturnsBadDict()
2718*da0073e9SAndroid Build Coastguard Worker            traced_module = torch.jit.trace(
2719*da0073e9SAndroid Build Coastguard Worker                mod, [torch.ones(1), torch.ones(1)], strict=False
2720*da0073e9SAndroid Build Coastguard Worker            )
2721*da0073e9SAndroid Build Coastguard Worker
2722*da0073e9SAndroid Build Coastguard Worker    def test_trace_linear(self):
2723*da0073e9SAndroid Build Coastguard Worker        m = torch.nn.Linear(20, 20)
2724*da0073e9SAndroid Build Coastguard Worker        inp = torch.rand([20, 20])
2725*da0073e9SAndroid Build Coastguard Worker        self.checkTrace(m, (inp,))
2726*da0073e9SAndroid Build Coastguard Worker        g = torch.jit.trace(m, (inp,)).graph
2727*da0073e9SAndroid Build Coastguard Worker        FileCheck().check("aten::linear").run(g)
2728*da0073e9SAndroid Build Coastguard Worker
2729*da0073e9SAndroid Build Coastguard Worker    def test_traced_module_implements_interface(self):
2730*da0073e9SAndroid Build Coastguard Worker        @torch.jit.interface
2731*da0073e9SAndroid Build Coastguard Worker        class TestModuleInterface(nn.Module):
2732*da0073e9SAndroid Build Coastguard Worker            def forward(
2733*da0073e9SAndroid Build Coastguard Worker                self, first_arg: torch.Tensor, second_arg: torch.Tensor
2734*da0073e9SAndroid Build Coastguard Worker            ) -> torch.Tensor:
2735*da0073e9SAndroid Build Coastguard Worker                pass
2736*da0073e9SAndroid Build Coastguard Worker
2737*da0073e9SAndroid Build Coastguard Worker        make_global(TestModuleInterface)
2738*da0073e9SAndroid Build Coastguard Worker
2739*da0073e9SAndroid Build Coastguard Worker        class TestModule(nn.Module):
2740*da0073e9SAndroid Build Coastguard Worker            def __init__(self) -> None:
2741*da0073e9SAndroid Build Coastguard Worker                super().__init__()
2742*da0073e9SAndroid Build Coastguard Worker                self.conv = nn.Conv2d(1, 1, 3)
2743*da0073e9SAndroid Build Coastguard Worker
2744*da0073e9SAndroid Build Coastguard Worker            def forward(
2745*da0073e9SAndroid Build Coastguard Worker                self, first_arg: torch.Tensor, second_arg: torch.Tensor
2746*da0073e9SAndroid Build Coastguard Worker            ) -> torch.Tensor:
2747*da0073e9SAndroid Build Coastguard Worker                return self.conv(first_arg) + second_arg
2748*da0073e9SAndroid Build Coastguard Worker
2749*da0073e9SAndroid Build Coastguard Worker        def fn_takes_interface(x: TestModuleInterface):
2750*da0073e9SAndroid Build Coastguard Worker            ones = torch.ones(1, 1, 3, 3)
2751*da0073e9SAndroid Build Coastguard Worker            return x.forward(ones, ones)
2752*da0073e9SAndroid Build Coastguard Worker
2753*da0073e9SAndroid Build Coastguard Worker        scripted_test_module = torch.jit.script(TestModule())
2754*da0073e9SAndroid Build Coastguard Worker        self.checkScript(fn_takes_interface, (scripted_test_module,))
2755*da0073e9SAndroid Build Coastguard Worker
2756*da0073e9SAndroid Build Coastguard Worker    def test_traced_module_contains_scripted_interface_types(self):
2757*da0073e9SAndroid Build Coastguard Worker        class LeafModule(torch.nn.Module):
2758*da0073e9SAndroid Build Coastguard Worker            def __init__(self) -> None:
2759*da0073e9SAndroid Build Coastguard Worker                super().__init__()
2760*da0073e9SAndroid Build Coastguard Worker                self.weight = torch.nn.Parameter(torch.rand(19))
2761*da0073e9SAndroid Build Coastguard Worker
2762*da0073e9SAndroid Build Coastguard Worker            def forward(self, input: torch.Tensor):
2763*da0073e9SAndroid Build Coastguard Worker                return input + self.weight
2764*da0073e9SAndroid Build Coastguard Worker
2765*da0073e9SAndroid Build Coastguard Worker        class LowerModuleImpl(torch.nn.Module):
2766*da0073e9SAndroid Build Coastguard Worker            def __init__(self) -> None:
2767*da0073e9SAndroid Build Coastguard Worker                super().__init__()
2768*da0073e9SAndroid Build Coastguard Worker                self.leaf = LeafModule()
2769*da0073e9SAndroid Build Coastguard Worker
2770*da0073e9SAndroid Build Coastguard Worker            def forward(self, input: torch.Tensor) -> torch.Tensor:
2771*da0073e9SAndroid Build Coastguard Worker                return self.leaf(input)
2772*da0073e9SAndroid Build Coastguard Worker
2773*da0073e9SAndroid Build Coastguard Worker        @torch.jit.interface
2774*da0073e9SAndroid Build Coastguard Worker        class LowerModuleInterface(torch.nn.Module):
2775*da0073e9SAndroid Build Coastguard Worker            def forward(self, input: torch.Tensor) -> torch.Tensor:
2776*da0073e9SAndroid Build Coastguard Worker                pass
2777*da0073e9SAndroid Build Coastguard Worker
2778*da0073e9SAndroid Build Coastguard Worker        class MiddleModule(torch.nn.Module):
2779*da0073e9SAndroid Build Coastguard Worker            lower: LowerModuleInterface
2780*da0073e9SAndroid Build Coastguard Worker
2781*da0073e9SAndroid Build Coastguard Worker            def __init__(self, feature_processor_modules=None):
2782*da0073e9SAndroid Build Coastguard Worker                super().__init__()
2783*da0073e9SAndroid Build Coastguard Worker                self.lower = LowerModuleImpl()
2784*da0073e9SAndroid Build Coastguard Worker
2785*da0073e9SAndroid Build Coastguard Worker            def forward(self, input):
2786*da0073e9SAndroid Build Coastguard Worker                return self.lower(input)
2787*da0073e9SAndroid Build Coastguard Worker
2788*da0073e9SAndroid Build Coastguard Worker        class WrapperModule(torch.nn.Module):
2789*da0073e9SAndroid Build Coastguard Worker            def __init__(self, m):
2790*da0073e9SAndroid Build Coastguard Worker                super().__init__()
2791*da0073e9SAndroid Build Coastguard Worker                self.middle = m
2792*da0073e9SAndroid Build Coastguard Worker
2793*da0073e9SAndroid Build Coastguard Worker            def forward(self, input):
2794*da0073e9SAndroid Build Coastguard Worker                return self.middle(input)
2795*da0073e9SAndroid Build Coastguard Worker
2796*da0073e9SAndroid Build Coastguard Worker        class TopModule(torch.nn.Module):
2797*da0073e9SAndroid Build Coastguard Worker            def __init__(self) -> None:
2798*da0073e9SAndroid Build Coastguard Worker                super().__init__()
2799*da0073e9SAndroid Build Coastguard Worker                m = MiddleModule()
2800*da0073e9SAndroid Build Coastguard Worker                m = torch.jit.script(m)
2801*da0073e9SAndroid Build Coastguard Worker                self.sub1 = m
2802*da0073e9SAndroid Build Coastguard Worker                self.sub2 = WrapperModule(m)
2803*da0073e9SAndroid Build Coastguard Worker
2804*da0073e9SAndroid Build Coastguard Worker            def forward(self, input: torch.Tensor):
2805*da0073e9SAndroid Build Coastguard Worker                return self.sub1(input) + self.sub2(input)
2806*da0073e9SAndroid Build Coastguard Worker
2807*da0073e9SAndroid Build Coastguard Worker        top = TopModule()
2808*da0073e9SAndroid Build Coastguard Worker        top_example_input = torch.ones(1)
2809*da0073e9SAndroid Build Coastguard Worker        torch.jit.trace(top, top_example_input)
2810*da0073e9SAndroid Build Coastguard Worker
2811*da0073e9SAndroid Build Coastguard Worker    def test_jit_trace_callfunction_return_shapes(self):
2812*da0073e9SAndroid Build Coastguard Worker        # a torch.jit.script function gets inserted as a CallFunction node
2813*da0073e9SAndroid Build Coastguard Worker        @torch.jit.script
2814*da0073e9SAndroid Build Coastguard Worker        def inner_fn(x):
2815*da0073e9SAndroid Build Coastguard Worker            return torch.cat((x, x))
2816*da0073e9SAndroid Build Coastguard Worker
2817*da0073e9SAndroid Build Coastguard Worker        def outer_fn(x, y):
2818*da0073e9SAndroid Build Coastguard Worker            return inner_fn(x + y).relu()
2819*da0073e9SAndroid Build Coastguard Worker
2820*da0073e9SAndroid Build Coastguard Worker        x, y = [torch.rand((2, 2), dtype=torch.float) for _ in range(2)]
2821*da0073e9SAndroid Build Coastguard Worker        fn_t = torch.jit.trace(outer_fn, (x, y))
2822*da0073e9SAndroid Build Coastguard Worker
2823*da0073e9SAndroid Build Coastguard Worker        # expect that the CallFunction node return type has shape information on it.
2824*da0073e9SAndroid Build Coastguard Worker        FileCheck().check("Float").check("4, 2").check("CallFunction").run(fn_t.graph)
2825*da0073e9SAndroid Build Coastguard Worker        for n in fn_t.graph.nodes():
2826*da0073e9SAndroid Build Coastguard Worker            if n.kind() == "prim::CallFunction":
2827*da0073e9SAndroid Build Coastguard Worker                self.assertTrue(n.output().isCompleteTensor())
2828