1# mypy: allow-untyped-defs 2import torch.cuda 3 4 5try: 6 from torch._C import _cudnn 7except ImportError: 8 # Uses of all the functions below should be guarded by torch.backends.cudnn.is_available(), 9 # so it's safe to not emit any checks here. 10 _cudnn = None # type: ignore[assignment] 11 12 13def get_cudnn_mode(mode): 14 if mode == "RNN_RELU": 15 return int(_cudnn.RNNMode.rnn_relu) 16 elif mode == "RNN_TANH": 17 return int(_cudnn.RNNMode.rnn_tanh) 18 elif mode == "LSTM": 19 return int(_cudnn.RNNMode.lstm) 20 elif mode == "GRU": 21 return int(_cudnn.RNNMode.gru) 22 else: 23 raise Exception(f"Unknown mode: {mode}") # noqa: TRY002 24 25 26# NB: We don't actually need this class anymore (in fact, we could serialize the 27# dropout state for even better reproducibility), but it is kept for backwards 28# compatibility for old models. 29class Unserializable: 30 def __init__(self, inner): 31 self.inner = inner 32 33 def get(self): 34 return self.inner 35 36 def __getstate__(self): 37 # Note: can't return {}, because python2 won't call __setstate__ 38 # if the value evaluates to False 39 return "<unserializable>" 40 41 def __setstate__(self, state): 42 self.inner = None 43 44 45def init_dropout_state(dropout, train, dropout_seed, dropout_state): 46 dropout_desc_name = "desc_" + str(torch.cuda.current_device()) 47 dropout_p = dropout if train else 0 48 if (dropout_desc_name not in dropout_state) or ( 49 dropout_state[dropout_desc_name].get() is None 50 ): 51 if dropout_p == 0: 52 dropout_state[dropout_desc_name] = Unserializable(None) 53 else: 54 dropout_state[dropout_desc_name] = Unserializable( 55 torch._cudnn_init_dropout_state( # type: ignore[call-arg] 56 dropout_p, 57 train, 58 dropout_seed, 59 self_ty=torch.uint8, 60 device=torch.device("cuda"), 61 ) 62 ) 63 dropout_ts = dropout_state[dropout_desc_name].get() 64 return dropout_ts 65