xref: /aosp_15_r20/external/pytorch/torch/csrc/jit/codegen/fuser/cuda/fused_kernel.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <torch/csrc/jit/codegen/fuser/cuda/fused_kernel.h>
2 
3 #include <torch/csrc/jit/codegen/fuser/compiler.h>
4 
5 #include <ATen/ATen.h>
6 #include <ATen/cuda/CUDAContext.h>
7 #include <ATen/cuda/CUDAGeneratorImpl.h>
8 #include <ATen/cuda/nvrtc_stub/ATenNVRTC.h>
9 #include <ATen/native/cuda/jit_utils.h>
10 #include <c10/cuda/CUDAGuard.h>
11 #include <torch/csrc/jit/resource_guard.h>
12 
13 #include <cuda_runtime.h>
14 
15 #include <algorithm>
16 #include <cmath>
17 #include <sstream>
18 #include <stdexcept>
19 #include <vector>
20 
21 namespace torch::jit::fuser::cuda {
22 
23 // See NOTE [ USE OF NVRTC AND DRIVER API ]
nvrtc()24 const at::cuda::NVRTC& nvrtc() {
25   return at::globalContext().getNVRTC();
26 }
27 
28 // query codegen output arch and target
codegenOutputQuery(const cudaDeviceProp * const prop,int & major,int & minor,bool & compile_to_sass)29 void codegenOutputQuery(
30     const cudaDeviceProp* const prop,
31     int& major,
32     int& minor,
33     bool& compile_to_sass) {
34 #ifdef USE_ROCM
35   AT_CUDA_NVRTC_CHECK(nvrtc().nvrtcVersion(&major, &minor));
36   compile_to_sass = false;
37 #else
38   using CudaVersion = std::pair<int, int>;
39   CudaVersion nvrtc_version;
40   AT_CUDA_NVRTC_CHECK(
41       nvrtc().nvrtcVersion(&nvrtc_version.first, &nvrtc_version.second));
42 
43   TORCH_CHECK(
44       nvrtc_version.first >= 6,
45       "NVRTC versions less than 6 are not supported. Is: ",
46       nvrtc_version.first);
47 
48   // Version supported by device
49   // Usually any lower version works too but is less efficient
50   const CudaVersion dev_version = CudaVersion(prop->major, prop->minor);
51   // Maximum version supported by the driver, cap dev_version to this
52   CudaVersion max_dev_version;
53   if (nvrtc_version.first <= 7) { // 7 supports 2-5.x
54     max_dev_version = CudaVersion(5, 0);
55   } else if (nvrtc_version.first <= 8) { // 8 supports 2-6.x
56     max_dev_version = CudaVersion(6, 0);
57   } else if (nvrtc_version.first <= 9) { // 9 supports 3-7.2
58     max_dev_version = CudaVersion(7, 2);
59   } else if (nvrtc_version.first <= 10) { // 10 supports 3-7.5
60     max_dev_version = CudaVersion(7, 5);
61   } else if (nvrtc_version == CudaVersion(11, 0)) { // 11.0 supports 3-8.0
62     max_dev_version = CudaVersion(8, 0);
63   } else if (nvrtc_version.first == 11 && nvrtc_version.second < 8) {
64     max_dev_version = CudaVersion(8, 6);
65   } else {
66     // If the driver version is unknown (i.e. newer than this code)
67     // assume the driver supports this device
68     max_dev_version = dev_version;
69   }
70   if (dev_version > max_dev_version) {
71     major = max_dev_version.first;
72     minor = max_dev_version.second;
73     // if we are clamping major/minor, sass is not compatible
74     compile_to_sass = false;
75   } else {
76     major = dev_version.first;
77     minor = dev_version.second;
78     compile_to_sass = true;
79   }
80 #endif
81 }
82 
83 // Compiles the specified kernel and stores the metadata required to run it
FusedKernelCUDA(at::DeviceIndex device,std::string name,std::string code,std::vector<TensorDesc> input_desc,std::vector<TensorDesc> output_desc,std::vector<PartitionDesc> chunk_desc,std::vector<PartitionDesc> concat_desc,bool has_random)84 FusedKernelCUDA::FusedKernelCUDA(
85     at::DeviceIndex device,
86     std::string name,
87     std::string code,
88     std::vector<TensorDesc> input_desc,
89     std::vector<TensorDesc> output_desc,
90     std::vector<PartitionDesc> chunk_desc,
91     std::vector<PartitionDesc> concat_desc,
92     bool has_random)
93     : FusedKernel(
94           std::move(name),
95           std::move(code),
96           std::move(input_desc),
97           std::move(output_desc),
98           std::move(chunk_desc),
99           std::move(concat_desc),
100           has_random),
101       device_(device) {
102   // Initializes driver's API context (if necessary)
103   at::cuda::jit::initializeCudaContext();
104 
105   // Note: hacked at::DeviceGuard since at::DeviceGuard was failing to work
106   // properly in some scenarios
107   const auto prior_device = at::cuda::current_device();
108   at::cuda::set_device(device_);
109 
110   // Acquires device and NVRTC properties (for compile arch and occupancy
111   // calculations)
112   // NOLINTNEXTLINE(cppcoreguidelines-prefer-member-initializer)
113   prop_ = at::cuda::getCurrentDeviceProperties();
114   int major = 0, minor = 0;
115   bool compile_to_sass = false;
116   codegenOutputQuery(prop_, major, minor, compile_to_sass);
117 
118   // Creates the NVRTC program
119   nvrtcProgram program{};
120   AT_CUDA_NVRTC_CHECK(nvrtc().nvrtcCreateProgram(
121       &program, code_.c_str(), nullptr, 0, nullptr, nullptr));
122 
123 #if defined(USE_ROCM)
124   std::vector<const char*> args = {"--std=c++17"};
125   args.push_back("-hip-pch");
126 #else
127   const std::string compute = std::string("--gpu-architecture=") +
128 #if defined(CUDA_VERSION) && CUDA_VERSION >= 11010
129       // CUDA 11.1 allows going directly to SASS (sm_) instead of PTX (compute_)
130       // which gives better backwards compatibility to work on older driver,
131       // (since older driver doesn't necessrily recognize PTX emitted by new
132       // toolkit);
133       // Meanwhile, for forward compatibility (future device with
134       // `compile_to_sass==false`), since SASS are not necessarily compatible,
135       // we fallback to PTX instead.
136       (compile_to_sass ? "sm_" : "compute_") +
137 #else
138       "compute_" +
139 #endif
140       std::to_string(major) + std::to_string(minor);
141   const std::vector<const char*> args = {
142       "--std=c++17", compute.c_str(), "-default-device"};
143 #endif
144   const auto result =
145       nvrtc().nvrtcCompileProgram(program, args.size(), args.data());
146   if (result != NVRTC_SUCCESS) {
147     size_t logsize = 0;
148     AT_CUDA_NVRTC_CHECK(nvrtc().nvrtcGetProgramLogSize(program, &logsize));
149     std::vector<char> log(logsize);
150     AT_CUDA_NVRTC_CHECK(nvrtc().nvrtcGetProgramLog(program, log.data()));
151     std::stringstream cu;
152     cu << log.data();
153     throw std::runtime_error(cu.str());
154   }
155   ResourceGuard holdProgram(
156       [&] { AT_CUDA_NVRTC_CHECK(nvrtc().nvrtcDestroyProgram(&program)); });
157   AT_CUDA_NVRTC_CHECK(result);
158   size_t ptx_size = 0;
159 #if defined(CUDA_VERSION) && CUDA_VERSION >= 11010
160   // compile_to_sass determines whether we are generating SASS or PTX, hence
161   // the different API.
162   const auto getSize = compile_to_sass
163       ? at::globalContext().getNVRTC().nvrtcGetCUBINSize
164       : at::globalContext().getNVRTC().nvrtcGetPTXSize;
165   const auto getFunc = compile_to_sass
166       ? at::globalContext().getNVRTC().nvrtcGetCUBIN
167       : at::globalContext().getNVRTC().nvrtcGetPTX;
168 #else
169   const auto getSize = at::globalContext().getNVRTC().nvrtcGetPTXSize;
170   const auto getFunc = at::globalContext().getNVRTC().nvrtcGetPTX;
171 #endif
172   AT_CUDA_NVRTC_CHECK(getSize(program, &ptx_size));
173   ptx_.resize(ptx_size);
174   AT_CUDA_NVRTC_CHECK(getFunc(program, ptx_.data()));
175 
176   AT_CUDA_DRIVER_CHECK(nvrtc().cuModuleLoadData(&module_, ptx_.data()));
177   AT_CUDA_DRIVER_CHECK(
178       nvrtc().cuModuleGetFunction(&function_, module_, name_.c_str()));
179 
180   // Computes max blocks
181   AT_CUDA_DRIVER_CHECK(nvrtc().cuOccupancyMaxActiveBlocksPerMultiprocessor(
182       &maxBlocks_, function_, 128, 0));
183   maxBlocks_ *= prop_->multiProcessorCount;
184 
185   // Resets device (end of hacked at::DeviceGuard)
186   at::cuda::set_device(prior_device);
187 }
188 
ceilDiv(const int a,const int b)189 static int ceilDiv(const int a, const int b) {
190   return (a + b - 1) / b;
191 }
192 
launch_raw(const uint32_t numel,std::vector<void * > & arguments) const193 void FusedKernelCUDA::launch_raw(
194     const uint32_t numel,
195     std::vector<void*>& arguments) const {
196   at::cuda::CUDAGuard guard{device_};
197   // Hacked at::DeviceGuard (see note above)
198   const auto prior_device = at::cuda::current_device();
199   at::cuda::set_device(device_);
200 
201   const auto nBlocks = std::min(maxBlocks_, ceilDiv(numel, kBlockSize));
202 
203   // Adds random state to arguments if necessary
204   // Note: philox_engine_inputs defined here so its lifetime extends to the
205   // launch
206   std::pair<uint64_t, uint64_t> philox_engine_inputs;
207   if (has_random_) {
208     const auto rand_offset =
209         4 * (std::ceil(numel / (4.0 * kBlockSize * nBlocks)) + 1);
210     auto gen = at::cuda::detail::getDefaultCUDAGenerator();
211     {
212       // See Note [Acquire lock when using random generators]
213       std::lock_guard<std::mutex> lock(gen.mutex());
214       philox_engine_inputs =
215           at::check_generator<at::CUDAGeneratorImpl>(gen)->philox_engine_inputs(
216               rand_offset);
217     }
218     arguments.push_back(&philox_engine_inputs.first);
219     arguments.push_back(&philox_engine_inputs.second);
220   }
221 
222   // Launches kernel on current stream (device was set by executor)
223   auto stream = at::cuda::getCurrentCUDAStream();
224   AT_CUDA_DRIVER_CHECK(nvrtc().cuLaunchKernel(
225       function_,
226       nBlocks,
227       1,
228       1,
229       kBlockSize,
230       1,
231       1,
232       0,
233       stream,
234       arguments.data(),
235       nullptr));
236 
237   // Resets device (see at::DeviceGuard notes above)
238   at::cuda::set_device(prior_device);
239 }
240 
~FusedKernelCUDA()241 FusedKernelCUDA::~FusedKernelCUDA() {
242   nvrtc().cuModuleUnload(module_);
243 }
244 
createFusionKernel(int16_t device,std::string name,std::string code,std::vector<TensorDesc> input_desc,std::vector<TensorDesc> output_desc,std::vector<PartitionDesc> chunk_desc,std::vector<PartitionDesc> concat_desc,bool has_random)245 static std::shared_ptr<FusedKernel> createFusionKernel(
246     int16_t device,
247     std::string name,
248     std::string code,
249     std::vector<TensorDesc> input_desc,
250     std::vector<TensorDesc> output_desc,
251     std::vector<PartitionDesc> chunk_desc,
252     std::vector<PartitionDesc> concat_desc,
253     bool has_random) {
254   return std::make_shared<FusedKernelCUDA>(
255       static_cast<at::DeviceIndex>(device),
256       std::move(name),
257       std::move(code),
258       std::move(input_desc),
259       std::move(output_desc),
260       std::move(chunk_desc),
261       std::move(concat_desc),
262       has_random);
263 }
264 
265 RegisterFusionBackend reg(DeviceType::CUDA, createFusionKernel);
266 
267 } // namespace torch::jit::fuser::cuda
268