1import torch 2 3 4""" 5`SampleModule` is used by `test_cpp_api_parity.py` to test that Python / C++ API 6parity test harness works for `torch.nn.Module` subclasses. 7 8When `SampleModule.has_parity` is true, behavior of `forward` / `backward` 9is the same as the C++ equivalent. 10 11When `SampleModule.has_parity` is false, behavior of `forward` / `backward` 12is different from the C++ equivalent. 13""" 14 15 16class SampleModule(torch.nn.Module): 17 def __init__(self, has_parity, has_submodule): 18 super().__init__() 19 self.has_parity = has_parity 20 if has_submodule: 21 self.submodule = SampleModule(self.has_parity, False) 22 23 self.has_submodule = has_submodule 24 self.register_parameter("param", torch.nn.Parameter(torch.empty(3, 4))) 25 26 self.reset_parameters() 27 28 def reset_parameters(self): 29 with torch.no_grad(): 30 self.param.fill_(1) 31 32 def forward(self, x): 33 submodule_forward_result = ( 34 self.submodule(x) if hasattr(self, "submodule") else 0 35 ) 36 if self.has_parity: 37 return x + self.param * 2 + submodule_forward_result 38 else: 39 return x + self.param * 4 + submodule_forward_result + 3 40 41 42torch.nn.SampleModule = SampleModule 43 44SAMPLE_MODULE_CPP_SOURCE = """\n 45namespace torch { 46namespace nn { 47struct C10_EXPORT SampleModuleOptions { 48 SampleModuleOptions(bool has_parity, bool has_submodule) : has_parity_(has_parity), has_submodule_(has_submodule) {} 49 50 TORCH_ARG(bool, has_parity); 51 TORCH_ARG(bool, has_submodule); 52}; 53 54struct C10_EXPORT SampleModuleImpl : public torch::nn::Cloneable<SampleModuleImpl> { 55 explicit SampleModuleImpl(SampleModuleOptions options) : options(std::move(options)) { 56 if (options.has_submodule()) { 57 submodule = register_module( 58 "submodule", 59 std::make_shared<SampleModuleImpl>(SampleModuleOptions(options.has_parity(), false))); 60 } 61 reset(); 62 } 63 void reset() { 64 param = register_parameter("param", torch::ones({3, 4})); 65 } 66 torch::Tensor forward(torch::Tensor x) { 67 return x + param * 2 + (submodule ? submodule->forward(x) : torch::zeros_like(x)); 68 } 69 SampleModuleOptions options; 70 torch::Tensor param; 71 std::shared_ptr<SampleModuleImpl> submodule{nullptr}; 72}; 73 74TORCH_MODULE(SampleModule); 75} // namespace nn 76} // namespace torch 77""" 78 79module_tests = [ 80 dict( 81 module_name="SampleModule", 82 desc="has_parity", 83 constructor_args=(True, True), 84 cpp_constructor_args="torch::nn::SampleModuleOptions(true, true)", 85 input_size=(3, 4), 86 cpp_input_args=["torch::randn({3, 4})"], 87 has_parity=True, 88 ), 89 dict( 90 fullname="SampleModule_no_parity", 91 constructor=lambda: SampleModule(has_parity=False, has_submodule=True), 92 cpp_constructor_args="torch::nn::SampleModuleOptions(false, true)", 93 input_size=(3, 4), 94 cpp_input_args=["torch::randn({3, 4})"], 95 has_parity=False, 96 ), 97 # This is to test that setting the `test_cpp_api_parity=False` flag skips 98 # the C++ API parity test accordingly (otherwise this test would run and 99 # throw a parity error). 100 dict( 101 fullname="SampleModule_THIS_TEST_SHOULD_BE_SKIPPED", 102 constructor=lambda: SampleModule(False, True), 103 cpp_constructor_args="torch::nn::SampleModuleOptions(false, true)", 104 input_size=(3, 4), 105 cpp_input_args=["torch::randn({3, 4})"], 106 test_cpp_api_parity=False, 107 ), 108] 109