xref: /aosp_15_r20/external/pytorch/torch/utils/data/_utils/signal_handling.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: allow-untyped-defs
2r"""Signal handling for multiprocessing data loading.
3
4NOTE [ Signal handling in multiprocessing data loading ]
5
6In cases like DataLoader, if a worker process dies due to bus error/segfault
7or just hang, the main process will hang waiting for data. This is difficult
8to avoid on PyTorch side as it can be caused by limited shm, or other
9libraries users call in the workers. In this file and `DataLoader.cpp`, we make
10our best effort to provide some error message to users when such unfortunate
11events happen.
12
13When a _BaseDataLoaderIter starts worker processes, their pids are registered in a
14defined in `DataLoader.cpp`: id(_BaseDataLoaderIter) => Collection[ Worker pids ]
15via `_set_worker_pids`.
16
17When an error happens in a worker process, the main process received a SIGCHLD,
18and Python will eventually call the handler registered below
19(in `_set_SIGCHLD_handler`). In the handler, the `_error_if_any_worker_fails`
20call checks all registered worker pids and raise proper error message to
21prevent main process from hanging waiting for data from worker.
22
23Additionally, at the beginning of each worker's `_utils.worker._worker_loop`,
24`_set_worker_signal_handlers` is called to register critical signal handlers
25(e.g., for SIGSEGV, SIGBUS, SIGFPE, SIGTERM) in C, which just prints an error
26message to stderr before triggering the default handler. So a message will also
27be printed from the worker process when it is killed by such signals.
28
29See NOTE [ Data Loader Multiprocessing Shutdown Logic ] for the reasoning of
30this signal handling design and other mechanism we implement to make our
31multiprocessing data loading robust to errors.
32"""
33
34import signal
35import threading
36
37# Some of the following imported functions are not used in this file, but are to
38# be used `_utils.signal_handling.XXXXX`.
39from torch._C import (  # noqa: F401
40    _error_if_any_worker_fails,
41    _remove_worker_pids,
42    _set_worker_pids,
43    _set_worker_signal_handlers,
44)
45
46from . import IS_WINDOWS
47
48
49_SIGCHLD_handler_set = False
50r"""Whether SIGCHLD handler is set for DataLoader worker failures. Only one
51handler needs to be set for all DataLoaders in a process."""
52
53
54def _set_SIGCHLD_handler():
55    # Windows doesn't support SIGCHLD handler
56    if IS_WINDOWS:
57        return
58    # can't set signal in child threads
59    if not isinstance(threading.current_thread(), threading._MainThread):  # type: ignore[attr-defined]
60        return
61    global _SIGCHLD_handler_set
62    if _SIGCHLD_handler_set:
63        return
64    previous_handler = signal.getsignal(signal.SIGCHLD)
65    if not callable(previous_handler):
66        # This doesn't catch default handler, but SIGCHLD default handler is a
67        # no-op.
68        previous_handler = None
69
70    def handler(signum, frame):
71        # This following call uses `waitid` with WNOHANG from C side. Therefore,
72        # Python can still get and update the process status successfully.
73        _error_if_any_worker_fails()
74        if previous_handler is not None:
75            assert callable(previous_handler)
76            previous_handler(signum, frame)
77
78    signal.signal(signal.SIGCHLD, handler)
79    _SIGCHLD_handler_set = True
80