xref: /aosp_15_r20/external/pytorch/benchmarks/tensorexpr/pooling.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1from . import benchmark
2
3
4class PoolingBench(benchmark.Benchmark):
5    def __init__(self, case, mode, device, dtype, kernel_size, N, C, H, W):
6        super().__init__(mode, device)
7        self.case = case
8        self.kernel_size = kernel_size
9        self.N = N
10        self.C = C
11        self.H = H
12        self.W = W
13        self.data = self.rand(
14            [N, C, H, W], device=device, dtype=dtype, requires_grad=self.requires_grad
15        )
16
17    def forward(self):
18        if self.case == "maxpool":
19            y = self.max_pool2d(self.data, self.kernel_size, stride=1)
20        elif self.case == "avgpool":
21            y = self.avg_pool2d(self.data, self.kernel_size, stride=1)
22        return y
23
24    def config(self):
25        return [self.kernel_size, self.N, self.C, self.H, self.W]
26
27    def memory_workload(self):
28        if self.mode == "fwd":
29            sol_count = 1 + 1
30            algorithmic_count = 1 + 1
31        else:
32            sol_count = (1 + 1) + (1 + 1)
33            algorithmic_count = (1 + 1) + (2 + 1)
34
35        buffer_size = self.N * self.C * self.H * self.W
36        return {
37            "sol": buffer_size * sol_count,
38            "algorithmic": buffer_size * algorithmic_count,
39        }
40
41    @staticmethod
42    def default_configs():
43        return [[3, 16, 32, 256, 256]]
44
45
46class MaxPoolBench(PoolingBench):
47    def __init__(self, *args):
48        super().__init__("maxpool", *args)
49
50    @staticmethod
51    def module():
52        return "maxpool"
53
54
55class AvgPoolBench(PoolingBench):
56    def __init__(self, *args):
57        super().__init__("avgpool", *args)
58
59    @staticmethod
60    def module():
61        return "avgpool"
62
63
64benchmark.register_benchmark_class(MaxPoolBench)
65benchmark.register_benchmark_class(AvgPoolBench)
66