1 #include <gtest/gtest.h>
2
3 #include <thread>
4
5 #include <torch/csrc/monitor/counters.h>
6 #include <torch/csrc/monitor/events.h>
7
8 using namespace torch::monitor;
9
TEST(MonitorTest,CounterDouble)10 TEST(MonitorTest, CounterDouble) {
11 Stat<double> a{
12 "a",
13 {Aggregation::MEAN, Aggregation::COUNT},
14 std::chrono::milliseconds(100000),
15 2,
16 };
17 a.add(5.0);
18 ASSERT_EQ(a.count(), 1);
19 a.add(6.0);
20 ASSERT_EQ(a.count(), 0);
21
22 auto stats = a.get();
23 std::unordered_map<Aggregation, double, AggregationHash> want = {
24 {Aggregation::MEAN, 5.5},
25 {Aggregation::COUNT, 2.0},
26 };
27 ASSERT_EQ(stats, want);
28 }
29
TEST(MonitorTest,CounterInt64Sum)30 TEST(MonitorTest, CounterInt64Sum) {
31 Stat<int64_t> a{
32 "a",
33 {Aggregation::SUM},
34 std::chrono::milliseconds(100000),
35 2,
36 };
37 a.add(5);
38 a.add(6);
39 auto stats = a.get();
40 std::unordered_map<Aggregation, int64_t, AggregationHash> want = {
41 {Aggregation::SUM, 11},
42 };
43 ASSERT_EQ(stats, want);
44 }
45
TEST(MonitorTest,CounterInt64Value)46 TEST(MonitorTest, CounterInt64Value) {
47 Stat<int64_t> a{
48 "a",
49 {Aggregation::VALUE},
50 std::chrono::milliseconds(100000),
51 2,
52 };
53 a.add(5);
54 a.add(6);
55 auto stats = a.get();
56 std::unordered_map<Aggregation, int64_t, AggregationHash> want = {
57 {Aggregation::VALUE, 6},
58 };
59 ASSERT_EQ(stats, want);
60 }
61
TEST(MonitorTest,CounterInt64Mean)62 TEST(MonitorTest, CounterInt64Mean) {
63 Stat<int64_t> a{
64 "a",
65 {Aggregation::MEAN},
66 std::chrono::milliseconds(100000),
67 2,
68 };
69 {
70 // zero samples case
71 auto stats = a.get();
72 std::unordered_map<Aggregation, int64_t, AggregationHash> want = {
73 {Aggregation::MEAN, 0},
74 };
75 ASSERT_EQ(stats, want);
76 }
77
78 a.add(0);
79 a.add(10);
80
81 {
82 auto stats = a.get();
83 std::unordered_map<Aggregation, int64_t, AggregationHash> want = {
84 {Aggregation::MEAN, 5},
85 };
86 ASSERT_EQ(stats, want);
87 }
88 }
89
TEST(MonitorTest,CounterInt64Count)90 TEST(MonitorTest, CounterInt64Count) {
91 Stat<int64_t> a{
92 "a",
93 {Aggregation::COUNT},
94 std::chrono::milliseconds(100000),
95 2,
96 };
97 ASSERT_EQ(a.count(), 0);
98 a.add(0);
99 ASSERT_EQ(a.count(), 1);
100 a.add(10);
101 ASSERT_EQ(a.count(), 0);
102
103 auto stats = a.get();
104 std::unordered_map<Aggregation, int64_t, AggregationHash> want = {
105 {Aggregation::COUNT, 2},
106 };
107 ASSERT_EQ(stats, want);
108 }
109
TEST(MonitorTest,CounterInt64MinMax)110 TEST(MonitorTest, CounterInt64MinMax) {
111 Stat<int64_t> a{
112 "a",
113 {Aggregation::MIN, Aggregation::MAX},
114 std::chrono::milliseconds(100000),
115 6,
116 };
117 {
118 auto stats = a.get();
119 std::unordered_map<Aggregation, int64_t, AggregationHash> want = {
120 {Aggregation::MAX, 0},
121 {Aggregation::MIN, 0},
122 };
123 ASSERT_EQ(stats, want);
124 }
125
126 a.add(0);
127 a.add(5);
128 a.add(-5);
129 a.add(-6);
130 a.add(9);
131 a.add(2);
132 {
133 auto stats = a.get();
134 std::unordered_map<Aggregation, int64_t, AggregationHash> want = {
135 {Aggregation::MAX, 9},
136 {Aggregation::MIN, -6},
137 };
138 ASSERT_EQ(stats, want);
139 }
140 }
141
TEST(MonitorTest,CounterInt64WindowSize)142 TEST(MonitorTest, CounterInt64WindowSize) {
143 Stat<int64_t> a{
144 "a",
145 {Aggregation::COUNT, Aggregation::SUM},
146 std::chrono::milliseconds(100000),
147 /*windowSize=*/3,
148 };
149 a.add(1);
150 a.add(2);
151 ASSERT_EQ(a.count(), 2);
152 a.add(3);
153 ASSERT_EQ(a.count(), 0);
154
155 // after logging max for window, should be zero
156 a.add(4);
157 ASSERT_EQ(a.count(), 0);
158
159 auto stats = a.get();
160 std::unordered_map<Aggregation, int64_t, AggregationHash> want = {
161 {Aggregation::COUNT, 3},
162 {Aggregation::SUM, 6},
163 };
164 ASSERT_EQ(stats, want);
165 }
166
TEST(MonitorTest,CounterInt64WindowSizeHuge)167 TEST(MonitorTest, CounterInt64WindowSizeHuge) {
168 Stat<int64_t> a{
169 "a",
170 {Aggregation::COUNT, Aggregation::SUM},
171 std::chrono::hours(24 * 365 * 10), // 10 years
172 /*windowSize=*/3,
173 };
174 a.add(1);
175 a.add(2);
176 ASSERT_EQ(a.count(), 2);
177 a.add(3);
178 ASSERT_EQ(a.count(), 0);
179
180 // after logging max for window, should be zero
181 a.add(4);
182 ASSERT_EQ(a.count(), 0);
183
184 auto stats = a.get();
185 std::unordered_map<Aggregation, int64_t, AggregationHash> want = {
186 {Aggregation::COUNT, 3},
187 {Aggregation::SUM, 6},
188 };
189 ASSERT_EQ(stats, want);
190 }
191
192 template <typename T>
193 struct TestStat : public Stat<T> {
194 uint64_t mockWindowId{1};
195
TestStatTestStat196 TestStat(
197 std::string name,
198 std::initializer_list<Aggregation> aggregations,
199 std::chrono::milliseconds windowSize,
200 int64_t maxSamples = std::numeric_limits<int64_t>::max())
201 : Stat<T>(name, aggregations, windowSize, maxSamples) {}
202
currentWindowIdTestStat203 uint64_t currentWindowId() const override {
204 return mockWindowId;
205 }
206 };
207
208 struct AggregatingEventHandler : public EventHandler {
209 std::vector<Event> events;
210
handleAggregatingEventHandler211 void handle(const Event& e) override {
212 events.emplace_back(e);
213 }
214 };
215
216 template <typename T>
217 struct HandlerGuard {
218 std::shared_ptr<T> handler;
219
HandlerGuardHandlerGuard220 HandlerGuard() : handler(std::make_shared<T>()) {
221 registerEventHandler(handler);
222 }
223
~HandlerGuardHandlerGuard224 ~HandlerGuard() {
225 unregisterEventHandler(handler);
226 }
227 };
228
TEST(MonitorTest,Stat)229 TEST(MonitorTest, Stat) {
230 HandlerGuard<AggregatingEventHandler> guard;
231
232 Stat<int64_t> a{
233 "a",
234 {Aggregation::COUNT, Aggregation::SUM},
235 std::chrono::milliseconds(1),
236 };
237 ASSERT_EQ(guard.handler->events.size(), 0);
238
239 a.add(1);
240 ASSERT_LE(a.count(), 1);
241
242 std::this_thread::sleep_for(std::chrono::milliseconds(2));
243 a.add(2);
244 ASSERT_LE(a.count(), 1);
245
246 ASSERT_GE(guard.handler->events.size(), 1);
247 ASSERT_LE(guard.handler->events.size(), 2);
248 }
249
TEST(MonitorTest,StatEvent)250 TEST(MonitorTest, StatEvent) {
251 HandlerGuard<AggregatingEventHandler> guard;
252
253 TestStat<int64_t> a{
254 "a",
255 {Aggregation::COUNT, Aggregation::SUM},
256 std::chrono::milliseconds(1),
257 };
258 ASSERT_EQ(guard.handler->events.size(), 0);
259
260 a.add(1);
261 ASSERT_EQ(a.count(), 1);
262 a.add(2);
263 ASSERT_EQ(a.count(), 2);
264 ASSERT_EQ(guard.handler->events.size(), 0);
265
266 a.mockWindowId = 100;
267
268 a.add(3);
269 ASSERT_LE(a.count(), 1);
270
271 ASSERT_EQ(guard.handler->events.size(), 1);
272 Event e = guard.handler->events.at(0);
273 ASSERT_EQ(e.name, "torch.monitor.Stat");
274 ASSERT_NE(e.timestamp, std::chrono::system_clock::time_point{});
275 std::unordered_map<std::string, data_value_t> data{
276 {"a.sum", 3L},
277 {"a.count", 2L},
278 };
279 ASSERT_EQ(e.data, data);
280 }
281
TEST(MonitorTest,StatEventDestruction)282 TEST(MonitorTest, StatEventDestruction) {
283 HandlerGuard<AggregatingEventHandler> guard;
284
285 {
286 TestStat<int64_t> a{
287 "a",
288 {Aggregation::COUNT, Aggregation::SUM},
289 std::chrono::hours(10),
290 };
291 a.add(1);
292 ASSERT_EQ(a.count(), 1);
293 ASSERT_EQ(guard.handler->events.size(), 0);
294 }
295 ASSERT_EQ(guard.handler->events.size(), 1);
296
297 Event e = guard.handler->events.at(0);
298 ASSERT_EQ(e.name, "torch.monitor.Stat");
299 ASSERT_NE(e.timestamp, std::chrono::system_clock::time_point{});
300 std::unordered_map<std::string, data_value_t> data{
301 {"a.sum", 1L},
302 {"a.count", 1L},
303 };
304 ASSERT_EQ(e.data, data);
305 }
306