xref: /aosp_15_r20/external/pytorch/tools/test/test_vulkan_codegen.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1import tempfile
2import unittest
3
4from tools.gen_vulkan_spv import DEFAULT_ENV, SPVGenerator
5
6
7####################
8# Data for testing #
9####################
10
11test_shader = """
12#version 450 core
13
14#define FORMAT ${FORMAT}
15#define PRECISION ${PRECISION}
16#define OP(X) ${OPERATOR}
17
18$def is_int(dtype):
19$   return dtype in {"int", "int32", "int8"}
20
21$def is_uint(dtype):
22$   return dtype in {"uint", "uint32", "uint8"}
23
24$if is_int(DTYPE):
25  #define VEC4_T ivec4
26$elif is_uint(DTYPE):
27  #define VEC4_T uvec4
28$else:
29  #define VEC4_T vec4
30
31$if not INPLACE:
32  $if is_int(DTYPE):
33    layout(set = 0, binding = 0, FORMAT) uniform PRECISION restrict writeonly iimage3D uOutput;
34    layout(set = 0, binding = 1) uniform PRECISION isampler3D uInput;
35  $elif is_uint(DTYPE):
36    layout(set = 0, binding = 0, FORMAT) uniform PRECISION restrict writeonly uimage3D uOutput;
37    layout(set = 0, binding = 1) uniform PRECISION usampler3D uInput;
38  $else:
39    layout(set = 0, binding = 0, FORMAT) uniform PRECISION restrict writeonly image3D uOutput;
40    layout(set = 0, binding = 1) uniform PRECISION sampler3D uInput;
41$else:
42  $if is_int(DTYPE):
43    layout(set = 0, binding = 0, FORMAT) uniform PRECISION restrict iimage3D uOutput;
44  $elif is_uint(DTYPE):
45    layout(set = 0, binding = 0, FORMAT) uniform PRECISION restrict uimage3D uOutput;
46  $else:
47    layout(set = 0, binding = 0, FORMAT) uniform PRECISION restrict image3D uOutput;
48
49layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in;
50
51void main() {
52  const ivec3 pos = ivec3(gl_GlobalInvocationID);
53  $if not INPLACE:
54    VEC4_T v = texelFetch(uInput, pos, 0);
55  $else:
56    VEC4_T v = imageLoad(uOutput, pos);
57  $for i in range(ITER[0]):
58    for (int i = 0; i < ${ITER[1]}; ++i) {
59        v = OP(v + i);
60    }
61  imageStore(uOutput, pos, OP(v));
62}
63
64"""
65
66test_params_yaml = """
67test_shader:
68  parameter_names_with_default_values:
69    DTYPE: float
70    INPLACE: false
71    OPERATOR: X + 3
72    ITER: !!python/tuple [3, 5]
73  generate_variant_forall:
74    INPLACE:
75      - VALUE: false
76        SUFFIX: ""
77      - VALUE: true
78        SUFFIX: inplace
79    DTYPE:
80      - VALUE: int8
81      - VALUE: float
82  shader_variants:
83    - NAME: test_shader_1
84    - NAME: test_shader_3
85      OPERATOR: X - 1
86      ITER: !!python/tuple [3, 2]
87      generate_variant_forall:
88        DTYPE:
89        - VALUE: float
90        - VALUE: int
91
92"""
93
94##############
95# Unit Tests #
96##############
97
98
99class TestVulkanSPVCodegen(unittest.TestCase):
100    def setUp(self) -> None:
101        self.tmpdir = tempfile.TemporaryDirectory()
102
103        with open(f"{self.tmpdir.name}/test_shader.glsl,", "w") as f:
104            f.write(test_shader)
105
106        with open(f"{self.tmpdir.name}/test_params.yaml", "w") as f:
107            f.write(test_params_yaml)
108
109        self.tmpoutdir = tempfile.TemporaryDirectory()
110
111        self.generator = SPVGenerator(
112            src_dir_paths=self.tmpdir.name, env=DEFAULT_ENV, glslc_path=None
113        )
114
115    def cleanUp(self) -> None:
116        self.tmpdir.cleanup()
117        self.tmpoutdir.cleanup()
118
119    def testOutputMap(self) -> None:
120        # Each shader variant will produce variants generated based on all possible combinations
121        # of the DTYPE and INPLACE parameters. test_shader_3 has fewer generated variants due to
122        # a custom specified generate_variant_forall field.
123        expected_output_shaders = {
124            "test_shader_1_float",
125            "test_shader_1_inplace_float",
126            "test_shader_1_inplace_int8",
127            "test_shader_1_int8",
128            "test_shader_3_float",
129            "test_shader_3_int",
130        }
131
132        actual_output_shaders = set(self.generator.output_shader_map.keys())
133
134        self.assertEqual(expected_output_shaders, actual_output_shaders)
135