1*da0073e9SAndroid Build Coastguard Worker #include <torch/extension.h> 2*da0073e9SAndroid Build Coastguard Worker #include <torch/torch.h> 3*da0073e9SAndroid Build Coastguard Worker 4*da0073e9SAndroid Build Coastguard Worker using namespace torch::autograd; 5*da0073e9SAndroid Build Coastguard Worker 6*da0073e9SAndroid Build Coastguard Worker class Identity : public Function<Identity> { 7*da0073e9SAndroid Build Coastguard Worker public: forward(AutogradContext * ctx,torch::Tensor input)8*da0073e9SAndroid Build Coastguard Worker static torch::Tensor forward(AutogradContext* ctx, torch::Tensor input) { 9*da0073e9SAndroid Build Coastguard Worker return input; 10*da0073e9SAndroid Build Coastguard Worker } 11*da0073e9SAndroid Build Coastguard Worker backward(AutogradContext * ctx,tensor_list grad_outputs)12*da0073e9SAndroid Build Coastguard Worker static tensor_list backward(AutogradContext* ctx, tensor_list grad_outputs) { 13*da0073e9SAndroid Build Coastguard Worker return {grad_outputs[0]}; 14*da0073e9SAndroid Build Coastguard Worker } 15*da0073e9SAndroid Build Coastguard Worker }; 16*da0073e9SAndroid Build Coastguard Worker identity(torch::Tensor input)17*da0073e9SAndroid Build Coastguard Workertorch::Tensor identity(torch::Tensor input) { 18*da0073e9SAndroid Build Coastguard Worker return Identity::apply(input); 19*da0073e9SAndroid Build Coastguard Worker } 20*da0073e9SAndroid Build Coastguard Worker PYBIND11_MODULE(TORCH_EXTENSION_NAME,m)21*da0073e9SAndroid Build Coastguard WorkerPYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { 22*da0073e9SAndroid Build Coastguard Worker m.def("identity", &identity, "identity"); 23*da0073e9SAndroid Build Coastguard Worker } 24