1"""Define some common setup blocks which benchmarks can reuse.""" 2# mypy: ignore-errors 3import enum 4 5from core.api import GroupedSetup 6from core.utils import parse_stmts 7 8 9_TRIVIAL_2D = GroupedSetup(r"x = torch.ones((4, 4))", r"auto x = torch::ones({4, 4});") 10 11 12_TRIVIAL_3D = GroupedSetup( 13 r"x = torch.ones((4, 4, 4))", r"auto x = torch::ones({4, 4, 4});" 14) 15 16 17_TRIVIAL_4D = GroupedSetup( 18 r"x = torch.ones((4, 4, 4, 4))", r"auto x = torch::ones({4, 4, 4, 4});" 19) 20 21 22_TRAINING = GroupedSetup( 23 *parse_stmts( 24 r""" 25 Python | C++ 26 ---------------------------------------- | ---------------------------------------- 27 # Inputs | // Inputs 28 x = torch.ones((1,)) | auto x = torch::ones({1}); 29 y = torch.ones((1,)) | auto y = torch::ones({1}); 30 | 31 # Weights | // Weights 32 w0 = torch.ones( | auto w0 = torch::ones({1}); 33 (1,), requires_grad=True) | w0.set_requires_grad(true); 34 w1 = torch.ones( | auto w1 = torch::ones({1}); 35 (1,), requires_grad=True) | w1.set_requires_grad(true); 36 w2 = torch.ones( | auto w2 = torch::ones({2}); 37 (2,), requires_grad=True) | w2.set_requires_grad(true); 38 """ 39 ) 40) 41 42 43class Setup(enum.Enum): 44 TRIVIAL_2D = _TRIVIAL_2D 45 TRIVIAL_3D = _TRIVIAL_3D 46 TRIVIAL_4D = _TRIVIAL_4D 47 TRAINING = _TRAINING 48