xref: /aosp_15_r20/external/pytorch/test/cpp/monitor/test_counters.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <gtest/gtest.h>
2 
3 #include <thread>
4 
5 #include <torch/csrc/monitor/counters.h>
6 #include <torch/csrc/monitor/events.h>
7 
8 using namespace torch::monitor;
9 
TEST(MonitorTest,CounterDouble)10 TEST(MonitorTest, CounterDouble) {
11   Stat<double> a{
12       "a",
13       {Aggregation::MEAN, Aggregation::COUNT},
14       std::chrono::milliseconds(100000),
15       2,
16   };
17   a.add(5.0);
18   ASSERT_EQ(a.count(), 1);
19   a.add(6.0);
20   ASSERT_EQ(a.count(), 0);
21 
22   auto stats = a.get();
23   std::unordered_map<Aggregation, double, AggregationHash> want = {
24       {Aggregation::MEAN, 5.5},
25       {Aggregation::COUNT, 2.0},
26   };
27   ASSERT_EQ(stats, want);
28 }
29 
TEST(MonitorTest,CounterInt64Sum)30 TEST(MonitorTest, CounterInt64Sum) {
31   Stat<int64_t> a{
32       "a",
33       {Aggregation::SUM},
34       std::chrono::milliseconds(100000),
35       2,
36   };
37   a.add(5);
38   a.add(6);
39   auto stats = a.get();
40   std::unordered_map<Aggregation, int64_t, AggregationHash> want = {
41       {Aggregation::SUM, 11},
42   };
43   ASSERT_EQ(stats, want);
44 }
45 
TEST(MonitorTest,CounterInt64Value)46 TEST(MonitorTest, CounterInt64Value) {
47   Stat<int64_t> a{
48       "a",
49       {Aggregation::VALUE},
50       std::chrono::milliseconds(100000),
51       2,
52   };
53   a.add(5);
54   a.add(6);
55   auto stats = a.get();
56   std::unordered_map<Aggregation, int64_t, AggregationHash> want = {
57       {Aggregation::VALUE, 6},
58   };
59   ASSERT_EQ(stats, want);
60 }
61 
TEST(MonitorTest,CounterInt64Mean)62 TEST(MonitorTest, CounterInt64Mean) {
63   Stat<int64_t> a{
64       "a",
65       {Aggregation::MEAN},
66       std::chrono::milliseconds(100000),
67       2,
68   };
69   {
70     // zero samples case
71     auto stats = a.get();
72     std::unordered_map<Aggregation, int64_t, AggregationHash> want = {
73         {Aggregation::MEAN, 0},
74     };
75     ASSERT_EQ(stats, want);
76   }
77 
78   a.add(0);
79   a.add(10);
80 
81   {
82     auto stats = a.get();
83     std::unordered_map<Aggregation, int64_t, AggregationHash> want = {
84         {Aggregation::MEAN, 5},
85     };
86     ASSERT_EQ(stats, want);
87   }
88 }
89 
TEST(MonitorTest,CounterInt64Count)90 TEST(MonitorTest, CounterInt64Count) {
91   Stat<int64_t> a{
92       "a",
93       {Aggregation::COUNT},
94       std::chrono::milliseconds(100000),
95       2,
96   };
97   ASSERT_EQ(a.count(), 0);
98   a.add(0);
99   ASSERT_EQ(a.count(), 1);
100   a.add(10);
101   ASSERT_EQ(a.count(), 0);
102 
103   auto stats = a.get();
104   std::unordered_map<Aggregation, int64_t, AggregationHash> want = {
105       {Aggregation::COUNT, 2},
106   };
107   ASSERT_EQ(stats, want);
108 }
109 
TEST(MonitorTest,CounterInt64MinMax)110 TEST(MonitorTest, CounterInt64MinMax) {
111   Stat<int64_t> a{
112       "a",
113       {Aggregation::MIN, Aggregation::MAX},
114       std::chrono::milliseconds(100000),
115       6,
116   };
117   {
118     auto stats = a.get();
119     std::unordered_map<Aggregation, int64_t, AggregationHash> want = {
120         {Aggregation::MAX, 0},
121         {Aggregation::MIN, 0},
122     };
123     ASSERT_EQ(stats, want);
124   }
125 
126   a.add(0);
127   a.add(5);
128   a.add(-5);
129   a.add(-6);
130   a.add(9);
131   a.add(2);
132   {
133     auto stats = a.get();
134     std::unordered_map<Aggregation, int64_t, AggregationHash> want = {
135         {Aggregation::MAX, 9},
136         {Aggregation::MIN, -6},
137     };
138     ASSERT_EQ(stats, want);
139   }
140 }
141 
TEST(MonitorTest,CounterInt64WindowSize)142 TEST(MonitorTest, CounterInt64WindowSize) {
143   Stat<int64_t> a{
144       "a",
145       {Aggregation::COUNT, Aggregation::SUM},
146       std::chrono::milliseconds(100000),
147       /*windowSize=*/3,
148   };
149   a.add(1);
150   a.add(2);
151   ASSERT_EQ(a.count(), 2);
152   a.add(3);
153   ASSERT_EQ(a.count(), 0);
154 
155   // after logging max for window, should be zero
156   a.add(4);
157   ASSERT_EQ(a.count(), 0);
158 
159   auto stats = a.get();
160   std::unordered_map<Aggregation, int64_t, AggregationHash> want = {
161       {Aggregation::COUNT, 3},
162       {Aggregation::SUM, 6},
163   };
164   ASSERT_EQ(stats, want);
165 }
166 
TEST(MonitorTest,CounterInt64WindowSizeHuge)167 TEST(MonitorTest, CounterInt64WindowSizeHuge) {
168   Stat<int64_t> a{
169       "a",
170       {Aggregation::COUNT, Aggregation::SUM},
171       std::chrono::hours(24 * 365 * 10), // 10 years
172       /*windowSize=*/3,
173   };
174   a.add(1);
175   a.add(2);
176   ASSERT_EQ(a.count(), 2);
177   a.add(3);
178   ASSERT_EQ(a.count(), 0);
179 
180   // after logging max for window, should be zero
181   a.add(4);
182   ASSERT_EQ(a.count(), 0);
183 
184   auto stats = a.get();
185   std::unordered_map<Aggregation, int64_t, AggregationHash> want = {
186       {Aggregation::COUNT, 3},
187       {Aggregation::SUM, 6},
188   };
189   ASSERT_EQ(stats, want);
190 }
191 
192 template <typename T>
193 struct TestStat : public Stat<T> {
194   uint64_t mockWindowId{1};
195 
TestStatTestStat196   TestStat(
197       std::string name,
198       std::initializer_list<Aggregation> aggregations,
199       std::chrono::milliseconds windowSize,
200       int64_t maxSamples = std::numeric_limits<int64_t>::max())
201       : Stat<T>(name, aggregations, windowSize, maxSamples) {}
202 
currentWindowIdTestStat203   uint64_t currentWindowId() const override {
204     return mockWindowId;
205   }
206 };
207 
208 struct AggregatingEventHandler : public EventHandler {
209   std::vector<Event> events;
210 
handleAggregatingEventHandler211   void handle(const Event& e) override {
212     events.emplace_back(e);
213   }
214 };
215 
216 template <typename T>
217 struct HandlerGuard {
218   std::shared_ptr<T> handler;
219 
HandlerGuardHandlerGuard220   HandlerGuard() : handler(std::make_shared<T>()) {
221     registerEventHandler(handler);
222   }
223 
~HandlerGuardHandlerGuard224   ~HandlerGuard() {
225     unregisterEventHandler(handler);
226   }
227 };
228 
TEST(MonitorTest,Stat)229 TEST(MonitorTest, Stat) {
230   HandlerGuard<AggregatingEventHandler> guard;
231 
232   Stat<int64_t> a{
233       "a",
234       {Aggregation::COUNT, Aggregation::SUM},
235       std::chrono::milliseconds(1),
236   };
237   ASSERT_EQ(guard.handler->events.size(), 0);
238 
239   a.add(1);
240   ASSERT_LE(a.count(), 1);
241 
242   std::this_thread::sleep_for(std::chrono::milliseconds(2));
243   a.add(2);
244   ASSERT_LE(a.count(), 1);
245 
246   ASSERT_GE(guard.handler->events.size(), 1);
247   ASSERT_LE(guard.handler->events.size(), 2);
248 }
249 
TEST(MonitorTest,StatEvent)250 TEST(MonitorTest, StatEvent) {
251   HandlerGuard<AggregatingEventHandler> guard;
252 
253   TestStat<int64_t> a{
254       "a",
255       {Aggregation::COUNT, Aggregation::SUM},
256       std::chrono::milliseconds(1),
257   };
258   ASSERT_EQ(guard.handler->events.size(), 0);
259 
260   a.add(1);
261   ASSERT_EQ(a.count(), 1);
262   a.add(2);
263   ASSERT_EQ(a.count(), 2);
264   ASSERT_EQ(guard.handler->events.size(), 0);
265 
266   a.mockWindowId = 100;
267 
268   a.add(3);
269   ASSERT_LE(a.count(), 1);
270 
271   ASSERT_EQ(guard.handler->events.size(), 1);
272   Event e = guard.handler->events.at(0);
273   ASSERT_EQ(e.name, "torch.monitor.Stat");
274   ASSERT_NE(e.timestamp, std::chrono::system_clock::time_point{});
275   std::unordered_map<std::string, data_value_t> data{
276       {"a.sum", 3L},
277       {"a.count", 2L},
278   };
279   ASSERT_EQ(e.data, data);
280 }
281 
TEST(MonitorTest,StatEventDestruction)282 TEST(MonitorTest, StatEventDestruction) {
283   HandlerGuard<AggregatingEventHandler> guard;
284 
285   {
286     TestStat<int64_t> a{
287         "a",
288         {Aggregation::COUNT, Aggregation::SUM},
289         std::chrono::hours(10),
290     };
291     a.add(1);
292     ASSERT_EQ(a.count(), 1);
293     ASSERT_EQ(guard.handler->events.size(), 0);
294   }
295   ASSERT_EQ(guard.handler->events.size(), 1);
296 
297   Event e = guard.handler->events.at(0);
298   ASSERT_EQ(e.name, "torch.monitor.Stat");
299   ASSERT_NE(e.timestamp, std::chrono::system_clock::time_point{});
300   std::unordered_map<std::string, data_value_t> data{
301       {"a.sum", 1L},
302       {"a.count", 1L},
303   };
304   ASSERT_EQ(e.data, data);
305 }
306