xref: /aosp_15_r20/external/pytorch/benchmarks/instruction_counts/definitions/setup.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1"""Define some common setup blocks which benchmarks can reuse."""
2# mypy: ignore-errors
3import enum
4
5from core.api import GroupedSetup
6from core.utils import parse_stmts
7
8
9_TRIVIAL_2D = GroupedSetup(r"x = torch.ones((4, 4))", r"auto x = torch::ones({4, 4});")
10
11
12_TRIVIAL_3D = GroupedSetup(
13    r"x = torch.ones((4, 4, 4))", r"auto x = torch::ones({4, 4, 4});"
14)
15
16
17_TRIVIAL_4D = GroupedSetup(
18    r"x = torch.ones((4, 4, 4, 4))", r"auto x = torch::ones({4, 4, 4, 4});"
19)
20
21
22_TRAINING = GroupedSetup(
23    *parse_stmts(
24        r"""
25        Python                                   | C++
26        ---------------------------------------- | ----------------------------------------
27        # Inputs                                 | // Inputs
28        x = torch.ones((1,))                     | auto x = torch::ones({1});
29        y = torch.ones((1,))                     | auto y = torch::ones({1});
30                                                 |
31        # Weights                                | // Weights
32        w0 = torch.ones(                         | auto w0 = torch::ones({1});
33            (1,), requires_grad=True)            | w0.set_requires_grad(true);
34        w1 = torch.ones(                         | auto w1 = torch::ones({1});
35            (1,), requires_grad=True)            | w1.set_requires_grad(true);
36        w2 = torch.ones(                         | auto w2 = torch::ones({2});
37            (2,), requires_grad=True)            | w2.set_requires_grad(true);
38    """
39    )
40)
41
42
43class Setup(enum.Enum):
44    TRIVIAL_2D = _TRIVIAL_2D
45    TRIVIAL_3D = _TRIVIAL_3D
46    TRIVIAL_4D = _TRIVIAL_4D
47    TRAINING = _TRAINING
48