1*da0073e9SAndroid Build Coastguard Worker #include <torch/csrc/DataLoader.h>
2*da0073e9SAndroid Build Coastguard Worker
3*da0073e9SAndroid Build Coastguard Worker // Together with `torch/utils/data/_utils/signal_handling.py`, the following
4*da0073e9SAndroid Build Coastguard Worker // is an effort to do our best to provide some error message to users when a
5*da0073e9SAndroid Build Coastguard Worker // worker dies due to error / critical signals.
6*da0073e9SAndroid Build Coastguard Worker //
7*da0073e9SAndroid Build Coastguard Worker // See NOTE [ Signal handling in multiprocessing data loading ] for more
8*da0073e9SAndroid Build Coastguard Worker // details.
9*da0073e9SAndroid Build Coastguard Worker
10*da0073e9SAndroid Build Coastguard Worker // TODO: The following don't work on Windows. Specifically, sigaction, waitid
11*da0073e9SAndroid Build Coastguard Worker // calls, and SIGCHLD handler. Currently, dummy implementations are provided
12*da0073e9SAndroid Build Coastguard Worker // for Windows.
13*da0073e9SAndroid Build Coastguard Worker
14*da0073e9SAndroid Build Coastguard Worker #ifndef _WIN32
15*da0073e9SAndroid Build Coastguard Worker
16*da0073e9SAndroid Build Coastguard Worker #include <torch/csrc/Exceptions.h>
17*da0073e9SAndroid Build Coastguard Worker #include <torch/csrc/utils/python_numbers.h>
18*da0073e9SAndroid Build Coastguard Worker
19*da0073e9SAndroid Build Coastguard Worker #include <c10/util/irange.h>
20*da0073e9SAndroid Build Coastguard Worker #include <fmt/format.h>
21*da0073e9SAndroid Build Coastguard Worker
22*da0073e9SAndroid Build Coastguard Worker #include <sys/wait.h>
23*da0073e9SAndroid Build Coastguard Worker #include <csignal>
24*da0073e9SAndroid Build Coastguard Worker #include <map>
25*da0073e9SAndroid Build Coastguard Worker #include <set>
26*da0073e9SAndroid Build Coastguard Worker #include <sstream>
27*da0073e9SAndroid Build Coastguard Worker
28*da0073e9SAndroid Build Coastguard Worker using namespace torch;
29*da0073e9SAndroid Build Coastguard Worker
30*da0073e9SAndroid Build Coastguard Worker // Critical signal handlers should be registered on worker processes before
31*da0073e9SAndroid Build Coastguard Worker // doing work.
32*da0073e9SAndroid Build Coastguard Worker // The handler will raise default handler so that the kill information will be
33*da0073e9SAndroid Build Coastguard Worker // retrieved from main process.
34*da0073e9SAndroid Build Coastguard Worker // Python handle is _set_worker_signal_handlers().
35*da0073e9SAndroid Build Coastguard Worker #define SIGNAL_HANDLER(SIGNAL, HANDLER_NAME, ERROR_MSG) \
36*da0073e9SAndroid Build Coastguard Worker static void HANDLER_NAME(int sig, siginfo_t* info, void* ctx) { \
37*da0073e9SAndroid Build Coastguard Worker auto _w = \
38*da0073e9SAndroid Build Coastguard Worker write(STDERR_FILENO, ERROR_MSG, sizeof(ERROR_MSG) / sizeof(char)); \
39*da0073e9SAndroid Build Coastguard Worker (void)_w; \
40*da0073e9SAndroid Build Coastguard Worker struct sigaction sa {}; \
41*da0073e9SAndroid Build Coastguard Worker sa.sa_handler = SIG_DFL; \
42*da0073e9SAndroid Build Coastguard Worker sa.sa_flags = 0; \
43*da0073e9SAndroid Build Coastguard Worker if (sigemptyset(&sa.sa_mask) != 0 || \
44*da0073e9SAndroid Build Coastguard Worker sigaction(SIGNAL, &sa, nullptr) != 0) { \
45*da0073e9SAndroid Build Coastguard Worker _exit(EXIT_FAILURE); \
46*da0073e9SAndroid Build Coastguard Worker } else { \
47*da0073e9SAndroid Build Coastguard Worker raise(SIGNAL); \
48*da0073e9SAndroid Build Coastguard Worker } \
49*da0073e9SAndroid Build Coastguard Worker }
50*da0073e9SAndroid Build Coastguard Worker
51*da0073e9SAndroid Build Coastguard Worker // signal(2) is really not portable. So use sigaction.
52*da0073e9SAndroid Build Coastguard Worker // http://man7.org/linux/man-pages/man2/signal.2.html
setSignalHandler(int signal,void (* handler)(int,siginfo_t *,void *),struct sigaction * old_sa_ptr)53*da0073e9SAndroid Build Coastguard Worker static inline void setSignalHandler(
54*da0073e9SAndroid Build Coastguard Worker int signal,
55*da0073e9SAndroid Build Coastguard Worker void (*handler)(int, siginfo_t*, void*),
56*da0073e9SAndroid Build Coastguard Worker struct sigaction* old_sa_ptr) {
57*da0073e9SAndroid Build Coastguard Worker struct sigaction sa {};
58*da0073e9SAndroid Build Coastguard Worker sa.sa_sigaction = handler;
59*da0073e9SAndroid Build Coastguard Worker sa.sa_flags = SA_RESTART | SA_SIGINFO | SA_NOCLDSTOP | SA_NODEFER;
60*da0073e9SAndroid Build Coastguard Worker if (sigemptyset(&sa.sa_mask) != 0 ||
61*da0073e9SAndroid Build Coastguard Worker sigaction(signal, &sa, old_sa_ptr) != 0) {
62*da0073e9SAndroid Build Coastguard Worker std::ostringstream oss;
63*da0073e9SAndroid Build Coastguard Worker oss << "An error occurred while setting handler for " << strsignal(signal)
64*da0073e9SAndroid Build Coastguard Worker << ".";
65*da0073e9SAndroid Build Coastguard Worker throw std::runtime_error(oss.str());
66*da0073e9SAndroid Build Coastguard Worker }
67*da0073e9SAndroid Build Coastguard Worker }
68*da0073e9SAndroid Build Coastguard Worker
69*da0073e9SAndroid Build Coastguard Worker SIGNAL_HANDLER(
70*da0073e9SAndroid Build Coastguard Worker SIGBUS,
71*da0073e9SAndroid Build Coastguard Worker handler_SIGBUS,
72*da0073e9SAndroid Build Coastguard Worker "ERROR: Unexpected bus error encountered in worker. "
73*da0073e9SAndroid Build Coastguard Worker "This might be caused by insufficient shared memory (shm).\n");
74*da0073e9SAndroid Build Coastguard Worker SIGNAL_HANDLER(
75*da0073e9SAndroid Build Coastguard Worker SIGSEGV,
76*da0073e9SAndroid Build Coastguard Worker handler_SIGSEGV,
77*da0073e9SAndroid Build Coastguard Worker "ERROR: Unexpected segmentation fault encountered in worker.\n");
78*da0073e9SAndroid Build Coastguard Worker SIGNAL_HANDLER(
79*da0073e9SAndroid Build Coastguard Worker SIGFPE,
80*da0073e9SAndroid Build Coastguard Worker handler_SIGFPE,
81*da0073e9SAndroid Build Coastguard Worker "ERROR: Unexpected floating-point exception encountered in worker.\n");
82*da0073e9SAndroid Build Coastguard Worker
83*da0073e9SAndroid Build Coastguard Worker // When an error happened in DataLoader methods and Python starts to exit, the
84*da0073e9SAndroid Build Coastguard Worker // error trace will keep the loader alive, and Python may kill the children
85*da0073e9SAndroid Build Coastguard Worker // processes first before deleting the loader object. Then the cleaning up
86*da0073e9SAndroid Build Coastguard Worker // methods in DataLoader.__del__ are not yet called, and SIGCHILD will print an
87*da0073e9SAndroid Build Coastguard Worker // error saying a worker is killed by SIGTERM. So we suppress SIGTERM from main
88*da0073e9SAndroid Build Coastguard Worker // loader process here to avoid this by _exit(EXIT_SUCCESS). Note that if we
89*da0073e9SAndroid Build Coastguard Worker // exit with nonzero code, the loader SIGCHLD handler may report RuntimeError
90*da0073e9SAndroid Build Coastguard Worker // again, and then it defeats the whole purpose.
handler_SIGTERM(int sig,siginfo_t * info,void * ctx)91*da0073e9SAndroid Build Coastguard Worker static void handler_SIGTERM(int sig, siginfo_t* info, void* ctx) {
92*da0073e9SAndroid Build Coastguard Worker if (info->si_pid == getppid()) {
93*da0073e9SAndroid Build Coastguard Worker _exit(EXIT_SUCCESS);
94*da0073e9SAndroid Build Coastguard Worker }
95*da0073e9SAndroid Build Coastguard Worker struct sigaction sa {};
96*da0073e9SAndroid Build Coastguard Worker sa.sa_handler = SIG_DFL;
97*da0073e9SAndroid Build Coastguard Worker sa.sa_flags = 0;
98*da0073e9SAndroid Build Coastguard Worker if (sigemptyset(&sa.sa_mask) != 0 || sigaction(SIGTERM, &sa, nullptr) != 0) {
99*da0073e9SAndroid Build Coastguard Worker _exit(EXIT_FAILURE);
100*da0073e9SAndroid Build Coastguard Worker } else {
101*da0073e9SAndroid Build Coastguard Worker raise(SIGTERM);
102*da0073e9SAndroid Build Coastguard Worker }
103*da0073e9SAndroid Build Coastguard Worker }
104*da0073e9SAndroid Build Coastguard Worker
setDataLoaderSignalHandlers()105*da0073e9SAndroid Build Coastguard Worker __attribute__((weak)) void setDataLoaderSignalHandlers() {}
106*da0073e9SAndroid Build Coastguard Worker
THPModule_setWorkerSignalHandlers(PyObject * module,PyObject * arg)107*da0073e9SAndroid Build Coastguard Worker static PyObject* THPModule_setWorkerSignalHandlers(
108*da0073e9SAndroid Build Coastguard Worker PyObject* module,
109*da0073e9SAndroid Build Coastguard Worker PyObject* arg) {
110*da0073e9SAndroid Build Coastguard Worker HANDLE_TH_ERRORS
111*da0073e9SAndroid Build Coastguard Worker setSignalHandler(SIGBUS, &handler_SIGBUS, nullptr);
112*da0073e9SAndroid Build Coastguard Worker setSignalHandler(SIGSEGV, &handler_SIGSEGV, nullptr);
113*da0073e9SAndroid Build Coastguard Worker setSignalHandler(SIGTERM, &handler_SIGTERM, nullptr);
114*da0073e9SAndroid Build Coastguard Worker setSignalHandler(SIGFPE, &handler_SIGFPE, nullptr);
115*da0073e9SAndroid Build Coastguard Worker setDataLoaderSignalHandlers();
116*da0073e9SAndroid Build Coastguard Worker Py_RETURN_NONE;
117*da0073e9SAndroid Build Coastguard Worker END_HANDLE_TH_ERRORS
118*da0073e9SAndroid Build Coastguard Worker }
119*da0073e9SAndroid Build Coastguard Worker
120*da0073e9SAndroid Build Coastguard Worker static std::map<int64_t, std::set<pid_t>> worker_pids = {};
121*da0073e9SAndroid Build Coastguard Worker
THPModule_errorIfAnyWorkerFails(PyObject * module,PyObject * noargs)122*da0073e9SAndroid Build Coastguard Worker static PyObject* THPModule_errorIfAnyWorkerFails(
123*da0073e9SAndroid Build Coastguard Worker PyObject* module,
124*da0073e9SAndroid Build Coastguard Worker PyObject* noargs) {
125*da0073e9SAndroid Build Coastguard Worker HANDLE_TH_ERRORS
126*da0073e9SAndroid Build Coastguard Worker
127*da0073e9SAndroid Build Coastguard Worker // Only check the pids we care about
128*da0073e9SAndroid Build Coastguard Worker for (auto& w : worker_pids) {
129*da0073e9SAndroid Build Coastguard Worker auto& pid_set = w.second;
130*da0073e9SAndroid Build Coastguard Worker for (auto worker_pid : pid_set) {
131*da0073e9SAndroid Build Coastguard Worker // Use waitid rather than waitpid so that we can set NOWAIT, and that
132*da0073e9SAndroid Build Coastguard Worker // Python and other handlers can get whatever info they want about the
133*da0073e9SAndroid Build Coastguard Worker // child.
134*da0073e9SAndroid Build Coastguard Worker siginfo_t infop{};
135*da0073e9SAndroid Build Coastguard Worker infop.si_pid = 0;
136*da0073e9SAndroid Build Coastguard Worker auto error =
137*da0073e9SAndroid Build Coastguard Worker waitid(P_PID, worker_pid, &infop, WEXITED | WNOHANG | WNOWAIT);
138*da0073e9SAndroid Build Coastguard Worker // ignore errors and case with no waitable child
139*da0073e9SAndroid Build Coastguard Worker if (error < 0 || infop.si_pid == 0)
140*da0073e9SAndroid Build Coastguard Worker continue;
141*da0073e9SAndroid Build Coastguard Worker if (infop.si_code == CLD_EXITED &&
142*da0073e9SAndroid Build Coastguard Worker infop.si_status != EXIT_SUCCESS) { // exit with error
143*da0073e9SAndroid Build Coastguard Worker std::ostringstream oss;
144*da0073e9SAndroid Build Coastguard Worker oss << "DataLoader worker (pid " << worker_pid << ") exited "
145*da0073e9SAndroid Build Coastguard Worker << "unexpectedly with exit code " << infop.si_status << ". "
146*da0073e9SAndroid Build Coastguard Worker << "Details are lost due to multiprocessing. Rerunning with "
147*da0073e9SAndroid Build Coastguard Worker << "num_workers=0 may give better error trace.";
148*da0073e9SAndroid Build Coastguard Worker // This is necessary. Otherwise, the runtime error will kill the other
149*da0073e9SAndroid Build Coastguard Worker // workers, and trigger this again.
150*da0073e9SAndroid Build Coastguard Worker pid_set.clear();
151*da0073e9SAndroid Build Coastguard Worker throw std::runtime_error(oss.str());
152*da0073e9SAndroid Build Coastguard Worker } else if (
153*da0073e9SAndroid Build Coastguard Worker infop.si_code == CLD_KILLED ||
154*da0073e9SAndroid Build Coastguard Worker infop.si_code == CLD_DUMPED) { // killed by signal
155*da0073e9SAndroid Build Coastguard Worker std::ostringstream oss;
156*da0073e9SAndroid Build Coastguard Worker oss << "DataLoader worker (pid " << worker_pid << ") is killed "
157*da0073e9SAndroid Build Coastguard Worker << "by signal: " << strsignal(infop.si_status) << ". ";
158*da0073e9SAndroid Build Coastguard Worker if (infop.si_status == SIGBUS) {
159*da0073e9SAndroid Build Coastguard Worker oss << "It is possible that dataloader's workers are out of shared memory. "
160*da0073e9SAndroid Build Coastguard Worker << "Please try to raise your shared memory limit.";
161*da0073e9SAndroid Build Coastguard Worker }
162*da0073e9SAndroid Build Coastguard Worker // This is necessary. Otherwise, the runtime error will kill the other
163*da0073e9SAndroid Build Coastguard Worker // workers, and trigger this again.
164*da0073e9SAndroid Build Coastguard Worker pid_set.clear();
165*da0073e9SAndroid Build Coastguard Worker throw std::runtime_error(oss.str());
166*da0073e9SAndroid Build Coastguard Worker }
167*da0073e9SAndroid Build Coastguard Worker }
168*da0073e9SAndroid Build Coastguard Worker }
169*da0073e9SAndroid Build Coastguard Worker Py_RETURN_NONE;
170*da0073e9SAndroid Build Coastguard Worker END_HANDLE_TH_ERRORS
171*da0073e9SAndroid Build Coastguard Worker }
172*da0073e9SAndroid Build Coastguard Worker
173*da0073e9SAndroid Build Coastguard Worker // We don't want to exit on any SIGCHLD from any child. child_pids is a tuple
174*da0073e9SAndroid Build Coastguard Worker // of pids we are interested in.
THPModule_setWorkerPIDs(PyObject * module,PyObject * args)175*da0073e9SAndroid Build Coastguard Worker static PyObject* THPModule_setWorkerPIDs(PyObject* module, PyObject* args) {
176*da0073e9SAndroid Build Coastguard Worker HANDLE_TH_ERRORS
177*da0073e9SAndroid Build Coastguard Worker TORCH_CHECK_TYPE(
178*da0073e9SAndroid Build Coastguard Worker PyTuple_GET_SIZE(args) == 2,
179*da0073e9SAndroid Build Coastguard Worker "_set_worker_pids expects exactly 2 arguments.");
180*da0073e9SAndroid Build Coastguard Worker int64_t key = THPUtils_unpackLong(PyTuple_GET_ITEM(args, 0));
181*da0073e9SAndroid Build Coastguard Worker TORCH_CHECK_VALUE(
182*da0073e9SAndroid Build Coastguard Worker worker_pids.find(key) == worker_pids.end(),
183*da0073e9SAndroid Build Coastguard Worker "_set_worker_pids should be called only once for each _BaseDataLoaderIter.");
184*da0073e9SAndroid Build Coastguard Worker PyObject* child_pids = PyTuple_GET_ITEM(args, 1);
185*da0073e9SAndroid Build Coastguard Worker TORCH_CHECK_TYPE(
186*da0073e9SAndroid Build Coastguard Worker PyTuple_Check(child_pids),
187*da0073e9SAndroid Build Coastguard Worker "_set_worker_pids expects a tuple for child_pids, but got ",
188*da0073e9SAndroid Build Coastguard Worker Py_TYPE(child_pids)->tp_name,
189*da0073e9SAndroid Build Coastguard Worker ".");
190*da0073e9SAndroid Build Coastguard Worker std::set<pid_t> pids_set = {};
191*da0073e9SAndroid Build Coastguard Worker auto size = PyTuple_GET_SIZE(child_pids);
192*da0073e9SAndroid Build Coastguard Worker for (const auto idx : c10::irange(size)) {
193*da0073e9SAndroid Build Coastguard Worker PyObject* obj = PyTuple_GET_ITEM(child_pids, idx);
194*da0073e9SAndroid Build Coastguard Worker pids_set.insert(static_cast<pid_t>(THPUtils_unpackLong(obj)));
195*da0073e9SAndroid Build Coastguard Worker }
196*da0073e9SAndroid Build Coastguard Worker
197*da0073e9SAndroid Build Coastguard Worker worker_pids[key] = pids_set;
198*da0073e9SAndroid Build Coastguard Worker
199*da0073e9SAndroid Build Coastguard Worker Py_RETURN_NONE;
200*da0073e9SAndroid Build Coastguard Worker END_HANDLE_TH_ERRORS
201*da0073e9SAndroid Build Coastguard Worker }
202*da0073e9SAndroid Build Coastguard Worker
THPModule_removeWorkerPIDs(PyObject * module,PyObject * loader_id)203*da0073e9SAndroid Build Coastguard Worker static PyObject* THPModule_removeWorkerPIDs(
204*da0073e9SAndroid Build Coastguard Worker PyObject* module,
205*da0073e9SAndroid Build Coastguard Worker PyObject* loader_id) {
206*da0073e9SAndroid Build Coastguard Worker HANDLE_TH_ERRORS
207*da0073e9SAndroid Build Coastguard Worker
208*da0073e9SAndroid Build Coastguard Worker int64_t key = THPUtils_unpackLong(loader_id);
209*da0073e9SAndroid Build Coastguard Worker auto it = worker_pids.find(key);
210*da0073e9SAndroid Build Coastguard Worker TORCH_CHECK_VALUE(
211*da0073e9SAndroid Build Coastguard Worker it != worker_pids.end(),
212*da0073e9SAndroid Build Coastguard Worker "Cannot find worker information for _BaseDataLoaderIter with id ",
213*da0073e9SAndroid Build Coastguard Worker key);
214*da0073e9SAndroid Build Coastguard Worker worker_pids.erase(it);
215*da0073e9SAndroid Build Coastguard Worker
216*da0073e9SAndroid Build Coastguard Worker Py_RETURN_NONE;
217*da0073e9SAndroid Build Coastguard Worker END_HANDLE_TH_ERRORS
218*da0073e9SAndroid Build Coastguard Worker }
219*da0073e9SAndroid Build Coastguard Worker
220*da0073e9SAndroid Build Coastguard Worker #undef SIGNAL_HANDLER
221*da0073e9SAndroid Build Coastguard Worker
222*da0073e9SAndroid Build Coastguard Worker #else
223*da0073e9SAndroid Build Coastguard Worker // dummy implementations for windows
224*da0073e9SAndroid Build Coastguard Worker
THPModule_setWorkerSignalHandlers(PyObject * module,PyObject * _ignored)225*da0073e9SAndroid Build Coastguard Worker static PyObject* THPModule_setWorkerSignalHandlers(
226*da0073e9SAndroid Build Coastguard Worker PyObject* module,
227*da0073e9SAndroid Build Coastguard Worker PyObject* _ignored) {
228*da0073e9SAndroid Build Coastguard Worker Py_RETURN_NONE;
229*da0073e9SAndroid Build Coastguard Worker }
230*da0073e9SAndroid Build Coastguard Worker
THPModule_setWorkerPIDs(PyObject * module,PyObject * _ignored)231*da0073e9SAndroid Build Coastguard Worker static PyObject* THPModule_setWorkerPIDs(PyObject* module, PyObject* _ignored) {
232*da0073e9SAndroid Build Coastguard Worker Py_RETURN_NONE;
233*da0073e9SAndroid Build Coastguard Worker }
234*da0073e9SAndroid Build Coastguard Worker
THPModule_removeWorkerPIDs(PyObject * module,PyObject * _ignored)235*da0073e9SAndroid Build Coastguard Worker static PyObject* THPModule_removeWorkerPIDs(
236*da0073e9SAndroid Build Coastguard Worker PyObject* module,
237*da0073e9SAndroid Build Coastguard Worker PyObject* _ignored) {
238*da0073e9SAndroid Build Coastguard Worker Py_RETURN_NONE;
239*da0073e9SAndroid Build Coastguard Worker }
240*da0073e9SAndroid Build Coastguard Worker
THPModule_errorIfAnyWorkerFails(PyObject * module,PyObject * _ignored)241*da0073e9SAndroid Build Coastguard Worker static PyObject* THPModule_errorIfAnyWorkerFails(
242*da0073e9SAndroid Build Coastguard Worker PyObject* module,
243*da0073e9SAndroid Build Coastguard Worker PyObject* _ignored) {
244*da0073e9SAndroid Build Coastguard Worker Py_RETURN_NONE;
245*da0073e9SAndroid Build Coastguard Worker }
246*da0073e9SAndroid Build Coastguard Worker
247*da0073e9SAndroid Build Coastguard Worker #endif
248*da0073e9SAndroid Build Coastguard Worker
249*da0073e9SAndroid Build Coastguard Worker // NOLINTNEXTLINE(cppcoreguidelines-avoid-c-arrays,cppcoreguidelines-avoid-non-const-global-variables,modernize-avoid-c-arrays)
250*da0073e9SAndroid Build Coastguard Worker PyMethodDef DataLoaderMethods[] = {
251*da0073e9SAndroid Build Coastguard Worker {"_set_worker_signal_handlers",
252*da0073e9SAndroid Build Coastguard Worker THPModule_setWorkerSignalHandlers,
253*da0073e9SAndroid Build Coastguard Worker METH_NOARGS,
254*da0073e9SAndroid Build Coastguard Worker nullptr},
255*da0073e9SAndroid Build Coastguard Worker {"_set_worker_pids", THPModule_setWorkerPIDs, METH_VARARGS, nullptr},
256*da0073e9SAndroid Build Coastguard Worker {"_remove_worker_pids", THPModule_removeWorkerPIDs, METH_O, nullptr},
257*da0073e9SAndroid Build Coastguard Worker {"_error_if_any_worker_fails",
258*da0073e9SAndroid Build Coastguard Worker THPModule_errorIfAnyWorkerFails,
259*da0073e9SAndroid Build Coastguard Worker METH_NOARGS,
260*da0073e9SAndroid Build Coastguard Worker nullptr},
261*da0073e9SAndroid Build Coastguard Worker {nullptr, nullptr, 0, nullptr}};
262