xref: /aosp_15_r20/external/pytorch/torch/csrc/distributed/c10d/reducer_cuda.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <torch/csrc/distributed/c10d/reducer_timer.hpp>
2 
3 #include <ATen/cuda/CUDAEvent.h>
4 #include <c10/core/DeviceGuard.h>
5 
6 namespace c10d {
7 namespace {
8 
9 const int kMilliSecondToNanosSecond = 1000000;
10 
11 class CudaTimer : public Timer {
12  private:
13   c10::Device device;
14 
15   at::cuda::CUDAEvent forward_start = at::cuda::CUDAEvent(cudaEventDefault);
16   at::cuda::CUDAEvent backward_compute_start =
17       at::cuda::CUDAEvent(cudaEventDefault);
18   at::cuda::CUDAEvent backward_compute_end =
19       at::cuda::CUDAEvent(cudaEventDefault);
20   at::cuda::CUDAEvent backward_comm_start =
21       at::cuda::CUDAEvent(cudaEventDefault);
22   at::cuda::CUDAEvent backward_comm_end = at::cuda::CUDAEvent(cudaEventDefault);
23 
getEvent(Event event)24   at::cuda::CUDAEvent& getEvent(Event event) {
25     switch (event) {
26       case Event::kForwardStart:
27         return forward_start;
28       case Event::kBackwardComputeStart:
29         return backward_compute_start;
30       case Event::kBackwardComputeEnd:
31         return backward_compute_end;
32       case Event::kBackwardCommStart:
33         return backward_comm_start;
34       case Event::kBackwardCommEnd:
35         return backward_comm_end;
36       default:
37         TORCH_INTERNAL_ASSERT(false);
38     }
39   }
40 
41  public:
CudaTimer(c10::Device dev)42   explicit CudaTimer(c10::Device dev) : device(dev) {}
43 
record(Event event)44   void record(Event event) override {
45     // Parent class sets the host-side time
46     Timer::record(event);
47     c10::DeviceGuard g(device);
48     getEvent(event).record();
49   }
50 
measureDifference(Event start,Event end)51   std::optional<int64_t> measureDifference(Event start, Event end) override {
52     c10::DeviceGuard g(device);
53     at::cuda::CUDAEvent& start_event = getEvent(start);
54     at::cuda::CUDAEvent& end_event = getEvent(end);
55     // It is possible users did not call backward or run codes in
56     // no-sync mode, in this case, some cudaEvents like "backward_compute_end"
57     // or "backward_comm_start" or "backward_comm_end" will not be recorded.
58     // cudaEvent is created when it is first time to be recorded.
59     // If it is never recorded/created, skip synchronize and calculation.
60     // Otherwise it will throw cuda errors.
61     if (!start_event.isCreated() || !end_event.isCreated()) {
62       return std::nullopt;
63     }
64     // set_runtime_stats_and_log is called at the beginning of forward call,
65     // when it is cheap to synchronize the cuda events of previous iteration,
66     // as mostly all cuda operations are finished in previous iteration.
67     start_event.synchronize();
68     end_event.synchronize();
69     float milliseconds = start_event.elapsed_time(end_event);
70     // If gpu_end is not recorded in this iteration,
71     // milliseconds will have invalid value.
72     // For some cases like DDP runs on non-sync mode,
73     // gpu_end can not be recorded in this iteration and thus can not
74     // calculate the valid avg_time.
75     // In this case, skip calculating the avg_time and return.
76     if (milliseconds < 0) {
77       return std::nullopt;
78     }
79     return int64_t(milliseconds * kMilliSecondToNanosSecond);
80   }
81 };
82 
83 C10_REGISTER_TYPED_CLASS(TimerRegistry, c10::kCUDA, CudaTimer);
84 
85 } // namespace
86 } // namespace c10d
87