xref: /aosp_15_r20/external/pytorch/torch/backends/cudnn/rnn.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: allow-untyped-defs
2import torch.cuda
3
4
5try:
6    from torch._C import _cudnn
7except ImportError:
8    # Uses of all the functions below should be guarded by torch.backends.cudnn.is_available(),
9    # so it's safe to not emit any checks here.
10    _cudnn = None  # type: ignore[assignment]
11
12
13def get_cudnn_mode(mode):
14    if mode == "RNN_RELU":
15        return int(_cudnn.RNNMode.rnn_relu)
16    elif mode == "RNN_TANH":
17        return int(_cudnn.RNNMode.rnn_tanh)
18    elif mode == "LSTM":
19        return int(_cudnn.RNNMode.lstm)
20    elif mode == "GRU":
21        return int(_cudnn.RNNMode.gru)
22    else:
23        raise Exception(f"Unknown mode: {mode}")  # noqa: TRY002
24
25
26# NB: We don't actually need this class anymore (in fact, we could serialize the
27# dropout state for even better reproducibility), but it is kept for backwards
28# compatibility for old models.
29class Unserializable:
30    def __init__(self, inner):
31        self.inner = inner
32
33    def get(self):
34        return self.inner
35
36    def __getstate__(self):
37        # Note: can't return {}, because python2 won't call __setstate__
38        # if the value evaluates to False
39        return "<unserializable>"
40
41    def __setstate__(self, state):
42        self.inner = None
43
44
45def init_dropout_state(dropout, train, dropout_seed, dropout_state):
46    dropout_desc_name = "desc_" + str(torch.cuda.current_device())
47    dropout_p = dropout if train else 0
48    if (dropout_desc_name not in dropout_state) or (
49        dropout_state[dropout_desc_name].get() is None
50    ):
51        if dropout_p == 0:
52            dropout_state[dropout_desc_name] = Unserializable(None)
53        else:
54            dropout_state[dropout_desc_name] = Unserializable(
55                torch._cudnn_init_dropout_state(  # type: ignore[call-arg]
56                    dropout_p,
57                    train,
58                    dropout_seed,
59                    self_ty=torch.uint8,
60                    device=torch.device("cuda"),
61                )
62            )
63    dropout_ts = dropout_state[dropout_desc_name].get()
64    return dropout_ts
65