xref: /aosp_15_r20/external/pytorch/test/onnx/test_pytorch_onnx_onnxruntime_cuda.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# Owner(s): ["module: onnx"]
2
3import unittest
4
5import onnx_test_common
6import onnxruntime  # noqa: F401
7import parameterized
8from onnx_test_common import MAX_ONNX_OPSET_VERSION, MIN_ONNX_OPSET_VERSION
9from pytorch_test_common import (
10    skipIfNoBFloat16Cuda,
11    skipIfNoCuda,
12    skipIfUnsupportedMinOpsetVersion,
13    skipScriptTest,
14)
15from test_pytorch_onnx_onnxruntime import _parameterized_class_attrs_and_values
16
17import torch
18from torch.cuda.amp import autocast
19from torch.testing._internal import common_utils
20
21
22@parameterized.parameterized_class(
23    **_parameterized_class_attrs_and_values(
24        MIN_ONNX_OPSET_VERSION, MAX_ONNX_OPSET_VERSION
25    ),
26    class_name_func=onnx_test_common.parameterize_class_name,
27)
28class TestONNXRuntime_cuda(onnx_test_common._TestONNXRuntime):
29    @skipIfUnsupportedMinOpsetVersion(9)
30    @skipIfNoCuda
31    def test_gelu_fp16(self):
32        class GeluModel(torch.nn.Module):
33            def forward(self, x):
34                return torch.nn.functional.gelu(x)
35
36        x = torch.randn(
37            2,
38            4,
39            5,
40            6,
41            requires_grad=True,
42            dtype=torch.float16,
43            device=torch.device("cuda"),
44        )
45        self.run_test(GeluModel(), x, rtol=1e-3, atol=1e-5)
46
47    @skipIfUnsupportedMinOpsetVersion(9)
48    @skipIfNoCuda
49    @skipScriptTest()
50    def test_layer_norm_fp16(self):
51        class LayerNormModel(torch.nn.Module):
52            def __init__(self) -> None:
53                super().__init__()
54                self.layer_norm = torch.nn.LayerNorm([10, 10])
55
56            @autocast()
57            def forward(self, x):
58                return self.layer_norm(x)
59
60        x = torch.randn(
61            20,
62            5,
63            10,
64            10,
65            requires_grad=True,
66            dtype=torch.float16,
67            device=torch.device("cuda"),
68        )
69        self.run_test(LayerNormModel().cuda(), x, rtol=1e-3, atol=1e-5)
70
71    @skipIfUnsupportedMinOpsetVersion(12)
72    @skipIfNoCuda
73    @skipScriptTest()
74    def test_softmaxCrossEntropy_fusion_fp16(self):
75        class FusionModel(torch.nn.Module):
76            def __init__(self) -> None:
77                super().__init__()
78                self.loss = torch.nn.NLLLoss(reduction="none")
79                self.m = torch.nn.LogSoftmax(dim=1)
80
81            @autocast()
82            def forward(self, input, target):
83                output = self.loss(self.m(2 * input), target)
84                return output
85
86        N, C = 5, 4
87        input = torch.randn(N, 16, dtype=torch.float16, device=torch.device("cuda"))
88        target = torch.empty(N, dtype=torch.long, device=torch.device("cuda")).random_(
89            0, C
90        )
91
92        # using test data containing default ignore_index=-100
93        target[target == 1] = -100
94        self.run_test(FusionModel(), (input, target))
95
96    @skipIfNoCuda
97    @skipScriptTest()
98    def test_apex_o2(self):
99        class LinearModel(torch.nn.Module):
100            def __init__(self) -> None:
101                super().__init__()
102                self.linear = torch.nn.Linear(3, 5)
103
104            def forward(self, x):
105                return self.linear(x)
106
107        try:
108            from apex import amp
109        except Exception as e:
110            raise unittest.SkipTest("Apex is not available") from e
111        input = torch.randn(3, 3, device=torch.device("cuda"))
112        model = amp.initialize(LinearModel(), opt_level="O2")
113        self.run_test(model, input)
114
115    # ONNX supports bfloat16 for opsets >= 13
116    # Add, Sub and Mul ops don't support bfloat16 cpu in onnxruntime.
117    @skipIfUnsupportedMinOpsetVersion(13)
118    @skipIfNoBFloat16Cuda
119    def test_arithmetic_bfp16(self):
120        class MyModule(torch.nn.Module):
121            def forward(self, x):
122                y = torch.ones(3, 4, dtype=torch.bfloat16, device=torch.device("cuda"))
123                x = x.type_as(y)
124                return torch.mul(torch.add(x, y), torch.sub(x, y)).to(
125                    dtype=torch.float16
126                )
127
128        x = torch.ones(
129            3, 4, requires_grad=True, dtype=torch.float16, device=torch.device("cuda")
130        )
131        self.run_test(MyModule(), x, rtol=1e-3, atol=1e-5)
132
133    @skipIfNoCuda
134    def test_deduplicate_initializers_diff_devices(self):
135        class Model(torch.nn.Module):
136            def __init__(self) -> None:
137                super().__init__()
138                self.w = torch.nn.Parameter(
139                    torch.ones(2, 3, device=torch.device("cpu"))
140                )
141                self.b = torch.nn.Parameter(torch.ones(3, device=torch.device("cuda")))
142
143            def forward(self, x, y):
144                return torch.matmul(self.w, x), y + self.b
145
146        x = torch.randn(3, 3, device=torch.device("cpu"))
147        y = torch.randn(3, 3, device=torch.device("cuda"))
148        self.run_test(Model(), (x, y))
149
150
151if __name__ == "__main__":
152    common_utils.run_tests()
153