1 // Copyright (c) 2021 The Vulkano developers
2 // Licensed under the Apache License, Version 2.0
3 // <LICENSE-APACHE or
4 // https://www.apache.org/licenses/LICENSE-2.0> or the MIT
5 // license <LICENSE-MIT or https://opensource.org/licenses/MIT>,
6 // at your option. All files in the project carrying such
7 // notice may not be copied, modified, or distributed except
8 // according to those terms.
9 
10 //! Extraction of information from SPIR-V modules, that is needed by the rest of Vulkano.
11 
12 use super::{DescriptorBindingRequirements, FragmentShaderExecution, FragmentTestsStages};
13 use crate::{
14     descriptor_set::layout::DescriptorType,
15     image::view::ImageViewType,
16     pipeline::layout::PushConstantRange,
17     shader::{
18         spirv::{
19             Capability, Decoration, Dim, ExecutionMode, ExecutionModel, Id, Instruction, Spirv,
20             StorageClass,
21         },
22         DescriptorIdentifier, DescriptorRequirements, EntryPointInfo, GeometryShaderExecution,
23         GeometryShaderInput, ShaderExecution, ShaderInterface, ShaderInterfaceEntry,
24         ShaderInterfaceEntryType, ShaderScalarType, ShaderStage,
25         SpecializationConstantRequirements,
26     },
27     DeviceSize,
28 };
29 use ahash::{HashMap, HashSet};
30 use std::borrow::Cow;
31 
32 /// Returns an iterator of the capabilities used by `spirv`.
33 #[inline]
spirv_capabilities(spirv: &Spirv) -> impl Iterator<Item = &Capability>34 pub fn spirv_capabilities(spirv: &Spirv) -> impl Iterator<Item = &Capability> {
35     spirv
36         .iter_capability()
37         .filter_map(|instruction| match instruction {
38             Instruction::Capability { capability } => Some(capability),
39             _ => None,
40         })
41 }
42 
43 /// Returns an iterator of the extensions used by `spirv`.
44 #[inline]
spirv_extensions(spirv: &Spirv) -> impl Iterator<Item = &str>45 pub fn spirv_extensions(spirv: &Spirv) -> impl Iterator<Item = &str> {
46     spirv
47         .iter_extension()
48         .filter_map(|instruction| match instruction {
49             Instruction::Extension { name } => Some(name.as_str()),
50             _ => None,
51         })
52 }
53 
54 /// Returns an iterator over all entry points in `spirv`, with information about the entry point.
55 #[inline]
entry_points( spirv: &Spirv, ) -> impl Iterator<Item = (String, ExecutionModel, EntryPointInfo)> + '_56 pub fn entry_points(
57     spirv: &Spirv,
58 ) -> impl Iterator<Item = (String, ExecutionModel, EntryPointInfo)> + '_ {
59     let interface_variables = interface_variables(spirv);
60 
61     spirv.iter_entry_point().filter_map(move |instruction| {
62         let (execution_model, function_id, entry_point_name, interface) = match instruction {
63             Instruction::EntryPoint {
64                 execution_model,
65                 entry_point,
66                 name,
67                 interface,
68                 ..
69             } => (*execution_model, *entry_point, name, interface),
70             _ => return None,
71         };
72 
73         let execution = shader_execution(spirv, execution_model, function_id);
74         let stage = ShaderStage::from(execution);
75 
76         let descriptor_binding_requirements = inspect_entry_point(
77             &interface_variables.descriptor_binding,
78             spirv,
79             stage,
80             function_id,
81         );
82         let push_constant_requirements = push_constant_requirements(spirv, stage);
83         let specialization_constant_requirements = specialization_constant_requirements(spirv);
84         let input_interface = shader_interface(
85             spirv,
86             interface,
87             StorageClass::Input,
88             matches!(
89                 execution_model,
90                 ExecutionModel::TessellationControl
91                     | ExecutionModel::TessellationEvaluation
92                     | ExecutionModel::Geometry
93             ),
94         );
95         let output_interface = shader_interface(
96             spirv,
97             interface,
98             StorageClass::Output,
99             matches!(execution_model, ExecutionModel::TessellationControl),
100         );
101 
102         Some((
103             entry_point_name.clone(),
104             execution_model,
105             EntryPointInfo {
106                 execution,
107                 descriptor_binding_requirements,
108                 push_constant_requirements,
109                 specialization_constant_requirements,
110                 input_interface,
111                 output_interface,
112             },
113         ))
114     })
115 }
116 
117 /// Extracts the `ShaderExecution` for the entry point `function_id` from `spirv`.
shader_execution( spirv: &Spirv, execution_model: ExecutionModel, function_id: Id, ) -> ShaderExecution118 fn shader_execution(
119     spirv: &Spirv,
120     execution_model: ExecutionModel,
121     function_id: Id,
122 ) -> ShaderExecution {
123     match execution_model {
124         ExecutionModel::Vertex => ShaderExecution::Vertex,
125 
126         ExecutionModel::TessellationControl => ShaderExecution::TessellationControl,
127 
128         ExecutionModel::TessellationEvaluation => ShaderExecution::TessellationEvaluation,
129 
130         ExecutionModel::Geometry => {
131             let mut input = None;
132 
133             for instruction in spirv.iter_execution_mode() {
134                 let mode = match instruction {
135                     Instruction::ExecutionMode {
136                         entry_point, mode, ..
137                     } if *entry_point == function_id => mode,
138                     _ => continue,
139                 };
140 
141                 match mode {
142                     ExecutionMode::InputPoints => {
143                         input = Some(GeometryShaderInput::Points);
144                     }
145                     ExecutionMode::InputLines => {
146                         input = Some(GeometryShaderInput::Lines);
147                     }
148                     ExecutionMode::InputLinesAdjacency => {
149                         input = Some(GeometryShaderInput::LinesWithAdjacency);
150                     }
151                     ExecutionMode::Triangles => {
152                         input = Some(GeometryShaderInput::Triangles);
153                     }
154                     ExecutionMode::InputTrianglesAdjacency => {
155                         input = Some(GeometryShaderInput::TrianglesWithAdjacency);
156                     }
157                     _ => (),
158                 }
159             }
160 
161             ShaderExecution::Geometry(GeometryShaderExecution {
162                 input: input
163                     .expect("Geometry shader does not have an input primitive ExecutionMode"),
164             })
165         }
166 
167         ExecutionModel::Fragment => {
168             let mut fragment_tests_stages = FragmentTestsStages::Late;
169 
170             for instruction in spirv.iter_execution_mode() {
171                 let mode = match instruction {
172                     Instruction::ExecutionMode {
173                         entry_point, mode, ..
174                     } if *entry_point == function_id => mode,
175                     _ => continue,
176                 };
177 
178                 #[allow(clippy::single_match)]
179                 match mode {
180                     ExecutionMode::EarlyFragmentTests => {
181                         fragment_tests_stages = FragmentTestsStages::Early;
182                     }
183                     /*ExecutionMode::EarlyAndLateFragmentTestsAMD => {
184                         fragment_tests_stages = FragmentTestsStages::EarlyAndLate;
185                     }*/
186                     _ => (),
187                 }
188             }
189 
190             ShaderExecution::Fragment(FragmentShaderExecution {
191                 fragment_tests_stages,
192             })
193         }
194 
195         ExecutionModel::GLCompute => ShaderExecution::Compute,
196 
197         ExecutionModel::RayGenerationKHR => ShaderExecution::RayGeneration,
198         ExecutionModel::IntersectionKHR => ShaderExecution::Intersection,
199         ExecutionModel::AnyHitKHR => ShaderExecution::AnyHit,
200         ExecutionModel::ClosestHitKHR => ShaderExecution::ClosestHit,
201         ExecutionModel::MissKHR => ShaderExecution::Miss,
202         ExecutionModel::CallableKHR => ShaderExecution::Callable,
203 
204         ExecutionModel::TaskNV => ShaderExecution::Task,
205         ExecutionModel::MeshNV => ShaderExecution::Mesh,
206 
207         ExecutionModel::Kernel => todo!(),
208     }
209 }
210 
211 #[derive(Clone, Debug, Default)]
212 struct InterfaceVariables {
213     descriptor_binding: HashMap<Id, DescriptorBindingVariable>,
214 }
215 
216 // See also section 14.5.2 of the Vulkan specs: Descriptor Set Interface.
217 #[derive(Clone, Debug)]
218 struct DescriptorBindingVariable {
219     set: u32,
220     binding: u32,
221     reqs: DescriptorBindingRequirements,
222 }
223 
interface_variables(spirv: &Spirv) -> InterfaceVariables224 fn interface_variables(spirv: &Spirv) -> InterfaceVariables {
225     let mut variables = InterfaceVariables::default();
226 
227     for instruction in spirv.iter_global() {
228         if let Instruction::Variable {
229             result_id,
230             result_type_id: _,
231             storage_class,
232             ..
233         } = instruction
234         {
235             match storage_class {
236                 StorageClass::StorageBuffer
237                 | StorageClass::Uniform
238                 | StorageClass::UniformConstant => {
239                     variables.descriptor_binding.insert(
240                         *result_id,
241                         descriptor_binding_requirements_of(spirv, *result_id),
242                     );
243                 }
244                 _ => (),
245             }
246         }
247     }
248 
249     variables
250 }
251 
inspect_entry_point( global: &HashMap<Id, DescriptorBindingVariable>, spirv: &Spirv, stage: ShaderStage, entry_point: Id, ) -> HashMap<(u32, u32), DescriptorBindingRequirements>252 fn inspect_entry_point(
253     global: &HashMap<Id, DescriptorBindingVariable>,
254     spirv: &Spirv,
255     stage: ShaderStage,
256     entry_point: Id,
257 ) -> HashMap<(u32, u32), DescriptorBindingRequirements> {
258     struct Context<'a> {
259         global: &'a HashMap<Id, DescriptorBindingVariable>,
260         spirv: &'a Spirv,
261         stage: ShaderStage,
262         inspected_functions: HashSet<Id>,
263         result: HashMap<Id, DescriptorBindingVariable>,
264     }
265 
266     impl<'a> Context<'a> {
267         fn instruction_chain<const N: usize>(
268             &mut self,
269             chain: [fn(&Spirv, Id) -> Option<Id>; N],
270             id: Id,
271         ) -> Option<(&mut DescriptorBindingVariable, Option<u32>)> {
272             let id = chain
273                 .into_iter()
274                 .try_fold(id, |id, func| func(self.spirv, id))?;
275 
276             if let Some(variable) = self.global.get(&id) {
277                 // Variable was accessed without an access chain, return with index 0.
278                 let variable = self.result.entry(id).or_insert_with(|| variable.clone());
279                 variable.reqs.stages = self.stage.into();
280                 return Some((variable, Some(0)));
281             }
282 
283             let (id, indexes) = match *self.spirv.id(id).instruction() {
284                 Instruction::AccessChain {
285                     base, ref indexes, ..
286                 } => (base, indexes),
287                 _ => return None,
288             };
289 
290             if let Some(variable) = self.global.get(&id) {
291                 // Variable was accessed with an access chain.
292                 // Retrieve index from instruction if it's a constant value.
293                 // TODO: handle a `None` index too?
294                 let index = match *self.spirv.id(*indexes.first().unwrap()).instruction() {
295                     Instruction::Constant { ref value, .. } => Some(value[0]),
296                     _ => None,
297                 };
298                 let variable = self.result.entry(id).or_insert_with(|| variable.clone());
299                 variable.reqs.stages = self.stage.into();
300                 return Some((variable, index));
301             }
302 
303             None
304         }
305 
306         fn inspect_entry_point_r(&mut self, function: Id) {
307             fn desc_reqs(
308                 descriptor_variable: Option<(&mut DescriptorBindingVariable, Option<u32>)>,
309             ) -> Option<&mut DescriptorRequirements> {
310                 descriptor_variable
311                     .map(|(variable, index)| variable.reqs.descriptors.entry(index).or_default())
312             }
313 
314             fn inst_image_texel_pointer(spirv: &Spirv, id: Id) -> Option<Id> {
315                 match *spirv.id(id).instruction() {
316                     Instruction::ImageTexelPointer { image, .. } => Some(image),
317                     _ => None,
318                 }
319             }
320 
321             fn inst_load(spirv: &Spirv, id: Id) -> Option<Id> {
322                 match *spirv.id(id).instruction() {
323                     Instruction::Load { pointer, .. } => Some(pointer),
324                     _ => None,
325                 }
326             }
327 
328             fn inst_sampled_image(spirv: &Spirv, id: Id) -> Option<Id> {
329                 match *spirv.id(id).instruction() {
330                     Instruction::SampledImage { sampler, .. } => Some(sampler),
331                     _ => Some(id),
332                 }
333             }
334 
335             self.inspected_functions.insert(function);
336             let mut in_function = false;
337 
338             for instruction in self.spirv.instructions() {
339                 if !in_function {
340                     match *instruction {
341                         Instruction::Function { result_id, .. } if result_id == function => {
342                             in_function = true;
343                         }
344                         _ => {}
345                     }
346                 } else {
347                     let stage = self.stage;
348 
349                     match *instruction {
350                         Instruction::AtomicLoad { pointer, .. } => {
351                             // Storage buffer
352                             if let Some(desc_reqs) = desc_reqs(self.instruction_chain([], pointer))
353                             {
354                                 desc_reqs.memory_read = stage.into();
355                             }
356 
357                             // Storage image
358                             if let Some(desc_reqs) = desc_reqs(
359                                 self.instruction_chain([inst_image_texel_pointer], pointer),
360                             ) {
361                                 desc_reqs.memory_read = stage.into();
362                                 desc_reqs.storage_image_atomic = true;
363                             }
364                         }
365 
366                         Instruction::AtomicStore { pointer, .. } => {
367                             // Storage buffer
368                             if let Some(desc_reqs) = desc_reqs(self.instruction_chain([], pointer))
369                             {
370                                 desc_reqs.memory_write = stage.into();
371                             }
372 
373                             // Storage image
374                             if let Some(desc_reqs) = desc_reqs(
375                                 self.instruction_chain([inst_image_texel_pointer], pointer),
376                             ) {
377                                 desc_reqs.memory_write = stage.into();
378                                 desc_reqs.storage_image_atomic = true;
379                             }
380                         }
381 
382                         Instruction::AtomicExchange { pointer, .. }
383                         | Instruction::AtomicCompareExchange { pointer, .. }
384                         | Instruction::AtomicCompareExchangeWeak { pointer, .. }
385                         | Instruction::AtomicIIncrement { pointer, .. }
386                         | Instruction::AtomicIDecrement { pointer, .. }
387                         | Instruction::AtomicIAdd { pointer, .. }
388                         | Instruction::AtomicISub { pointer, .. }
389                         | Instruction::AtomicSMin { pointer, .. }
390                         | Instruction::AtomicUMin { pointer, .. }
391                         | Instruction::AtomicSMax { pointer, .. }
392                         | Instruction::AtomicUMax { pointer, .. }
393                         | Instruction::AtomicAnd { pointer, .. }
394                         | Instruction::AtomicOr { pointer, .. }
395                         | Instruction::AtomicXor { pointer, .. }
396                         | Instruction::AtomicFlagTestAndSet { pointer, .. }
397                         | Instruction::AtomicFlagClear { pointer, .. }
398                         | Instruction::AtomicFMinEXT { pointer, .. }
399                         | Instruction::AtomicFMaxEXT { pointer, .. }
400                         | Instruction::AtomicFAddEXT { pointer, .. } => {
401                             // Storage buffer
402                             if let Some(desc_reqs) = desc_reqs(self.instruction_chain([], pointer))
403                             {
404                                 desc_reqs.memory_read = stage.into();
405                                 desc_reqs.memory_write = stage.into();
406                             }
407 
408                             // Storage image
409                             if let Some(desc_reqs) = desc_reqs(
410                                 self.instruction_chain([inst_image_texel_pointer], pointer),
411                             ) {
412                                 desc_reqs.memory_read = stage.into();
413                                 desc_reqs.memory_write = stage.into();
414                                 desc_reqs.storage_image_atomic = true;
415                             }
416                         }
417 
418                         Instruction::CopyMemory { target, source, .. } => {
419                             self.instruction_chain([], target);
420                             self.instruction_chain([], source);
421                         }
422 
423                         Instruction::CopyObject { operand, .. } => {
424                             self.instruction_chain([], operand);
425                         }
426 
427                         Instruction::ExtInst { ref operands, .. } => {
428                             // We don't know which extended instructions take pointers,
429                             // so we must interpret every operand as a pointer.
430                             for &operand in operands {
431                                 self.instruction_chain([], operand);
432                             }
433                         }
434 
435                         Instruction::FunctionCall {
436                             function,
437                             ref arguments,
438                             ..
439                         } => {
440                             // Rather than trying to figure out the type of each argument, we just
441                             // try all of them as pointers.
442                             for &argument in arguments {
443                                 self.instruction_chain([], argument);
444                             }
445 
446                             if !self.inspected_functions.contains(&function) {
447                                 self.inspect_entry_point_r(function);
448                             }
449                         }
450 
451                         Instruction::FunctionEnd => return,
452 
453                         Instruction::ImageGather {
454                             sampled_image,
455                             image_operands,
456                             ..
457                         }
458                         | Instruction::ImageSparseGather {
459                             sampled_image,
460                             image_operands,
461                             ..
462                         } => {
463                             if let Some(desc_reqs) =
464                                 desc_reqs(self.instruction_chain(
465                                     [inst_sampled_image, inst_load],
466                                     sampled_image,
467                                 ))
468                             {
469                                 desc_reqs.memory_read = stage.into();
470                                 desc_reqs.sampler_no_ycbcr_conversion = true;
471 
472                                 if image_operands.as_ref().map_or(false, |image_operands| {
473                                     image_operands.bias.is_some()
474                                         || image_operands.const_offset.is_some()
475                                         || image_operands.offset.is_some()
476                                 }) {
477                                     desc_reqs.sampler_no_unnormalized_coordinates = true;
478                                 }
479                             }
480                         }
481 
482                         Instruction::ImageDrefGather { sampled_image, .. }
483                         | Instruction::ImageSparseDrefGather { sampled_image, .. } => {
484                             if let Some(desc_reqs) =
485                                 desc_reqs(self.instruction_chain(
486                                     [inst_sampled_image, inst_load],
487                                     sampled_image,
488                                 ))
489                             {
490                                 desc_reqs.memory_read = stage.into();
491                                 desc_reqs.sampler_no_unnormalized_coordinates = true;
492                                 desc_reqs.sampler_no_ycbcr_conversion = true;
493                             }
494                         }
495 
496                         Instruction::ImageSampleImplicitLod {
497                             sampled_image,
498                             image_operands,
499                             ..
500                         }
501                         | Instruction::ImageSampleProjImplicitLod {
502                             sampled_image,
503                             image_operands,
504                             ..
505                         }
506                         | Instruction::ImageSparseSampleProjImplicitLod {
507                             sampled_image,
508                             image_operands,
509                             ..
510                         }
511                         | Instruction::ImageSparseSampleImplicitLod {
512                             sampled_image,
513                             image_operands,
514                             ..
515                         } => {
516                             if let Some(desc_reqs) =
517                                 desc_reqs(self.instruction_chain(
518                                     [inst_sampled_image, inst_load],
519                                     sampled_image,
520                                 ))
521                             {
522                                 desc_reqs.memory_read = stage.into();
523                                 desc_reqs.sampler_no_unnormalized_coordinates = true;
524 
525                                 if image_operands.as_ref().map_or(false, |image_operands| {
526                                     image_operands.const_offset.is_some()
527                                         || image_operands.offset.is_some()
528                                 }) {
529                                     desc_reqs.sampler_no_ycbcr_conversion = true;
530                                 }
531                             }
532                         }
533 
534                         Instruction::ImageSampleProjExplicitLod {
535                             sampled_image,
536                             image_operands,
537                             ..
538                         }
539                         | Instruction::ImageSparseSampleProjExplicitLod {
540                             sampled_image,
541                             image_operands,
542                             ..
543                         } => {
544                             if let Some(desc_reqs) =
545                                 desc_reqs(self.instruction_chain(
546                                     [inst_sampled_image, inst_load],
547                                     sampled_image,
548                                 ))
549                             {
550                                 desc_reqs.memory_read = stage.into();
551                                 desc_reqs.sampler_no_unnormalized_coordinates = true;
552 
553                                 if image_operands.const_offset.is_some()
554                                     || image_operands.offset.is_some()
555                                 {
556                                     desc_reqs.sampler_no_ycbcr_conversion = true;
557                                 }
558                             }
559                         }
560 
561                         Instruction::ImageSampleDrefImplicitLod {
562                             sampled_image,
563                             image_operands,
564                             ..
565                         }
566                         | Instruction::ImageSampleProjDrefImplicitLod {
567                             sampled_image,
568                             image_operands,
569                             ..
570                         }
571                         | Instruction::ImageSparseSampleDrefImplicitLod {
572                             sampled_image,
573                             image_operands,
574                             ..
575                         }
576                         | Instruction::ImageSparseSampleProjDrefImplicitLod {
577                             sampled_image,
578                             image_operands,
579                             ..
580                         } => {
581                             if let Some(desc_reqs) =
582                                 desc_reqs(self.instruction_chain(
583                                     [inst_sampled_image, inst_load],
584                                     sampled_image,
585                                 ))
586                             {
587                                 desc_reqs.memory_read = stage.into();
588                                 desc_reqs.sampler_no_unnormalized_coordinates = true;
589                                 desc_reqs.sampler_compare = true;
590 
591                                 if image_operands.as_ref().map_or(false, |image_operands| {
592                                     image_operands.const_offset.is_some()
593                                         || image_operands.offset.is_some()
594                                 }) {
595                                     desc_reqs.sampler_no_ycbcr_conversion = true;
596                                 }
597                             }
598                         }
599 
600                         Instruction::ImageSampleDrefExplicitLod {
601                             sampled_image,
602                             image_operands,
603                             ..
604                         }
605                         | Instruction::ImageSampleProjDrefExplicitLod {
606                             sampled_image,
607                             image_operands,
608                             ..
609                         }
610                         | Instruction::ImageSparseSampleDrefExplicitLod {
611                             sampled_image,
612                             image_operands,
613                             ..
614                         }
615                         | Instruction::ImageSparseSampleProjDrefExplicitLod {
616                             sampled_image,
617                             image_operands,
618                             ..
619                         } => {
620                             if let Some(desc_reqs) =
621                                 desc_reqs(self.instruction_chain(
622                                     [inst_sampled_image, inst_load],
623                                     sampled_image,
624                                 ))
625                             {
626                                 desc_reqs.memory_read = stage.into();
627                                 desc_reqs.sampler_no_unnormalized_coordinates = true;
628                                 desc_reqs.sampler_compare = true;
629 
630                                 if image_operands.const_offset.is_some()
631                                     || image_operands.offset.is_some()
632                                 {
633                                     desc_reqs.sampler_no_ycbcr_conversion = true;
634                                 }
635                             }
636                         }
637 
638                         Instruction::ImageSampleExplicitLod {
639                             sampled_image,
640                             image_operands,
641                             ..
642                         }
643                         | Instruction::ImageSparseSampleExplicitLod {
644                             sampled_image,
645                             image_operands,
646                             ..
647                         } => {
648                             if let Some(desc_reqs) =
649                                 desc_reqs(self.instruction_chain(
650                                     [inst_sampled_image, inst_load],
651                                     sampled_image,
652                                 ))
653                             {
654                                 desc_reqs.memory_read = stage.into();
655 
656                                 if image_operands.bias.is_some()
657                                     || image_operands.const_offset.is_some()
658                                     || image_operands.offset.is_some()
659                                 {
660                                     desc_reqs.sampler_no_unnormalized_coordinates = true;
661                                 }
662 
663                                 if image_operands.const_offset.is_some()
664                                     || image_operands.offset.is_some()
665                                 {
666                                     desc_reqs.sampler_no_ycbcr_conversion = true;
667                                 }
668                             }
669                         }
670 
671                         Instruction::ImageTexelPointer { image, .. } => {
672                             self.instruction_chain([], image);
673                         }
674 
675                         Instruction::ImageRead { image, .. } => {
676                             if let Some(desc_reqs) =
677                                 desc_reqs(self.instruction_chain([inst_load], image))
678                             {
679                                 desc_reqs.memory_read = stage.into();
680                             }
681                         }
682 
683                         Instruction::ImageWrite { image, .. } => {
684                             if let Some(desc_reqs) =
685                                 desc_reqs(self.instruction_chain([inst_load], image))
686                             {
687                                 desc_reqs.memory_write = stage.into();
688                             }
689                         }
690 
691                         Instruction::Load { pointer, .. } => {
692                             if let Some((binding_variable, index)) =
693                                 self.instruction_chain([], pointer)
694                             {
695                                 // Only loads on buffers access memory directly.
696                                 // Loads on images load the image object itself, but don't touch
697                                 // the texels in memory yet.
698                                 if binding_variable.reqs.descriptor_types.iter().any(|ty| {
699                                     matches!(
700                                         ty,
701                                         DescriptorType::UniformBuffer
702                                             | DescriptorType::UniformBufferDynamic
703                                             | DescriptorType::StorageBuffer
704                                             | DescriptorType::StorageBufferDynamic
705                                     )
706                                 }) {
707                                     if let Some(desc_reqs) =
708                                         desc_reqs(Some((binding_variable, index)))
709                                     {
710                                         desc_reqs.memory_read = stage.into();
711                                     }
712                                 }
713                             }
714                         }
715 
716                         Instruction::SampledImage { image, sampler, .. } => {
717                             let identifier = match self.instruction_chain([inst_load], image) {
718                                 Some((variable, Some(index))) => DescriptorIdentifier {
719                                     set: variable.set,
720                                     binding: variable.binding,
721                                     index,
722                                 },
723                                 _ => continue,
724                             };
725 
726                             if let Some(desc_reqs) =
727                                 desc_reqs(self.instruction_chain([inst_load], sampler))
728                             {
729                                 desc_reqs.sampler_with_images.insert(identifier);
730                             }
731                         }
732 
733                         Instruction::Store { pointer, .. } => {
734                             // This can only apply to buffers, right?
735                             if let Some(desc_reqs) = desc_reqs(self.instruction_chain([], pointer))
736                             {
737                                 desc_reqs.memory_write = stage.into();
738                             }
739                         }
740 
741                         _ => (),
742                     }
743                 }
744             }
745         }
746     }
747 
748     let mut context = Context {
749         global,
750         spirv,
751         stage,
752         inspected_functions: HashSet::default(),
753         result: HashMap::default(),
754     };
755     context.inspect_entry_point_r(entry_point);
756 
757     context
758         .result
759         .into_values()
760         .map(|variable| ((variable.set, variable.binding), variable.reqs))
761         .collect()
762 }
763 
764 /// Returns a `DescriptorBindingRequirements` value for the pointed type.
765 ///
766 /// See also section 14.5.2 of the Vulkan specs: Descriptor Set Interface
descriptor_binding_requirements_of(spirv: &Spirv, variable_id: Id) -> DescriptorBindingVariable767 fn descriptor_binding_requirements_of(spirv: &Spirv, variable_id: Id) -> DescriptorBindingVariable {
768     let variable_id_info = spirv.id(variable_id);
769 
770     let mut reqs = DescriptorBindingRequirements {
771         descriptor_count: Some(1),
772         ..Default::default()
773     };
774 
775     let (mut next_type_id, is_storage_buffer) = {
776         let variable_type_id = match *variable_id_info.instruction() {
777             Instruction::Variable { result_type_id, .. } => result_type_id,
778             _ => panic!("Id {} is not a variable", variable_id),
779         };
780 
781         match *spirv.id(variable_type_id).instruction() {
782             Instruction::TypePointer {
783                 ty, storage_class, ..
784             } => (Some(ty), storage_class == StorageClass::StorageBuffer),
785             _ => panic!(
786                 "Variable {} result_type_id does not refer to a TypePointer instruction",
787                 variable_id
788             ),
789         }
790     };
791 
792     while let Some(id) = next_type_id {
793         let id_info = spirv.id(id);
794 
795         next_type_id = match *id_info.instruction() {
796             Instruction::TypeStruct { .. } => {
797                 let decoration_block = id_info.iter_decoration().any(|instruction| {
798                     matches!(
799                         instruction,
800                         Instruction::Decorate {
801                             decoration: Decoration::Block,
802                             ..
803                         }
804                     )
805                 });
806 
807                 let decoration_buffer_block = id_info.iter_decoration().any(|instruction| {
808                     matches!(
809                         instruction,
810                         Instruction::Decorate {
811                             decoration: Decoration::BufferBlock,
812                             ..
813                         }
814                     )
815                 });
816 
817                 assert!(
818                     decoration_block ^ decoration_buffer_block,
819                     "Structs in shader interface are expected to be decorated with one of Block or \
820                     BufferBlock",
821                 );
822 
823                 if decoration_buffer_block || decoration_block && is_storage_buffer {
824                     reqs.descriptor_types = vec![
825                         DescriptorType::StorageBuffer,
826                         DescriptorType::StorageBufferDynamic,
827                     ];
828                 } else {
829                     reqs.descriptor_types = vec![
830                         DescriptorType::UniformBuffer,
831                         DescriptorType::UniformBufferDynamic,
832                     ];
833                 };
834 
835                 None
836             }
837 
838             Instruction::TypeImage {
839                 sampled_type,
840                 dim,
841                 arrayed,
842                 ms,
843                 sampled,
844                 image_format,
845                 ..
846             } => {
847                 assert!(
848                     sampled != 0,
849                     "Vulkan requires that variables of type OpTypeImage have a Sampled operand of \
850                     1 or 2",
851                 );
852                 reqs.image_format = image_format.into();
853                 reqs.image_multisampled = ms != 0;
854                 reqs.image_scalar_type = Some(match *spirv.id(sampled_type).instruction() {
855                     Instruction::TypeInt {
856                         width, signedness, ..
857                     } => {
858                         assert!(width == 32); // TODO: 64-bit components
859                         match signedness {
860                             0 => ShaderScalarType::Uint,
861                             1 => ShaderScalarType::Sint,
862                             _ => unreachable!(),
863                         }
864                     }
865                     Instruction::TypeFloat { width, .. } => {
866                         assert!(width == 32); // TODO: 64-bit components
867                         ShaderScalarType::Float
868                     }
869                     _ => unreachable!(),
870                 });
871 
872                 match dim {
873                     Dim::SubpassData => {
874                         assert!(
875                             reqs.image_format.is_none(),
876                             "If Dim is SubpassData, Image Format must be Unknown",
877                         );
878                         assert!(sampled == 2, "If Dim is SubpassData, Sampled must be 2");
879                         assert!(arrayed == 0, "If Dim is SubpassData, Arrayed must be 0");
880 
881                         reqs.descriptor_types = vec![DescriptorType::InputAttachment];
882                     }
883                     Dim::Buffer => {
884                         if sampled == 1 {
885                             reqs.descriptor_types = vec![DescriptorType::UniformTexelBuffer];
886                         } else {
887                             reqs.descriptor_types = vec![DescriptorType::StorageTexelBuffer];
888                         }
889                     }
890                     _ => {
891                         reqs.image_view_type = Some(match (dim, arrayed) {
892                             (Dim::Dim1D, 0) => ImageViewType::Dim1d,
893                             (Dim::Dim1D, 1) => ImageViewType::Dim1dArray,
894                             (Dim::Dim2D, 0) => ImageViewType::Dim2d,
895                             (Dim::Dim2D, 1) => ImageViewType::Dim2dArray,
896                             (Dim::Dim3D, 0) => ImageViewType::Dim3d,
897                             (Dim::Dim3D, 1) => {
898                                 panic!("Vulkan doesn't support arrayed 3D textures")
899                             }
900                             (Dim::Cube, 0) => ImageViewType::Cube,
901                             (Dim::Cube, 1) => ImageViewType::CubeArray,
902                             (Dim::Rect, _) => {
903                                 panic!("Vulkan doesn't support rectangle textures")
904                             }
905                             _ => unreachable!(),
906                         });
907 
908                         if reqs.descriptor_types.is_empty() {
909                             if sampled == 1 {
910                                 reqs.descriptor_types = vec![DescriptorType::SampledImage];
911                             } else {
912                                 reqs.descriptor_types = vec![DescriptorType::StorageImage];
913                             }
914                         }
915                     }
916                 }
917 
918                 None
919             }
920 
921             Instruction::TypeSampler { .. } => {
922                 reqs.descriptor_types = vec![DescriptorType::Sampler];
923 
924                 None
925             }
926 
927             Instruction::TypeSampledImage { image_type, .. } => {
928                 reqs.descriptor_types = vec![DescriptorType::CombinedImageSampler];
929 
930                 Some(image_type)
931             }
932 
933             Instruction::TypeArray {
934                 element_type,
935                 length,
936                 ..
937             } => {
938                 let len = match spirv.id(length).instruction() {
939                     Instruction::Constant { value, .. } => {
940                         value.iter().rev().fold(0, |a, &b| (a << 32) | b as u64)
941                     }
942                     _ => panic!("failed to find array length"),
943                 };
944 
945                 if let Some(count) = reqs.descriptor_count.as_mut() {
946                     *count *= len as u32
947                 }
948 
949                 Some(element_type)
950             }
951 
952             Instruction::TypeRuntimeArray { element_type, .. } => {
953                 reqs.descriptor_count = None;
954 
955                 Some(element_type)
956             }
957 
958             Instruction::TypeAccelerationStructureKHR { .. } => None, // FIXME temporary workaround
959 
960             _ => {
961                 let name = variable_id_info
962                     .iter_name()
963                     .find_map(|instruction| match *instruction {
964                         Instruction::Name { ref name, .. } => Some(name.as_str()),
965                         _ => None,
966                     })
967                     .unwrap_or("__unnamed");
968 
969                 panic!(
970                     "Couldn't find relevant type for global variable `{}` (id {}, maybe \
971                     unimplemented)",
972                     name, variable_id,
973                 );
974             }
975         };
976     }
977 
978     DescriptorBindingVariable {
979         set: variable_id_info
980             .iter_decoration()
981             .find_map(|instruction| match *instruction {
982                 Instruction::Decorate {
983                     decoration: Decoration::DescriptorSet { descriptor_set },
984                     ..
985                 } => Some(descriptor_set),
986                 _ => None,
987             })
988             .unwrap(),
989         binding: variable_id_info
990             .iter_decoration()
991             .find_map(|instruction| match *instruction {
992                 Instruction::Decorate {
993                     decoration: Decoration::Binding { binding_point },
994                     ..
995                 } => Some(binding_point),
996                 _ => None,
997             })
998             .unwrap(),
999         reqs,
1000     }
1001 }
1002 
1003 /// Extracts the `PushConstantRange` from `spirv`.
push_constant_requirements(spirv: &Spirv, stage: ShaderStage) -> Option<PushConstantRange>1004 fn push_constant_requirements(spirv: &Spirv, stage: ShaderStage) -> Option<PushConstantRange> {
1005     spirv
1006         .iter_global()
1007         .find_map(|instruction| match *instruction {
1008             Instruction::TypePointer {
1009                 ty,
1010                 storage_class: StorageClass::PushConstant,
1011                 ..
1012             } => {
1013                 let id_info = spirv.id(ty);
1014                 assert!(matches!(
1015                     id_info.instruction(),
1016                     Instruction::TypeStruct { .. }
1017                 ));
1018                 let start = offset_of_struct(spirv, ty);
1019                 let end =
1020                     size_of_type(spirv, ty).expect("Found runtime-sized push constants") as u32;
1021 
1022                 Some(PushConstantRange {
1023                     stages: stage.into(),
1024                     offset: start,
1025                     size: end - start,
1026                 })
1027             }
1028             _ => None,
1029         })
1030 }
1031 
1032 /// Extracts the `SpecializationConstantRequirements` from `spirv`.
specialization_constant_requirements( spirv: &Spirv, ) -> HashMap<u32, SpecializationConstantRequirements>1033 fn specialization_constant_requirements(
1034     spirv: &Spirv,
1035 ) -> HashMap<u32, SpecializationConstantRequirements> {
1036     spirv
1037         .iter_global()
1038         .filter_map(|instruction| {
1039             match *instruction {
1040                 Instruction::SpecConstantTrue {
1041                     result_type_id,
1042                     result_id,
1043                 }
1044                 | Instruction::SpecConstantFalse {
1045                     result_type_id,
1046                     result_id,
1047                 }
1048                 | Instruction::SpecConstant {
1049                     result_type_id,
1050                     result_id,
1051                     ..
1052                 }
1053                 | Instruction::SpecConstantComposite {
1054                     result_type_id,
1055                     result_id,
1056                     ..
1057                 } => spirv
1058                     .id(result_id)
1059                     .iter_decoration()
1060                     .find_map(|instruction| match *instruction {
1061                         Instruction::Decorate {
1062                             decoration:
1063                                 Decoration::SpecId {
1064                                     specialization_constant_id,
1065                                 },
1066                             ..
1067                         } => Some(specialization_constant_id),
1068                         _ => None,
1069                     })
1070                     .map(|constant_id| {
1071                         let size = match *spirv.id(result_type_id).instruction() {
1072                             Instruction::TypeBool { .. } => {
1073                                 // Translate bool to Bool32
1074                                 std::mem::size_of::<ash::vk::Bool32>() as DeviceSize
1075                             }
1076                             _ => size_of_type(spirv, result_type_id)
1077                                 .expect("Found runtime-sized specialization constant"),
1078                         };
1079                         (constant_id, SpecializationConstantRequirements { size })
1080                     }),
1081                 _ => None,
1082             }
1083         })
1084         .collect()
1085 }
1086 
1087 /// Extracts the `ShaderInterface` with the given storage class from `spirv`.
shader_interface( spirv: &Spirv, interface: &[Id], filter_storage_class: StorageClass, ignore_first_array: bool, ) -> ShaderInterface1088 fn shader_interface(
1089     spirv: &Spirv,
1090     interface: &[Id],
1091     filter_storage_class: StorageClass,
1092     ignore_first_array: bool,
1093 ) -> ShaderInterface {
1094     let elements: Vec<_> = interface
1095         .iter()
1096         .filter_map(|&id| {
1097             let (result_type_id, result_id) = match *spirv.id(id).instruction() {
1098                 Instruction::Variable {
1099                     result_type_id,
1100                     result_id,
1101                     storage_class,
1102                     ..
1103                 } if storage_class == filter_storage_class => (result_type_id, result_id),
1104                 _ => return None,
1105             };
1106 
1107             if is_builtin(spirv, result_id) {
1108                 return None;
1109             }
1110 
1111             let id_info = spirv.id(result_id);
1112 
1113             let name = id_info
1114                 .iter_name()
1115                 .find_map(|instruction| match *instruction {
1116                     Instruction::Name { ref name, .. } => Some(Cow::Owned(name.clone())),
1117                     _ => None,
1118                 });
1119 
1120             let location = id_info
1121                 .iter_decoration()
1122                 .find_map(|instruction| match *instruction {
1123                     Instruction::Decorate {
1124                         decoration: Decoration::Location { location },
1125                         ..
1126                     } => Some(location),
1127                     _ => None,
1128                 })
1129                 .unwrap_or_else(|| {
1130                     panic!(
1131                         "Input/output variable with id {} (name {:?}) is missing a location",
1132                         result_id, name,
1133                     )
1134                 });
1135             let component = id_info
1136                 .iter_decoration()
1137                 .find_map(|instruction| match *instruction {
1138                     Instruction::Decorate {
1139                         decoration: Decoration::Component { component },
1140                         ..
1141                     } => Some(component),
1142                     _ => None,
1143                 })
1144                 .unwrap_or(0);
1145 
1146             let ty = shader_interface_type_of(spirv, result_type_id, ignore_first_array);
1147             assert!(ty.num_elements >= 1);
1148 
1149             Some(ShaderInterfaceEntry {
1150                 location,
1151                 component,
1152                 ty,
1153                 name,
1154             })
1155         })
1156         .collect();
1157 
1158     // Checking for overlapping elements.
1159     for (offset, element1) in elements.iter().enumerate() {
1160         for element2 in elements.iter().skip(offset + 1) {
1161             if element1.location == element2.location
1162                 || (element1.location < element2.location
1163                     && element1.location + element1.ty.num_locations() > element2.location)
1164                 || (element2.location < element1.location
1165                     && element2.location + element2.ty.num_locations() > element1.location)
1166             {
1167                 panic!(
1168                     "The locations of attributes `{:?}` ({}..{}) and `{:?}` ({}..{}) overlap",
1169                     element1.name,
1170                     element1.location,
1171                     element1.location + element1.ty.num_locations(),
1172                     element2.name,
1173                     element2.location,
1174                     element2.location + element2.ty.num_locations(),
1175                 );
1176             }
1177         }
1178     }
1179 
1180     ShaderInterface { elements }
1181 }
1182 
1183 /// Returns the size of a type, or `None` if its size cannot be determined.
size_of_type(spirv: &Spirv, id: Id) -> Option<DeviceSize>1184 fn size_of_type(spirv: &Spirv, id: Id) -> Option<DeviceSize> {
1185     let id_info = spirv.id(id);
1186 
1187     match *id_info.instruction() {
1188         Instruction::TypeBool { .. } => {
1189             panic!("Can't put booleans in structs")
1190         }
1191         Instruction::TypeInt { width, .. } | Instruction::TypeFloat { width, .. } => {
1192             assert!(width % 8 == 0);
1193             Some(width as DeviceSize / 8)
1194         }
1195         Instruction::TypeVector {
1196             component_type,
1197             component_count,
1198             ..
1199         } => size_of_type(spirv, component_type)
1200             .map(|component_size| component_size * component_count as DeviceSize),
1201         Instruction::TypeMatrix {
1202             column_type,
1203             column_count,
1204             ..
1205         } => {
1206             // FIXME: row-major or column-major
1207             size_of_type(spirv, column_type)
1208                 .map(|column_size| column_size * column_count as DeviceSize)
1209         }
1210         Instruction::TypeArray { length, .. } => {
1211             let stride = id_info
1212                 .iter_decoration()
1213                 .find_map(|instruction| match *instruction {
1214                     Instruction::Decorate {
1215                         decoration: Decoration::ArrayStride { array_stride },
1216                         ..
1217                     } => Some(array_stride),
1218                     _ => None,
1219                 })
1220                 .unwrap();
1221             let length = match spirv.id(length).instruction() {
1222                 Instruction::Constant { value, .. } => Some(
1223                     value
1224                         .iter()
1225                         .rev()
1226                         .fold(0u64, |a, &b| (a << 32) | b as DeviceSize),
1227                 ),
1228                 _ => None,
1229             }
1230             .unwrap();
1231 
1232             Some(stride as DeviceSize * length)
1233         }
1234         Instruction::TypeRuntimeArray { .. } => None,
1235         Instruction::TypeStruct {
1236             ref member_types, ..
1237         } => {
1238             let mut end_of_struct = 0;
1239 
1240             for (&member, member_info) in member_types.iter().zip(id_info.iter_members()) {
1241                 // Built-ins have an unknown size.
1242                 if member_info.iter_decoration().any(|instruction| {
1243                     matches!(
1244                         instruction,
1245                         Instruction::MemberDecorate {
1246                             decoration: Decoration::BuiltIn { .. },
1247                             ..
1248                         }
1249                     )
1250                 }) {
1251                     return None;
1252                 }
1253 
1254                 // Some structs don't have `Offset` decorations, in the case they are used as local
1255                 // variables only. Ignoring these.
1256                 let offset =
1257                     member_info
1258                         .iter_decoration()
1259                         .find_map(|instruction| match *instruction {
1260                             Instruction::MemberDecorate {
1261                                 decoration: Decoration::Offset { byte_offset },
1262                                 ..
1263                             } => Some(byte_offset),
1264                             _ => None,
1265                         })?;
1266                 let size = size_of_type(spirv, member)?;
1267                 end_of_struct = end_of_struct.max(offset as DeviceSize + size);
1268             }
1269 
1270             Some(end_of_struct)
1271         }
1272         _ => panic!("Type {} not found", id),
1273     }
1274 }
1275 
1276 /// Returns the smallest offset of all members of a struct, or 0 if `id` is not a struct.
offset_of_struct(spirv: &Spirv, id: Id) -> u321277 fn offset_of_struct(spirv: &Spirv, id: Id) -> u32 {
1278     spirv
1279         .id(id)
1280         .iter_members()
1281         .filter_map(|member_info| {
1282             member_info
1283                 .iter_decoration()
1284                 .find_map(|instruction| match *instruction {
1285                     Instruction::MemberDecorate {
1286                         decoration: Decoration::Offset { byte_offset },
1287                         ..
1288                     } => Some(byte_offset),
1289                     _ => None,
1290                 })
1291         })
1292         .min()
1293         .unwrap_or(0)
1294 }
1295 
1296 /// If `ignore_first_array` is true, the function expects the outermost instruction to be
1297 /// `OpTypeArray`. If it's the case, the OpTypeArray will be ignored. If not, the function will
1298 /// panic.
shader_interface_type_of( spirv: &Spirv, id: Id, ignore_first_array: bool, ) -> ShaderInterfaceEntryType1299 fn shader_interface_type_of(
1300     spirv: &Spirv,
1301     id: Id,
1302     ignore_first_array: bool,
1303 ) -> ShaderInterfaceEntryType {
1304     match *spirv.id(id).instruction() {
1305         Instruction::TypeInt {
1306             width, signedness, ..
1307         } => {
1308             assert!(!ignore_first_array);
1309             ShaderInterfaceEntryType {
1310                 base_type: match signedness {
1311                     0 => ShaderScalarType::Uint,
1312                     1 => ShaderScalarType::Sint,
1313                     _ => unreachable!(),
1314                 },
1315                 num_components: 1,
1316                 num_elements: 1,
1317                 is_64bit: match width {
1318                     8 | 16 | 32 => false,
1319                     64 => true,
1320                     _ => unimplemented!(),
1321                 },
1322             }
1323         }
1324         Instruction::TypeFloat { width, .. } => {
1325             assert!(!ignore_first_array);
1326             ShaderInterfaceEntryType {
1327                 base_type: ShaderScalarType::Float,
1328                 num_components: 1,
1329                 num_elements: 1,
1330                 is_64bit: match width {
1331                     16 | 32 => false,
1332                     64 => true,
1333                     _ => unimplemented!(),
1334                 },
1335             }
1336         }
1337         Instruction::TypeVector {
1338             component_type,
1339             component_count,
1340             ..
1341         } => {
1342             assert!(!ignore_first_array);
1343             ShaderInterfaceEntryType {
1344                 num_components: component_count,
1345                 ..shader_interface_type_of(spirv, component_type, false)
1346             }
1347         }
1348         Instruction::TypeMatrix {
1349             column_type,
1350             column_count,
1351             ..
1352         } => {
1353             assert!(!ignore_first_array);
1354             ShaderInterfaceEntryType {
1355                 num_elements: column_count,
1356                 ..shader_interface_type_of(spirv, column_type, false)
1357             }
1358         }
1359         Instruction::TypeArray {
1360             element_type,
1361             length,
1362             ..
1363         } => {
1364             if ignore_first_array {
1365                 shader_interface_type_of(spirv, element_type, false)
1366             } else {
1367                 let mut ty = shader_interface_type_of(spirv, element_type, false);
1368                 let num_elements = spirv
1369                     .instructions()
1370                     .iter()
1371                     .find_map(|instruction| match *instruction {
1372                         Instruction::Constant {
1373                             result_id,
1374                             ref value,
1375                             ..
1376                         } if result_id == length => Some(value.clone()),
1377                         _ => None,
1378                     })
1379                     .expect("failed to find array length")
1380                     .iter()
1381                     .rev()
1382                     .fold(0u64, |a, &b| (a << 32) | b as u64)
1383                     as u32;
1384                 ty.num_elements *= num_elements;
1385                 ty
1386             }
1387         }
1388         Instruction::TypePointer { ty, .. } => {
1389             shader_interface_type_of(spirv, ty, ignore_first_array)
1390         }
1391         _ => panic!("Type {} not found or invalid", id),
1392     }
1393 }
1394 
1395 /// Returns true if a `BuiltIn` decorator is applied on an id.
is_builtin(spirv: &Spirv, id: Id) -> bool1396 fn is_builtin(spirv: &Spirv, id: Id) -> bool {
1397     let id_info = spirv.id(id);
1398 
1399     if id_info.iter_decoration().any(|instruction| {
1400         matches!(
1401             instruction,
1402             Instruction::Decorate {
1403                 decoration: Decoration::BuiltIn { .. },
1404                 ..
1405             }
1406         )
1407     }) {
1408         return true;
1409     }
1410 
1411     if id_info
1412         .iter_members()
1413         .flat_map(|member_info| member_info.iter_decoration())
1414         .any(|instruction| {
1415             matches!(
1416                 instruction,
1417                 Instruction::MemberDecorate {
1418                     decoration: Decoration::BuiltIn { .. },
1419                     ..
1420                 }
1421             )
1422         })
1423     {
1424         return true;
1425     }
1426 
1427     match id_info.instruction() {
1428         Instruction::Variable {
1429             result_type_id: ty, ..
1430         }
1431         | Instruction::TypeArray {
1432             element_type: ty, ..
1433         }
1434         | Instruction::TypeRuntimeArray {
1435             element_type: ty, ..
1436         }
1437         | Instruction::TypePointer { ty, .. } => is_builtin(spirv, *ty),
1438         Instruction::TypeStruct { member_types, .. } => {
1439             member_types.iter().any(|ty| is_builtin(spirv, *ty))
1440         }
1441         _ => false,
1442     }
1443 }
1444