1import tempfile 2import unittest 3 4from tools.gen_vulkan_spv import DEFAULT_ENV, SPVGenerator 5 6 7#################### 8# Data for testing # 9#################### 10 11test_shader = """ 12#version 450 core 13 14#define FORMAT ${FORMAT} 15#define PRECISION ${PRECISION} 16#define OP(X) ${OPERATOR} 17 18$def is_int(dtype): 19$ return dtype in {"int", "int32", "int8"} 20 21$def is_uint(dtype): 22$ return dtype in {"uint", "uint32", "uint8"} 23 24$if is_int(DTYPE): 25 #define VEC4_T ivec4 26$elif is_uint(DTYPE): 27 #define VEC4_T uvec4 28$else: 29 #define VEC4_T vec4 30 31$if not INPLACE: 32 $if is_int(DTYPE): 33 layout(set = 0, binding = 0, FORMAT) uniform PRECISION restrict writeonly iimage3D uOutput; 34 layout(set = 0, binding = 1) uniform PRECISION isampler3D uInput; 35 $elif is_uint(DTYPE): 36 layout(set = 0, binding = 0, FORMAT) uniform PRECISION restrict writeonly uimage3D uOutput; 37 layout(set = 0, binding = 1) uniform PRECISION usampler3D uInput; 38 $else: 39 layout(set = 0, binding = 0, FORMAT) uniform PRECISION restrict writeonly image3D uOutput; 40 layout(set = 0, binding = 1) uniform PRECISION sampler3D uInput; 41$else: 42 $if is_int(DTYPE): 43 layout(set = 0, binding = 0, FORMAT) uniform PRECISION restrict iimage3D uOutput; 44 $elif is_uint(DTYPE): 45 layout(set = 0, binding = 0, FORMAT) uniform PRECISION restrict uimage3D uOutput; 46 $else: 47 layout(set = 0, binding = 0, FORMAT) uniform PRECISION restrict image3D uOutput; 48 49layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in; 50 51void main() { 52 const ivec3 pos = ivec3(gl_GlobalInvocationID); 53 $if not INPLACE: 54 VEC4_T v = texelFetch(uInput, pos, 0); 55 $else: 56 VEC4_T v = imageLoad(uOutput, pos); 57 $for i in range(ITER[0]): 58 for (int i = 0; i < ${ITER[1]}; ++i) { 59 v = OP(v + i); 60 } 61 imageStore(uOutput, pos, OP(v)); 62} 63 64""" 65 66test_params_yaml = """ 67test_shader: 68 parameter_names_with_default_values: 69 DTYPE: float 70 INPLACE: false 71 OPERATOR: X + 3 72 ITER: !!python/tuple [3, 5] 73 generate_variant_forall: 74 INPLACE: 75 - VALUE: false 76 SUFFIX: "" 77 - VALUE: true 78 SUFFIX: inplace 79 DTYPE: 80 - VALUE: int8 81 - VALUE: float 82 shader_variants: 83 - NAME: test_shader_1 84 - NAME: test_shader_3 85 OPERATOR: X - 1 86 ITER: !!python/tuple [3, 2] 87 generate_variant_forall: 88 DTYPE: 89 - VALUE: float 90 - VALUE: int 91 92""" 93 94############## 95# Unit Tests # 96############## 97 98 99class TestVulkanSPVCodegen(unittest.TestCase): 100 def setUp(self) -> None: 101 self.tmpdir = tempfile.TemporaryDirectory() 102 103 with open(f"{self.tmpdir.name}/test_shader.glsl,", "w") as f: 104 f.write(test_shader) 105 106 with open(f"{self.tmpdir.name}/test_params.yaml", "w") as f: 107 f.write(test_params_yaml) 108 109 self.tmpoutdir = tempfile.TemporaryDirectory() 110 111 self.generator = SPVGenerator( 112 src_dir_paths=self.tmpdir.name, env=DEFAULT_ENV, glslc_path=None 113 ) 114 115 def cleanUp(self) -> None: 116 self.tmpdir.cleanup() 117 self.tmpoutdir.cleanup() 118 119 def testOutputMap(self) -> None: 120 # Each shader variant will produce variants generated based on all possible combinations 121 # of the DTYPE and INPLACE parameters. test_shader_3 has fewer generated variants due to 122 # a custom specified generate_variant_forall field. 123 expected_output_shaders = { 124 "test_shader_1_float", 125 "test_shader_1_inplace_float", 126 "test_shader_1_inplace_int8", 127 "test_shader_1_int8", 128 "test_shader_3_float", 129 "test_shader_3_int", 130 } 131 132 actual_output_shaders = set(self.generator.output_shader_map.keys()) 133 134 self.assertEqual(expected_output_shaders, actual_output_shaders) 135