1 #include <torch/csrc/jit/codegen/fuser/cuda/fused_kernel.h>
2
3 #include <torch/csrc/jit/codegen/fuser/compiler.h>
4
5 #include <ATen/ATen.h>
6 #include <ATen/cuda/CUDAContext.h>
7 #include <ATen/cuda/CUDAGeneratorImpl.h>
8 #include <ATen/cuda/nvrtc_stub/ATenNVRTC.h>
9 #include <ATen/native/cuda/jit_utils.h>
10 #include <c10/cuda/CUDAGuard.h>
11 #include <torch/csrc/jit/resource_guard.h>
12
13 #include <cuda_runtime.h>
14
15 #include <algorithm>
16 #include <cmath>
17 #include <sstream>
18 #include <stdexcept>
19 #include <vector>
20
21 namespace torch::jit::fuser::cuda {
22
23 // See NOTE [ USE OF NVRTC AND DRIVER API ]
nvrtc()24 const at::cuda::NVRTC& nvrtc() {
25 return at::globalContext().getNVRTC();
26 }
27
28 // query codegen output arch and target
codegenOutputQuery(const cudaDeviceProp * const prop,int & major,int & minor,bool & compile_to_sass)29 void codegenOutputQuery(
30 const cudaDeviceProp* const prop,
31 int& major,
32 int& minor,
33 bool& compile_to_sass) {
34 #ifdef USE_ROCM
35 AT_CUDA_NVRTC_CHECK(nvrtc().nvrtcVersion(&major, &minor));
36 compile_to_sass = false;
37 #else
38 using CudaVersion = std::pair<int, int>;
39 CudaVersion nvrtc_version;
40 AT_CUDA_NVRTC_CHECK(
41 nvrtc().nvrtcVersion(&nvrtc_version.first, &nvrtc_version.second));
42
43 TORCH_CHECK(
44 nvrtc_version.first >= 6,
45 "NVRTC versions less than 6 are not supported. Is: ",
46 nvrtc_version.first);
47
48 // Version supported by device
49 // Usually any lower version works too but is less efficient
50 const CudaVersion dev_version = CudaVersion(prop->major, prop->minor);
51 // Maximum version supported by the driver, cap dev_version to this
52 CudaVersion max_dev_version;
53 if (nvrtc_version.first <= 7) { // 7 supports 2-5.x
54 max_dev_version = CudaVersion(5, 0);
55 } else if (nvrtc_version.first <= 8) { // 8 supports 2-6.x
56 max_dev_version = CudaVersion(6, 0);
57 } else if (nvrtc_version.first <= 9) { // 9 supports 3-7.2
58 max_dev_version = CudaVersion(7, 2);
59 } else if (nvrtc_version.first <= 10) { // 10 supports 3-7.5
60 max_dev_version = CudaVersion(7, 5);
61 } else if (nvrtc_version == CudaVersion(11, 0)) { // 11.0 supports 3-8.0
62 max_dev_version = CudaVersion(8, 0);
63 } else if (nvrtc_version.first == 11 && nvrtc_version.second < 8) {
64 max_dev_version = CudaVersion(8, 6);
65 } else {
66 // If the driver version is unknown (i.e. newer than this code)
67 // assume the driver supports this device
68 max_dev_version = dev_version;
69 }
70 if (dev_version > max_dev_version) {
71 major = max_dev_version.first;
72 minor = max_dev_version.second;
73 // if we are clamping major/minor, sass is not compatible
74 compile_to_sass = false;
75 } else {
76 major = dev_version.first;
77 minor = dev_version.second;
78 compile_to_sass = true;
79 }
80 #endif
81 }
82
83 // Compiles the specified kernel and stores the metadata required to run it
FusedKernelCUDA(at::DeviceIndex device,std::string name,std::string code,std::vector<TensorDesc> input_desc,std::vector<TensorDesc> output_desc,std::vector<PartitionDesc> chunk_desc,std::vector<PartitionDesc> concat_desc,bool has_random)84 FusedKernelCUDA::FusedKernelCUDA(
85 at::DeviceIndex device,
86 std::string name,
87 std::string code,
88 std::vector<TensorDesc> input_desc,
89 std::vector<TensorDesc> output_desc,
90 std::vector<PartitionDesc> chunk_desc,
91 std::vector<PartitionDesc> concat_desc,
92 bool has_random)
93 : FusedKernel(
94 std::move(name),
95 std::move(code),
96 std::move(input_desc),
97 std::move(output_desc),
98 std::move(chunk_desc),
99 std::move(concat_desc),
100 has_random),
101 device_(device) {
102 // Initializes driver's API context (if necessary)
103 at::cuda::jit::initializeCudaContext();
104
105 // Note: hacked at::DeviceGuard since at::DeviceGuard was failing to work
106 // properly in some scenarios
107 const auto prior_device = at::cuda::current_device();
108 at::cuda::set_device(device_);
109
110 // Acquires device and NVRTC properties (for compile arch and occupancy
111 // calculations)
112 // NOLINTNEXTLINE(cppcoreguidelines-prefer-member-initializer)
113 prop_ = at::cuda::getCurrentDeviceProperties();
114 int major = 0, minor = 0;
115 bool compile_to_sass = false;
116 codegenOutputQuery(prop_, major, minor, compile_to_sass);
117
118 // Creates the NVRTC program
119 nvrtcProgram program{};
120 AT_CUDA_NVRTC_CHECK(nvrtc().nvrtcCreateProgram(
121 &program, code_.c_str(), nullptr, 0, nullptr, nullptr));
122
123 #if defined(USE_ROCM)
124 std::vector<const char*> args = {"--std=c++17"};
125 args.push_back("-hip-pch");
126 #else
127 const std::string compute = std::string("--gpu-architecture=") +
128 #if defined(CUDA_VERSION) && CUDA_VERSION >= 11010
129 // CUDA 11.1 allows going directly to SASS (sm_) instead of PTX (compute_)
130 // which gives better backwards compatibility to work on older driver,
131 // (since older driver doesn't necessrily recognize PTX emitted by new
132 // toolkit);
133 // Meanwhile, for forward compatibility (future device with
134 // `compile_to_sass==false`), since SASS are not necessarily compatible,
135 // we fallback to PTX instead.
136 (compile_to_sass ? "sm_" : "compute_") +
137 #else
138 "compute_" +
139 #endif
140 std::to_string(major) + std::to_string(minor);
141 const std::vector<const char*> args = {
142 "--std=c++17", compute.c_str(), "-default-device"};
143 #endif
144 const auto result =
145 nvrtc().nvrtcCompileProgram(program, args.size(), args.data());
146 if (result != NVRTC_SUCCESS) {
147 size_t logsize = 0;
148 AT_CUDA_NVRTC_CHECK(nvrtc().nvrtcGetProgramLogSize(program, &logsize));
149 std::vector<char> log(logsize);
150 AT_CUDA_NVRTC_CHECK(nvrtc().nvrtcGetProgramLog(program, log.data()));
151 std::stringstream cu;
152 cu << log.data();
153 throw std::runtime_error(cu.str());
154 }
155 ResourceGuard holdProgram(
156 [&] { AT_CUDA_NVRTC_CHECK(nvrtc().nvrtcDestroyProgram(&program)); });
157 AT_CUDA_NVRTC_CHECK(result);
158 size_t ptx_size = 0;
159 #if defined(CUDA_VERSION) && CUDA_VERSION >= 11010
160 // compile_to_sass determines whether we are generating SASS or PTX, hence
161 // the different API.
162 const auto getSize = compile_to_sass
163 ? at::globalContext().getNVRTC().nvrtcGetCUBINSize
164 : at::globalContext().getNVRTC().nvrtcGetPTXSize;
165 const auto getFunc = compile_to_sass
166 ? at::globalContext().getNVRTC().nvrtcGetCUBIN
167 : at::globalContext().getNVRTC().nvrtcGetPTX;
168 #else
169 const auto getSize = at::globalContext().getNVRTC().nvrtcGetPTXSize;
170 const auto getFunc = at::globalContext().getNVRTC().nvrtcGetPTX;
171 #endif
172 AT_CUDA_NVRTC_CHECK(getSize(program, &ptx_size));
173 ptx_.resize(ptx_size);
174 AT_CUDA_NVRTC_CHECK(getFunc(program, ptx_.data()));
175
176 AT_CUDA_DRIVER_CHECK(nvrtc().cuModuleLoadData(&module_, ptx_.data()));
177 AT_CUDA_DRIVER_CHECK(
178 nvrtc().cuModuleGetFunction(&function_, module_, name_.c_str()));
179
180 // Computes max blocks
181 AT_CUDA_DRIVER_CHECK(nvrtc().cuOccupancyMaxActiveBlocksPerMultiprocessor(
182 &maxBlocks_, function_, 128, 0));
183 maxBlocks_ *= prop_->multiProcessorCount;
184
185 // Resets device (end of hacked at::DeviceGuard)
186 at::cuda::set_device(prior_device);
187 }
188
ceilDiv(const int a,const int b)189 static int ceilDiv(const int a, const int b) {
190 return (a + b - 1) / b;
191 }
192
launch_raw(const uint32_t numel,std::vector<void * > & arguments) const193 void FusedKernelCUDA::launch_raw(
194 const uint32_t numel,
195 std::vector<void*>& arguments) const {
196 at::cuda::CUDAGuard guard{device_};
197 // Hacked at::DeviceGuard (see note above)
198 const auto prior_device = at::cuda::current_device();
199 at::cuda::set_device(device_);
200
201 const auto nBlocks = std::min(maxBlocks_, ceilDiv(numel, kBlockSize));
202
203 // Adds random state to arguments if necessary
204 // Note: philox_engine_inputs defined here so its lifetime extends to the
205 // launch
206 std::pair<uint64_t, uint64_t> philox_engine_inputs;
207 if (has_random_) {
208 const auto rand_offset =
209 4 * (std::ceil(numel / (4.0 * kBlockSize * nBlocks)) + 1);
210 auto gen = at::cuda::detail::getDefaultCUDAGenerator();
211 {
212 // See Note [Acquire lock when using random generators]
213 std::lock_guard<std::mutex> lock(gen.mutex());
214 philox_engine_inputs =
215 at::check_generator<at::CUDAGeneratorImpl>(gen)->philox_engine_inputs(
216 rand_offset);
217 }
218 arguments.push_back(&philox_engine_inputs.first);
219 arguments.push_back(&philox_engine_inputs.second);
220 }
221
222 // Launches kernel on current stream (device was set by executor)
223 auto stream = at::cuda::getCurrentCUDAStream();
224 AT_CUDA_DRIVER_CHECK(nvrtc().cuLaunchKernel(
225 function_,
226 nBlocks,
227 1,
228 1,
229 kBlockSize,
230 1,
231 1,
232 0,
233 stream,
234 arguments.data(),
235 nullptr));
236
237 // Resets device (see at::DeviceGuard notes above)
238 at::cuda::set_device(prior_device);
239 }
240
~FusedKernelCUDA()241 FusedKernelCUDA::~FusedKernelCUDA() {
242 nvrtc().cuModuleUnload(module_);
243 }
244
createFusionKernel(int16_t device,std::string name,std::string code,std::vector<TensorDesc> input_desc,std::vector<TensorDesc> output_desc,std::vector<PartitionDesc> chunk_desc,std::vector<PartitionDesc> concat_desc,bool has_random)245 static std::shared_ptr<FusedKernel> createFusionKernel(
246 int16_t device,
247 std::string name,
248 std::string code,
249 std::vector<TensorDesc> input_desc,
250 std::vector<TensorDesc> output_desc,
251 std::vector<PartitionDesc> chunk_desc,
252 std::vector<PartitionDesc> concat_desc,
253 bool has_random) {
254 return std::make_shared<FusedKernelCUDA>(
255 static_cast<at::DeviceIndex>(device),
256 std::move(name),
257 std::move(code),
258 std::move(input_desc),
259 std::move(output_desc),
260 std::move(chunk_desc),
261 std::move(concat_desc),
262 has_random);
263 }
264
265 RegisterFusionBackend reg(DeviceType::CUDA, createFusionKernel);
266
267 } // namespace torch::jit::fuser::cuda
268