xref: /aosp_15_r20/external/pytorch/test/cpp/c10d/TestUtils.hpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1*da0073e9SAndroid Build Coastguard Worker #pragma once
2*da0073e9SAndroid Build Coastguard Worker 
3*da0073e9SAndroid Build Coastguard Worker #ifndef _WIN32
4*da0073e9SAndroid Build Coastguard Worker #include <signal.h>
5*da0073e9SAndroid Build Coastguard Worker #include <sys/wait.h>
6*da0073e9SAndroid Build Coastguard Worker #include <unistd.h>
7*da0073e9SAndroid Build Coastguard Worker #endif
8*da0073e9SAndroid Build Coastguard Worker 
9*da0073e9SAndroid Build Coastguard Worker #include <sys/types.h>
10*da0073e9SAndroid Build Coastguard Worker #include <cstring>
11*da0073e9SAndroid Build Coastguard Worker 
12*da0073e9SAndroid Build Coastguard Worker #include <condition_variable>
13*da0073e9SAndroid Build Coastguard Worker #include <mutex>
14*da0073e9SAndroid Build Coastguard Worker #include <string>
15*da0073e9SAndroid Build Coastguard Worker #include <system_error>
16*da0073e9SAndroid Build Coastguard Worker #include <vector>
17*da0073e9SAndroid Build Coastguard Worker 
18*da0073e9SAndroid Build Coastguard Worker namespace c10d {
19*da0073e9SAndroid Build Coastguard Worker namespace test {
20*da0073e9SAndroid Build Coastguard Worker 
21*da0073e9SAndroid Build Coastguard Worker class Semaphore {
22*da0073e9SAndroid Build Coastguard Worker  public:
post(int n=1)23*da0073e9SAndroid Build Coastguard Worker   void post(int n = 1) {
24*da0073e9SAndroid Build Coastguard Worker     std::unique_lock<std::mutex> lock(m_);
25*da0073e9SAndroid Build Coastguard Worker     n_ += n;
26*da0073e9SAndroid Build Coastguard Worker     cv_.notify_all();
27*da0073e9SAndroid Build Coastguard Worker   }
28*da0073e9SAndroid Build Coastguard Worker 
wait(int n=1)29*da0073e9SAndroid Build Coastguard Worker   void wait(int n = 1) {
30*da0073e9SAndroid Build Coastguard Worker     std::unique_lock<std::mutex> lock(m_);
31*da0073e9SAndroid Build Coastguard Worker     while (n_ < n) {
32*da0073e9SAndroid Build Coastguard Worker       cv_.wait(lock);
33*da0073e9SAndroid Build Coastguard Worker     }
34*da0073e9SAndroid Build Coastguard Worker     n_ -= n;
35*da0073e9SAndroid Build Coastguard Worker   }
36*da0073e9SAndroid Build Coastguard Worker 
37*da0073e9SAndroid Build Coastguard Worker  protected:
38*da0073e9SAndroid Build Coastguard Worker   int n_ = 0;
39*da0073e9SAndroid Build Coastguard Worker   std::mutex m_;
40*da0073e9SAndroid Build Coastguard Worker   std::condition_variable cv_;
41*da0073e9SAndroid Build Coastguard Worker };
42*da0073e9SAndroid Build Coastguard Worker 
43*da0073e9SAndroid Build Coastguard Worker #ifdef _WIN32
autoGenerateTmpFilePath()44*da0073e9SAndroid Build Coastguard Worker std::string autoGenerateTmpFilePath() {
45*da0073e9SAndroid Build Coastguard Worker   char tmp[L_tmpnam_s];
46*da0073e9SAndroid Build Coastguard Worker   errno_t err;
47*da0073e9SAndroid Build Coastguard Worker   err = tmpnam_s(tmp, L_tmpnam_s);
48*da0073e9SAndroid Build Coastguard Worker   if (err != 0)
49*da0073e9SAndroid Build Coastguard Worker   {
50*da0073e9SAndroid Build Coastguard Worker     throw std::system_error(errno, std::system_category());
51*da0073e9SAndroid Build Coastguard Worker   }
52*da0073e9SAndroid Build Coastguard Worker   return std::string(tmp);
53*da0073e9SAndroid Build Coastguard Worker }
54*da0073e9SAndroid Build Coastguard Worker 
tmppath()55*da0073e9SAndroid Build Coastguard Worker std::string tmppath() {
56*da0073e9SAndroid Build Coastguard Worker   const char* tmpfile = getenv("TMPFILE");
57*da0073e9SAndroid Build Coastguard Worker   if (tmpfile) {
58*da0073e9SAndroid Build Coastguard Worker     return std::string(tmpfile);
59*da0073e9SAndroid Build Coastguard Worker   }
60*da0073e9SAndroid Build Coastguard Worker   else {
61*da0073e9SAndroid Build Coastguard Worker     return autoGenerateTmpFilePath();
62*da0073e9SAndroid Build Coastguard Worker   }
63*da0073e9SAndroid Build Coastguard Worker }
64*da0073e9SAndroid Build Coastguard Worker #else
tmppath()65*da0073e9SAndroid Build Coastguard Worker std::string tmppath() {
66*da0073e9SAndroid Build Coastguard Worker   // TMPFILE is for manual test execution during which the user will specify
67*da0073e9SAndroid Build Coastguard Worker   // the full temp file path using the environmental variable TMPFILE
68*da0073e9SAndroid Build Coastguard Worker   const char* tmpfile = getenv("TMPFILE");
69*da0073e9SAndroid Build Coastguard Worker   if (tmpfile) {
70*da0073e9SAndroid Build Coastguard Worker     return std::string(tmpfile);
71*da0073e9SAndroid Build Coastguard Worker   }
72*da0073e9SAndroid Build Coastguard Worker 
73*da0073e9SAndroid Build Coastguard Worker   const char* tmpdir = getenv("TMPDIR");
74*da0073e9SAndroid Build Coastguard Worker   if (tmpdir == nullptr) {
75*da0073e9SAndroid Build Coastguard Worker     tmpdir = "/tmp";
76*da0073e9SAndroid Build Coastguard Worker   }
77*da0073e9SAndroid Build Coastguard Worker 
78*da0073e9SAndroid Build Coastguard Worker   // Create template
79*da0073e9SAndroid Build Coastguard Worker   std::vector<char> tmp(256);
80*da0073e9SAndroid Build Coastguard Worker   auto len = snprintf(tmp.data(), tmp.size(), "%s/testXXXXXX", tmpdir);
81*da0073e9SAndroid Build Coastguard Worker   tmp.resize(len);
82*da0073e9SAndroid Build Coastguard Worker 
83*da0073e9SAndroid Build Coastguard Worker   // Create temporary file
84*da0073e9SAndroid Build Coastguard Worker   auto fd = mkstemp(&tmp[0]);
85*da0073e9SAndroid Build Coastguard Worker   if (fd == -1) {
86*da0073e9SAndroid Build Coastguard Worker     throw std::system_error(errno, std::system_category());
87*da0073e9SAndroid Build Coastguard Worker   }
88*da0073e9SAndroid Build Coastguard Worker   close(fd);
89*da0073e9SAndroid Build Coastguard Worker   return std::string(tmp.data(), tmp.size());
90*da0073e9SAndroid Build Coastguard Worker }
91*da0073e9SAndroid Build Coastguard Worker #endif
92*da0073e9SAndroid Build Coastguard Worker 
isTSANEnabled()93*da0073e9SAndroid Build Coastguard Worker bool isTSANEnabled() {
94*da0073e9SAndroid Build Coastguard Worker   auto s = std::getenv("PYTORCH_TEST_WITH_TSAN");
95*da0073e9SAndroid Build Coastguard Worker   return s && strcmp(s, "1") == 0;
96*da0073e9SAndroid Build Coastguard Worker }
97*da0073e9SAndroid Build Coastguard Worker struct TemporaryFile {
98*da0073e9SAndroid Build Coastguard Worker   std::string path;
99*da0073e9SAndroid Build Coastguard Worker 
TemporaryFilec10d::test::TemporaryFile100*da0073e9SAndroid Build Coastguard Worker   TemporaryFile() {
101*da0073e9SAndroid Build Coastguard Worker     path = tmppath();
102*da0073e9SAndroid Build Coastguard Worker   }
103*da0073e9SAndroid Build Coastguard Worker 
~TemporaryFilec10d::test::TemporaryFile104*da0073e9SAndroid Build Coastguard Worker   ~TemporaryFile() {
105*da0073e9SAndroid Build Coastguard Worker     unlink(path.c_str());
106*da0073e9SAndroid Build Coastguard Worker   }
107*da0073e9SAndroid Build Coastguard Worker };
108*da0073e9SAndroid Build Coastguard Worker 
109*da0073e9SAndroid Build Coastguard Worker #ifndef _WIN32
110*da0073e9SAndroid Build Coastguard Worker struct Fork {
111*da0073e9SAndroid Build Coastguard Worker   pid_t pid;
112*da0073e9SAndroid Build Coastguard Worker 
Forkc10d::test::Fork113*da0073e9SAndroid Build Coastguard Worker   Fork() {
114*da0073e9SAndroid Build Coastguard Worker     pid = fork();
115*da0073e9SAndroid Build Coastguard Worker     if (pid < 0) {
116*da0073e9SAndroid Build Coastguard Worker       throw std::system_error(errno, std::system_category(), "fork");
117*da0073e9SAndroid Build Coastguard Worker     }
118*da0073e9SAndroid Build Coastguard Worker   }
119*da0073e9SAndroid Build Coastguard Worker 
~Forkc10d::test::Fork120*da0073e9SAndroid Build Coastguard Worker   ~Fork() {
121*da0073e9SAndroid Build Coastguard Worker     if (pid > 0) {
122*da0073e9SAndroid Build Coastguard Worker       kill(pid, SIGKILL);
123*da0073e9SAndroid Build Coastguard Worker       waitpid(pid, nullptr, 0);
124*da0073e9SAndroid Build Coastguard Worker     }
125*da0073e9SAndroid Build Coastguard Worker   }
126*da0073e9SAndroid Build Coastguard Worker 
isChildc10d::test::Fork127*da0073e9SAndroid Build Coastguard Worker   bool isChild() {
128*da0073e9SAndroid Build Coastguard Worker     return pid == 0;
129*da0073e9SAndroid Build Coastguard Worker   }
130*da0073e9SAndroid Build Coastguard Worker };
131*da0073e9SAndroid Build Coastguard Worker #endif
132*da0073e9SAndroid Build Coastguard Worker 
133*da0073e9SAndroid Build Coastguard Worker } // namespace test
134*da0073e9SAndroid Build Coastguard Worker } // namespace c10d
135