xref: /aosp_15_r20/external/pytorch/test/test_cpp_api_parity.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# Owner(s): ["module: cpp"]
2
3
4import os
5
6from cpp_api_parity import (
7    functional_impl_check,
8    module_impl_check,
9    sample_functional,
10    sample_module,
11)
12from cpp_api_parity.parity_table_parser import parse_parity_tracker_table
13from cpp_api_parity.utils import is_torch_nn_functional_test
14
15import torch
16import torch.testing._internal.common_nn as common_nn
17import torch.testing._internal.common_utils as common
18
19
20# NOTE: turn this on if you want to print source code of all C++ tests (e.g. for debugging purpose)
21PRINT_CPP_SOURCE = False
22
23devices = ["cpu", "cuda"]
24
25PARITY_TABLE_PATH = os.path.join(
26    os.path.dirname(__file__), "cpp_api_parity", "parity-tracker.md"
27)
28
29parity_table = parse_parity_tracker_table(PARITY_TABLE_PATH)
30
31
32@torch.testing._internal.common_utils.markDynamoStrictTest
33class TestCppApiParity(common.TestCase):
34    module_test_params_map = {}
35    functional_test_params_map = {}
36
37
38expected_test_params_dicts = []
39
40if not common.IS_ARM64:
41    for test_params_dicts, test_instance_class in [
42        (sample_module.module_tests, common_nn.NewModuleTest),
43        (sample_functional.functional_tests, common_nn.NewModuleTest),
44        (common_nn.module_tests, common_nn.NewModuleTest),
45        (common_nn.new_module_tests, common_nn.NewModuleTest),
46        (common_nn.criterion_tests, common_nn.CriterionTest),
47    ]:
48        for test_params_dict in test_params_dicts:
49            if test_params_dict.get("test_cpp_api_parity", True):
50                if is_torch_nn_functional_test(test_params_dict):
51                    functional_impl_check.write_test_to_test_class(
52                        TestCppApiParity,
53                        test_params_dict,
54                        test_instance_class,
55                        parity_table,
56                        devices,
57                    )
58                else:
59                    module_impl_check.write_test_to_test_class(
60                        TestCppApiParity,
61                        test_params_dict,
62                        test_instance_class,
63                        parity_table,
64                        devices,
65                    )
66                expected_test_params_dicts.append(test_params_dict)
67
68    # Assert that all NN module/functional test dicts appear in the parity test
69    assert len(
70        [name for name in TestCppApiParity.__dict__ if "test_torch_nn_" in name]
71    ) == len(expected_test_params_dicts) * len(devices)
72
73    # Assert that there exists auto-generated tests for `SampleModule` and `sample_functional`.
74    # 4 == 2 (number of test dicts that are not skipped) * 2 (number of devices)
75    assert (
76        len([name for name in TestCppApiParity.__dict__ if "SampleModule" in name]) == 4
77    )
78    # 4 == 2 (number of test dicts that are not skipped) * 2 (number of devices)
79    assert (
80        len([name for name in TestCppApiParity.__dict__ if "sample_functional" in name])
81        == 4
82    )
83
84    module_impl_check.build_cpp_tests(
85        TestCppApiParity, print_cpp_source=PRINT_CPP_SOURCE
86    )
87    functional_impl_check.build_cpp_tests(
88        TestCppApiParity, print_cpp_source=PRINT_CPP_SOURCE
89    )
90
91if __name__ == "__main__":
92    common.TestCase._default_dtype_check_enabled = True
93    common.run_tests()
94