/* * CuDNN ReLU extension. Simple function but contains the general structure of * most CuDNN extensions: * 1) Check arguments. torch::check* functions provide a standard way to * validate input and provide pretty errors. 2) Create descriptors. Most CuDNN * functions require creating and setting a variety of descriptors. 3) Apply the * CuDNN function. 4) Destroy your descriptors. 5) Return something (optional). */ #include #include // for CUDNN_CHECK #include // for TensorDescriptor #include // for getCudnnHandle // Name of function in python module and name used for error messages by // torch::check* functions. const char* cudnn_relu_name = "cudnn_relu"; // Check arguments to cudnn_relu void cudnn_relu_check( const torch::Tensor& inputs, const torch::Tensor& outputs) { // Create TensorArgs. These record the names and positions of each tensor as a // parameter. torch::TensorArg arg_inputs(inputs, "inputs", 0); torch::TensorArg arg_outputs(outputs, "outputs", 1); // Check arguments. No need to return anything. These functions with throw an // error if they fail. Messages are populated using information from // TensorArgs. torch::checkContiguous(cudnn_relu_name, arg_inputs); torch::checkScalarType(cudnn_relu_name, arg_inputs, torch::kFloat); torch::checkBackend(cudnn_relu_name, arg_inputs.tensor, torch::Backend::CUDA); torch::checkContiguous(cudnn_relu_name, arg_outputs); torch::checkScalarType(cudnn_relu_name, arg_outputs, torch::kFloat); torch::checkBackend( cudnn_relu_name, arg_outputs.tensor, torch::Backend::CUDA); torch::checkSameSize(cudnn_relu_name, arg_inputs, arg_outputs); } void cudnn_relu(const torch::Tensor& inputs, const torch::Tensor& outputs) { // Most CuDNN extensions will follow a similar pattern. // Step 1: Check inputs. This will throw an error if inputs are invalid, so no // need to check return codes here. cudnn_relu_check(inputs, outputs); // Step 2: Create descriptors cudnnHandle_t cuDnn = torch::native::getCudnnHandle(); // Note: 4 is minimum dim for a TensorDescriptor. Input and output are same // size and type and contiguous, so one descriptor is sufficient. torch::native::TensorDescriptor input_tensor_desc(inputs, 4); cudnnActivationDescriptor_t activationDesc; // Note: Always check return value of cudnn functions using CUDNN_CHECK AT_CUDNN_CHECK(cudnnCreateActivationDescriptor(&activationDesc)); AT_CUDNN_CHECK(cudnnSetActivationDescriptor( activationDesc, /*mode=*/CUDNN_ACTIVATION_RELU, /*reluNanOpt=*/CUDNN_PROPAGATE_NAN, /*coef=*/1.)); // Step 3: Apply CuDNN function float alpha = 1.; float beta = 0.; AT_CUDNN_CHECK(cudnnActivationForward( cuDnn, activationDesc, &alpha, input_tensor_desc.desc(), inputs.data_ptr(), &beta, input_tensor_desc.desc(), // output descriptor same as input outputs.data_ptr())); // Step 4: Destroy descriptors AT_CUDNN_CHECK(cudnnDestroyActivationDescriptor(activationDesc)); // Step 5: Return something (optional) } // Create the pybind11 module PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { // Use the same name as the check functions so error messages make sense m.def(cudnn_relu_name, &cudnn_relu, "CuDNN ReLU"); }