xref: /aosp_15_r20/external/pytorch/benchmarks/dynamo/microbenchmarks/inductor_bmm.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1from benchmark_helper import time_with_torch_timer
2
3import torch
4import torch._dynamo
5import torch._dynamo.config
6import torch._inductor.config as config
7
8
9@torch._dynamo.optimize("inductor", nopython=True)
10def inductor_aten_bmm(a, b):
11    return torch.bmm(a, b)
12
13
14@torch._dynamo.optimize("inductor", nopython=True)
15def inductor_triton_bmm(a, b):
16    return torch.bmm(a, b)
17
18
19def torch_bmm(a, b):
20    return torch.bmm(a, b)
21
22
23def test_total_time(shapes):
24    print("shape; torch bmm; inductor aten bmm; inductor triton bmm")
25    for i in range(len(shapes)):
26        a_shape, b_shape = shapes[i]
27        print(a_shape, "x", b_shape, end="; ")
28        a = torch.randn(a_shape, device="cuda", dtype=torch.float16)
29        b = torch.randn(b_shape, device="cuda", dtype=a.dtype)
30
31        config.triton.use_bmm = False
32        inductor_aten_bmm(a, b)
33
34        config.triton.use_bmm = True
35        inductor_triton_bmm(a, b)
36
37        torch_ms = time_with_torch_timer(torch_bmm, (a, b)).mean * 1000
38
39        config.triton.use_bmm = False
40        ind_aten_ms = time_with_torch_timer(inductor_aten_bmm, (a, b)).mean * 1000
41
42        config.triton.use_bmm = True
43        ind_triton_ms = time_with_torch_timer(inductor_triton_bmm, (a, b)).mean * 1000
44
45        print(torch_ms, ind_aten_ms, ind_triton_ms, sep="; ")
46
47
48if __name__ == "__main__":
49    shapes = [
50        # BERT (all)
51        ([192, 128, 64], [192, 64, 128]),
52        ([192, 128, 128], [192, 128, 64]),
53        # hf_GPT2 (all)
54        ([12, 1024, 1024], [12, 1024, 64]),
55        ([12, 1024, 64], [12, 64, 1024]),
56        # hf_Albert (all)
57        ([12, 512, 64], [12, 64, 512]),
58        ([12, 512, 512], [12, 512, 64]),
59    ]
60
61    test_total_time(shapes)
62