xref: /aosp_15_r20/external/pytorch/torch/csrc/jit/runtime/script_profile.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <torch/csrc/jit/runtime/script_profile.h>
2 
3 #include <atomic>
4 #include <chrono>
5 #include <mutex>
6 #include <unordered_set>
7 
8 #include <c10/util/Exception.h>
9 #include <c10/util/intrusive_ptr.h>
10 #include <torch/csrc/jit/api/function_impl.h>
11 
12 namespace torch::jit {
13 
14 namespace {
15 
16 class ProfilesRegistry {
17  public:
empty()18   bool empty() {
19     return empty_.load(std::memory_order_relaxed);
20   }
21 
addProfile(ScriptProfile & p)22   void addProfile(ScriptProfile& p) {
23     std::lock_guard<std::mutex> g(mutex_);
24     enabledProfiles_.emplace(&p);
25     empty_.store(false, std::memory_order_relaxed);
26   }
27 
removeProfile(ScriptProfile & p)28   void removeProfile(ScriptProfile& p) {
29     std::lock_guard<std::mutex> g(mutex_);
30     enabledProfiles_.erase(&p);
31     if (enabledProfiles_.empty()) {
32       empty_.store(true, std::memory_order_relaxed);
33     }
34   }
35 
send(std::unique_ptr<profiling::Datapoint> datapoint)36   void send(std::unique_ptr<profiling::Datapoint> datapoint) {
37     auto shared = std::shared_ptr<profiling::Datapoint>(std::move(datapoint));
38     std::lock_guard<std::mutex> g(mutex_);
39     for (auto* p : enabledProfiles_) {
40       p->addDatapoint(shared);
41     }
42   }
43 
44  private:
45   std::atomic<bool> empty_{true};
46   std::mutex mutex_;
47   std::unordered_set<ScriptProfile*> enabledProfiles_;
48 };
49 
getProfilesRegistry()50 ProfilesRegistry& getProfilesRegistry() {
51   static auto registry = std::ref(*new ProfilesRegistry{});
52   return registry;
53 }
54 
initBindings()55 auto initBindings() {
56   torch::class_<SourceRef>("profiling", "SourceRef")
57       .def(
58           "starting_lineno",
59           [](const c10::intrusive_ptr<SourceRef>& self) {
60             return static_cast<int64_t>((*self)->starting_line_no());
61           })
62       .def("text", [](const c10::intrusive_ptr<SourceRef>& self) {
63         return (*self)->text_str().str();
64       });
65 
66   torch::class_<InstructionStats>("profiling", "InstructionStats")
67       .def(
68           "count",
69           [](const c10::intrusive_ptr<InstructionStats>& self) {
70             return self->count;
71           })
72       .def("duration_ns", [](const c10::intrusive_ptr<InstructionStats>& self) {
73         return static_cast<int64_t>(self->duration.count());
74       });
75 
76   torch::class_<SourceStats>("profiling", "SourceStats")
77       .def(
78           "source",
79           [](const c10::intrusive_ptr<SourceStats>& self) {
80             return c10::make_intrusive<SourceRef>(self->getSourceRef());
81           })
82       .def("line_map", &SourceStats::getLineMap);
83 
84   torch::class_<ScriptProfile>("profiling", "_ScriptProfile")
85       .def(torch::init<>())
86       .def("enable", &ScriptProfile::enable)
87       .def("disable", &ScriptProfile::disable)
88       .def("_dump_stats", [](const c10::intrusive_ptr<ScriptProfile>& self) {
89         const auto& stats = self->dumpStats();
90         c10::List<c10::intrusive_ptr<SourceStats>> ret;
91         for (const auto& source : stats) {
92           SourceStats::LineMap lineMap;
93           for (const auto& line : source.second) {
94             lineMap.insert(
95                 line.first, c10::make_intrusive<InstructionStats>(line.second));
96           }
97           ret.push_back(c10::make_intrusive<SourceStats>(
98               source.first, std::move(lineMap)));
99         }
100         return ret;
101       });
102   return nullptr;
103 }
104 
105 const auto C10_UNUSED torchBindInitializer = initBindings();
106 
107 } // namespace
108 
109 namespace profiling {
110 
InstructionSpan(Node & node)111 InstructionSpan::InstructionSpan(Node& node) {
112   datapoint_ = std::make_unique<Datapoint>(node.sourceRange());
113 }
114 
~InstructionSpan()115 InstructionSpan::~InstructionSpan() {
116   datapoint_->end = std::chrono::steady_clock::now();
117   getProfilesRegistry().send(std::move(datapoint_));
118 }
119 
isProfilingOngoing()120 bool isProfilingOngoing() {
121   return !getProfilesRegistry().empty();
122 }
123 
124 } // namespace profiling
125 
enable()126 void ScriptProfile::enable() {
127   if (!std::exchange(enabled_, true)) {
128     getProfilesRegistry().addProfile(*this);
129   }
130 }
131 
disable()132 void ScriptProfile::disable() {
133   if (std::exchange(enabled_, false)) {
134     getProfilesRegistry().removeProfile(*this);
135   }
136 }
137 
addDatapoint(std::shared_ptr<profiling::Datapoint> datapoint)138 void ScriptProfile::addDatapoint(
139     std::shared_ptr<profiling::Datapoint> datapoint) {
140   TORCH_CHECK(enabled_, "Cannot only add datapoint to disabled profilers.");
141   datapoints_.push_back(std::move(datapoint));
142 }
143 
dumpStats()144 const ScriptProfile::SourceMap& ScriptProfile::dumpStats() {
145   TORCH_CHECK(!enabled_, "Only disabled profilers are allowed to dump stats.");
146 
147   for (const auto& datapoint : datapoints_) {
148     if (const auto& source = datapoint->sourceRange.source()) {
149       if (auto fileLineCol = datapoint->sourceRange.file_line_col()) {
150         auto it = sourceMap_.find(*source);
151         if (it == sourceMap_.end()) {
152           it = sourceMap_.emplace(SourceRef{source}, LineMap{}).first;
153         }
154         auto& stats = it->second[std::get<1>(*fileLineCol)];
155         stats.count++;
156         stats.duration += datapoint->end - datapoint->start;
157       }
158     }
159   }
160   datapoints_.clear();
161 
162   return sourceMap_;
163 }
164 
~ScriptProfile()165 ScriptProfile::~ScriptProfile() {
166   if (enabled_) {
167     getProfilesRegistry().removeProfile(*this);
168   }
169 }
170 
171 } // namespace torch::jit
172