xref: /aosp_15_r20/external/pytorch/benchmarks/tensorexpr/matmul.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1*da0073e9SAndroid Build Coastguard Workerimport numpy as np
2*da0073e9SAndroid Build Coastguard Worker
3*da0073e9SAndroid Build Coastguard Workerfrom . import benchmark
4*da0073e9SAndroid Build Coastguard Worker
5*da0073e9SAndroid Build Coastguard Worker
6*da0073e9SAndroid Build Coastguard Workerclass MatMulBench(benchmark.Benchmark):
7*da0073e9SAndroid Build Coastguard Worker    def __init__(self, mode, device, dtype, B, M, N, K):
8*da0073e9SAndroid Build Coastguard Worker        super().__init__(mode, device, dtype)
9*da0073e9SAndroid Build Coastguard Worker        self.B = B
10*da0073e9SAndroid Build Coastguard Worker        self.M = M
11*da0073e9SAndroid Build Coastguard Worker        self.N = N
12*da0073e9SAndroid Build Coastguard Worker        self.K = K
13*da0073e9SAndroid Build Coastguard Worker        self.d1 = self.rand(
14*da0073e9SAndroid Build Coastguard Worker            [B, M, N], device=device, dtype=dtype, requires_grad=self.requires_grad
15*da0073e9SAndroid Build Coastguard Worker        )
16*da0073e9SAndroid Build Coastguard Worker        self.d2 = self.rand(
17*da0073e9SAndroid Build Coastguard Worker            [B, N, K], device=device, dtype=dtype, requires_grad=self.requires_grad
18*da0073e9SAndroid Build Coastguard Worker        )
19*da0073e9SAndroid Build Coastguard Worker        self.inputs = [self.d1, self.d2]
20*da0073e9SAndroid Build Coastguard Worker
21*da0073e9SAndroid Build Coastguard Worker    def forward(self, d1, d2):
22*da0073e9SAndroid Build Coastguard Worker        y = self.matmul(d1, d2)
23*da0073e9SAndroid Build Coastguard Worker        return y
24*da0073e9SAndroid Build Coastguard Worker
25*da0073e9SAndroid Build Coastguard Worker    def reference(self):
26*da0073e9SAndroid Build Coastguard Worker        return np.matmul(self.numpy(self.d1), self.numpy(self.d2))
27*da0073e9SAndroid Build Coastguard Worker
28*da0073e9SAndroid Build Coastguard Worker    def config(self):
29*da0073e9SAndroid Build Coastguard Worker        return [self.B, self.M, self.N, self.K]
30*da0073e9SAndroid Build Coastguard Worker
31*da0073e9SAndroid Build Coastguard Worker    @staticmethod
32*da0073e9SAndroid Build Coastguard Worker    def module():
33*da0073e9SAndroid Build Coastguard Worker        return "batch_matmul"
34*da0073e9SAndroid Build Coastguard Worker
35*da0073e9SAndroid Build Coastguard Worker    def memory_workload(self):
36*da0073e9SAndroid Build Coastguard Worker        if self.mode == "fwd":
37*da0073e9SAndroid Build Coastguard Worker            sol_count = 1
38*da0073e9SAndroid Build Coastguard Worker            algorithmic_count = 1
39*da0073e9SAndroid Build Coastguard Worker        else:
40*da0073e9SAndroid Build Coastguard Worker            sol_count = 1 + 1
41*da0073e9SAndroid Build Coastguard Worker            algorithmic_count = 1 + (1 + 1)
42*da0073e9SAndroid Build Coastguard Worker
43*da0073e9SAndroid Build Coastguard Worker        buffer_size = (
44*da0073e9SAndroid Build Coastguard Worker            self.B * self.M * self.N
45*da0073e9SAndroid Build Coastguard Worker            + self.B * self.M * self.N
46*da0073e9SAndroid Build Coastguard Worker            + self.B * self.N * self.K
47*da0073e9SAndroid Build Coastguard Worker        )
48*da0073e9SAndroid Build Coastguard Worker        return {
49*da0073e9SAndroid Build Coastguard Worker            "sol": buffer_size * sol_count,
50*da0073e9SAndroid Build Coastguard Worker            "algorithmic": buffer_size * algorithmic_count,
51*da0073e9SAndroid Build Coastguard Worker        }
52*da0073e9SAndroid Build Coastguard Worker
53*da0073e9SAndroid Build Coastguard Worker    def compute_workload(self):
54*da0073e9SAndroid Build Coastguard Worker        if self.mode == "fwd":
55*da0073e9SAndroid Build Coastguard Worker            count = 1
56*da0073e9SAndroid Build Coastguard Worker        else:
57*da0073e9SAndroid Build Coastguard Worker            count = 1 + (1 + 1)
58*da0073e9SAndroid Build Coastguard Worker
59*da0073e9SAndroid Build Coastguard Worker        op_count = 2 * self.B * self.M * self.N * self.K
60*da0073e9SAndroid Build Coastguard Worker
61*da0073e9SAndroid Build Coastguard Worker        return op_count * count
62*da0073e9SAndroid Build Coastguard Worker
63*da0073e9SAndroid Build Coastguard Worker    @staticmethod
64*da0073e9SAndroid Build Coastguard Worker    def default_configs():
65*da0073e9SAndroid Build Coastguard Worker        return [[128, 64, 128, 256]]
66*da0073e9SAndroid Build Coastguard Worker
67*da0073e9SAndroid Build Coastguard Worker
68*da0073e9SAndroid Build Coastguard Workerbenchmark.register_benchmark_class(MatMulBench)
69