xref: /aosp_15_r20/external/pytorch/benchmarks/tensorexpr/normalization.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1from . import benchmark, tensor_engine
2
3
4class NormalizationBench(benchmark.Benchmark):
5    def __init__(self, mode, device, dtype, N, C, H, W):
6        super().__init__(mode, device, dtype)
7        self.N = N
8        self.C = C
9        self.H = H
10        self.W = W
11
12        self.data = self.nchw_rand(
13            [self.N, self.C, self.H, self.W],
14            device=device,
15            dtype=dtype,
16            requires_grad=self.requires_grad,
17        )
18        self.running_mean = self.rand([self.C], device=device, dtype=dtype)
19        self.running_var = self.rand([self.C], device=device, dtype=dtype)
20        self.training = self.mode == "both"
21
22    def config(self):
23        return [self.N, self.C, self.H, self.W]
24
25    def memory_workload(self):
26        if self.mode == "fwd":
27            sol_count = 1 + 1
28            algorithmic_count = 2 + 1
29        else:
30            sol_count = (1 + 1) + (1 + 1)
31            algorithmic_count = (2 + 1) + (3 + 1)
32
33        buffer_size = self.N * self.C * self.H * self.W * 4
34        return {
35            "sol": buffer_size * sol_count,
36            "algorithmic": buffer_size * algorithmic_count,
37        }
38
39    @staticmethod
40    def default_configs():
41        return [[128, 32, 128, 128]]
42
43
44class BatchNormBench(NormalizationBench):
45    def forward(self):
46        y = self.batch_norm(
47            self.data, self.running_mean, self.running_var, training=self.training
48        )
49        return y
50
51    @staticmethod
52    def module():
53        return "batchnorm"
54
55
56class InstanceNormBench(NormalizationBench):
57    def forward(self):
58        y = self.instance_norm(self.data)
59        return y
60
61    @staticmethod
62    def module():
63        return "instance_norm"
64
65    def is_supported(self):
66        return tensor_engine.is_supported(self.instance_norm)
67
68
69class LayerNormBench(NormalizationBench):
70    def forward(self):
71        y = self.layer_norm(self.data, [self.H, self.W])
72        return y
73
74    @staticmethod
75    def module():
76        return "layernorm"
77
78
79benchmark.register_benchmark_class(BatchNormBench)
80benchmark.register_benchmark_class(InstanceNormBench)
81benchmark.register_benchmark_class(LayerNormBench)
82