1# Owner(s): ["module: cpp"] 2 3 4import os 5 6from cpp_api_parity import ( 7 functional_impl_check, 8 module_impl_check, 9 sample_functional, 10 sample_module, 11) 12from cpp_api_parity.parity_table_parser import parse_parity_tracker_table 13from cpp_api_parity.utils import is_torch_nn_functional_test 14 15import torch 16import torch.testing._internal.common_nn as common_nn 17import torch.testing._internal.common_utils as common 18 19 20# NOTE: turn this on if you want to print source code of all C++ tests (e.g. for debugging purpose) 21PRINT_CPP_SOURCE = False 22 23devices = ["cpu", "cuda"] 24 25PARITY_TABLE_PATH = os.path.join( 26 os.path.dirname(__file__), "cpp_api_parity", "parity-tracker.md" 27) 28 29parity_table = parse_parity_tracker_table(PARITY_TABLE_PATH) 30 31 32@torch.testing._internal.common_utils.markDynamoStrictTest 33class TestCppApiParity(common.TestCase): 34 module_test_params_map = {} 35 functional_test_params_map = {} 36 37 38expected_test_params_dicts = [] 39 40if not common.IS_ARM64: 41 for test_params_dicts, test_instance_class in [ 42 (sample_module.module_tests, common_nn.NewModuleTest), 43 (sample_functional.functional_tests, common_nn.NewModuleTest), 44 (common_nn.module_tests, common_nn.NewModuleTest), 45 (common_nn.new_module_tests, common_nn.NewModuleTest), 46 (common_nn.criterion_tests, common_nn.CriterionTest), 47 ]: 48 for test_params_dict in test_params_dicts: 49 if test_params_dict.get("test_cpp_api_parity", True): 50 if is_torch_nn_functional_test(test_params_dict): 51 functional_impl_check.write_test_to_test_class( 52 TestCppApiParity, 53 test_params_dict, 54 test_instance_class, 55 parity_table, 56 devices, 57 ) 58 else: 59 module_impl_check.write_test_to_test_class( 60 TestCppApiParity, 61 test_params_dict, 62 test_instance_class, 63 parity_table, 64 devices, 65 ) 66 expected_test_params_dicts.append(test_params_dict) 67 68 # Assert that all NN module/functional test dicts appear in the parity test 69 assert len( 70 [name for name in TestCppApiParity.__dict__ if "test_torch_nn_" in name] 71 ) == len(expected_test_params_dicts) * len(devices) 72 73 # Assert that there exists auto-generated tests for `SampleModule` and `sample_functional`. 74 # 4 == 2 (number of test dicts that are not skipped) * 2 (number of devices) 75 assert ( 76 len([name for name in TestCppApiParity.__dict__ if "SampleModule" in name]) == 4 77 ) 78 # 4 == 2 (number of test dicts that are not skipped) * 2 (number of devices) 79 assert ( 80 len([name for name in TestCppApiParity.__dict__ if "sample_functional" in name]) 81 == 4 82 ) 83 84 module_impl_check.build_cpp_tests( 85 TestCppApiParity, print_cpp_source=PRINT_CPP_SOURCE 86 ) 87 functional_impl_check.build_cpp_tests( 88 TestCppApiParity, print_cpp_source=PRINT_CPP_SOURCE 89 ) 90 91if __name__ == "__main__": 92 common.TestCase._default_dtype_check_enabled = True 93 common.run_tests() 94