xref: /aosp_15_r20/external/pytorch/test/jit/test_python_ir.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# Owner(s): ["oncall: jit"]
2
3import unittest
4
5import numpy as np
6
7import torch
8from torch.testing import FileCheck
9from torch.testing._internal.common_utils import IS_MACOS
10from torch.testing._internal.jit_utils import JitTestCase
11
12
13if __name__ == "__main__":
14    raise RuntimeError(
15        "This test file is not meant to be run directly, use:\n\n"
16        "\tpython test/test_jit.py TESTNAME\n\n"
17        "instead."
18    )
19
20
21class TestPythonIr(JitTestCase):
22    def test_param_strides(self):
23        def trace_me(arg):
24            return arg
25
26        t = torch.zeros(1, 3, 16, 16)
27        traced = torch.jit.trace(trace_me, t)
28        value = list(traced.graph.param_node().outputs())[0]
29        real_strides = list(t.stride())
30        type_strides = value.type().strides()
31        self.assertEqual(real_strides, type_strides)
32
33    def test_permute_inputs_binding(self):
34        @torch.jit.script
35        def foo(i, j, k):
36            pass
37
38        g = foo.graph
39
40        idxs = []
41        for i, inp in enumerate(g.inputs()):
42            inp.setDebugName(f"inp{i}")
43            idxs.append(i)
44
45        permuted_idxs = list(np.random.permutation(idxs))
46        g.permuteInputs(permuted_idxs)
47        for i, inp in enumerate(g.inputs()):
48            self.assertEqual(f"inp{permuted_idxs[i]}", inp.debugName())
49
50    @unittest.skipIf(IS_MACOS, "Failing on MacOS only")
51    def test_python_ir_utils(self):
52        @torch.jit.script
53        def foo(inp):
54            x = inp + 1
55            y = x / 2
56            z = y * y
57            return z
58
59        add_node = foo.graph.findNode("aten::add")
60        div_node = foo.graph.findNode("aten::div")
61
62        with foo.graph.insert_point_guard(add_node):
63            with foo.graph.insert_point_guard(div_node):
64                foo.graph.insertConstant("goodbye")
65            foo.graph.insertConstant("hello")
66        with foo.graph.insert_point_guard(foo.graph.findNode("aten::mul")):
67            foo.graph.insertConstant("hello")
68        FileCheck().check("hello").check("goodbye").check("hello").run(foo.graph)
69
70        self.assertTrue(add_node.matches(add_node.schema()))
71        self.assertFalse(add_node.matches(div_node.schema()))
72
73    def test_python_ir_utils_graph(self):
74        @torch.jit.script
75        def unrolled_mul(x: torch.Tensor, y: int):
76            out = x
77            for _ in range(y - 1):
78                out = out + x
79            return out
80
81        @torch.jit.script
82        def foo(x):
83            return x * 4
84
85        g = foo.graph
86        muls = g.findAllNodes("aten::mul")
87        scalar_muls = filter(
88            lambda x: x.matches("aten::mul(Tensor self, Scalar other) -> Tensor"), muls
89        )
90        mul_constant_int = filter(
91            lambda x: isinstance(list(x.inputs())[1].toIValue(), int), scalar_muls
92        )
93        for mul in mul_constant_int:
94            with g.insert_point_guard(mul):
95                outputs = g.insertGraph(unrolled_mul.graph, list(mul.inputs()))
96                assert len(outputs) == len(list(mul.outputs()))
97                for new_out, old_out in zip(outputs, g.outputs()):
98                    old_out.replaceAllUsesWith(new_out)
99                mul.destroy()
100
101        FileCheck().check_not("aten::mul").check("aten::add").run(foo.graph)
102        self.assertEqual(foo(torch.ones([2, 2])), torch.ones([2, 2]) * 4)
103