xref: /aosp_15_r20/external/pytorch/test/cpp_extensions/identity.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1*da0073e9SAndroid Build Coastguard Worker #include <torch/extension.h>
2*da0073e9SAndroid Build Coastguard Worker #include <torch/torch.h>
3*da0073e9SAndroid Build Coastguard Worker 
4*da0073e9SAndroid Build Coastguard Worker using namespace torch::autograd;
5*da0073e9SAndroid Build Coastguard Worker 
6*da0073e9SAndroid Build Coastguard Worker class Identity : public Function<Identity> {
7*da0073e9SAndroid Build Coastguard Worker  public:
forward(AutogradContext * ctx,torch::Tensor input)8*da0073e9SAndroid Build Coastguard Worker   static torch::Tensor forward(AutogradContext* ctx, torch::Tensor input) {
9*da0073e9SAndroid Build Coastguard Worker     return input;
10*da0073e9SAndroid Build Coastguard Worker   }
11*da0073e9SAndroid Build Coastguard Worker 
backward(AutogradContext * ctx,tensor_list grad_outputs)12*da0073e9SAndroid Build Coastguard Worker   static tensor_list backward(AutogradContext* ctx, tensor_list grad_outputs) {
13*da0073e9SAndroid Build Coastguard Worker     return {grad_outputs[0]};
14*da0073e9SAndroid Build Coastguard Worker   }
15*da0073e9SAndroid Build Coastguard Worker };
16*da0073e9SAndroid Build Coastguard Worker 
identity(torch::Tensor input)17*da0073e9SAndroid Build Coastguard Worker torch::Tensor identity(torch::Tensor input) {
18*da0073e9SAndroid Build Coastguard Worker   return Identity::apply(input);
19*da0073e9SAndroid Build Coastguard Worker }
20*da0073e9SAndroid Build Coastguard Worker 
PYBIND11_MODULE(TORCH_EXTENSION_NAME,m)21*da0073e9SAndroid Build Coastguard Worker PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
22*da0073e9SAndroid Build Coastguard Worker   m.def("identity", &identity, "identity");
23*da0073e9SAndroid Build Coastguard Worker }
24