1from benchmark_helper import time_with_torch_timer 2 3import torch 4import torch._dynamo 5import torch._dynamo.config 6import torch._inductor.config as config 7 8 9@torch._dynamo.optimize("inductor", nopython=True) 10def inductor_aten_bmm(a, b): 11 return torch.bmm(a, b) 12 13 14@torch._dynamo.optimize("inductor", nopython=True) 15def inductor_triton_bmm(a, b): 16 return torch.bmm(a, b) 17 18 19def torch_bmm(a, b): 20 return torch.bmm(a, b) 21 22 23def test_total_time(shapes): 24 print("shape; torch bmm; inductor aten bmm; inductor triton bmm") 25 for i in range(len(shapes)): 26 a_shape, b_shape = shapes[i] 27 print(a_shape, "x", b_shape, end="; ") 28 a = torch.randn(a_shape, device="cuda", dtype=torch.float16) 29 b = torch.randn(b_shape, device="cuda", dtype=a.dtype) 30 31 config.triton.use_bmm = False 32 inductor_aten_bmm(a, b) 33 34 config.triton.use_bmm = True 35 inductor_triton_bmm(a, b) 36 37 torch_ms = time_with_torch_timer(torch_bmm, (a, b)).mean * 1000 38 39 config.triton.use_bmm = False 40 ind_aten_ms = time_with_torch_timer(inductor_aten_bmm, (a, b)).mean * 1000 41 42 config.triton.use_bmm = True 43 ind_triton_ms = time_with_torch_timer(inductor_triton_bmm, (a, b)).mean * 1000 44 45 print(torch_ms, ind_aten_ms, ind_triton_ms, sep="; ") 46 47 48if __name__ == "__main__": 49 shapes = [ 50 # BERT (all) 51 ([192, 128, 64], [192, 64, 128]), 52 ([192, 128, 128], [192, 128, 64]), 53 # hf_GPT2 (all) 54 ([12, 1024, 1024], [12, 1024, 64]), 55 ([12, 1024, 64], [12, 64, 1024]), 56 # hf_Albert (all) 57 ([12, 512, 64], [12, 64, 512]), 58 ([12, 512, 512], [12, 512, 64]), 59 ] 60 61 test_total_time(shapes) 62