1from collections import namedtuple 2from functools import partial 3 4import torchvision.models as cnn 5 6import torch 7 8from .factory import ( 9 dropoutlstm_creator, 10 imagenet_cnn_creator, 11 layernorm_pytorch_lstm_creator, 12 lnlstm_creator, 13 lstm_creator, 14 lstm_multilayer_creator, 15 lstm_premul_bias_creator, 16 lstm_premul_creator, 17 lstm_simple_creator, 18 pytorch_lstm_creator, 19 varlen_lstm_creator, 20 varlen_pytorch_lstm_creator, 21) 22 23 24class DisableCuDNN: 25 def __enter__(self): 26 self.saved = torch.backends.cudnn.enabled 27 torch.backends.cudnn.enabled = False 28 29 def __exit__(self, *args, **kwargs): 30 torch.backends.cudnn.enabled = self.saved 31 32 33class DummyContext: 34 def __enter__(self): 35 pass 36 37 def __exit__(self, *args, **kwargs): 38 pass 39 40 41class AssertNoJIT: 42 def __enter__(self): 43 import os 44 45 enabled = os.environ.get("PYTORCH_JIT", 1) 46 assert not enabled 47 48 def __exit__(self, *args, **kwargs): 49 pass 50 51 52RNNRunner = namedtuple( 53 "RNNRunner", 54 [ 55 "name", 56 "creator", 57 "context", 58 ], 59) 60 61 62def get_nn_runners(*names): 63 return [nn_runners[name] for name in names] 64 65 66nn_runners = { 67 "cudnn": RNNRunner("cudnn", pytorch_lstm_creator, DummyContext), 68 "cudnn_dropout": RNNRunner( 69 "cudnn_dropout", partial(pytorch_lstm_creator, dropout=0.4), DummyContext 70 ), 71 "cudnn_layernorm": RNNRunner( 72 "cudnn_layernorm", layernorm_pytorch_lstm_creator, DummyContext 73 ), 74 "vl_cudnn": RNNRunner("vl_cudnn", varlen_pytorch_lstm_creator, DummyContext), 75 "vl_jit": RNNRunner( 76 "vl_jit", partial(varlen_lstm_creator, script=True), DummyContext 77 ), 78 "vl_py": RNNRunner("vl_py", varlen_lstm_creator, DummyContext), 79 "aten": RNNRunner("aten", pytorch_lstm_creator, DisableCuDNN), 80 "jit": RNNRunner("jit", lstm_creator, DummyContext), 81 "jit_premul": RNNRunner("jit_premul", lstm_premul_creator, DummyContext), 82 "jit_premul_bias": RNNRunner( 83 "jit_premul_bias", lstm_premul_bias_creator, DummyContext 84 ), 85 "jit_simple": RNNRunner("jit_simple", lstm_simple_creator, DummyContext), 86 "jit_multilayer": RNNRunner( 87 "jit_multilayer", lstm_multilayer_creator, DummyContext 88 ), 89 "jit_layernorm": RNNRunner("jit_layernorm", lnlstm_creator, DummyContext), 90 "jit_layernorm_decom": RNNRunner( 91 "jit_layernorm_decom", 92 partial(lnlstm_creator, decompose_layernorm=True), 93 DummyContext, 94 ), 95 "jit_dropout": RNNRunner("jit_dropout", dropoutlstm_creator, DummyContext), 96 "py": RNNRunner("py", partial(lstm_creator, script=False), DummyContext), 97 "resnet18": RNNRunner( 98 "resnet18", imagenet_cnn_creator(cnn.resnet18, jit=False), DummyContext 99 ), 100 "resnet18_jit": RNNRunner( 101 "resnet18_jit", imagenet_cnn_creator(cnn.resnet18), DummyContext 102 ), 103 "resnet50": RNNRunner( 104 "resnet50", imagenet_cnn_creator(cnn.resnet50, jit=False), DummyContext 105 ), 106 "resnet50_jit": RNNRunner( 107 "resnet50_jit", imagenet_cnn_creator(cnn.resnet50), DummyContext 108 ), 109} 110