1 // Copyright (c) 2021 The Vulkano developers
2 // Licensed under the Apache License, Version 2.0
3 // <LICENSE-APACHE or
4 // https://www.apache.org/licenses/LICENSE-2.0> or the MIT
5 // license <LICENSE-MIT or https://opensource.org/licenses/MIT>,
6 // at your option. All files in the project carrying such
7 // notice may not be copied, modified, or distributed except
8 // according to those terms.
9
10 //! Extraction of information from SPIR-V modules, that is needed by the rest of Vulkano.
11
12 use super::{DescriptorBindingRequirements, FragmentShaderExecution, FragmentTestsStages};
13 use crate::{
14 descriptor_set::layout::DescriptorType,
15 image::view::ImageViewType,
16 pipeline::layout::PushConstantRange,
17 shader::{
18 spirv::{
19 Capability, Decoration, Dim, ExecutionMode, ExecutionModel, Id, Instruction, Spirv,
20 StorageClass,
21 },
22 DescriptorIdentifier, DescriptorRequirements, EntryPointInfo, GeometryShaderExecution,
23 GeometryShaderInput, ShaderExecution, ShaderInterface, ShaderInterfaceEntry,
24 ShaderInterfaceEntryType, ShaderScalarType, ShaderStage,
25 SpecializationConstantRequirements,
26 },
27 DeviceSize,
28 };
29 use ahash::{HashMap, HashSet};
30 use std::borrow::Cow;
31
32 /// Returns an iterator of the capabilities used by `spirv`.
33 #[inline]
spirv_capabilities(spirv: &Spirv) -> impl Iterator<Item = &Capability>34 pub fn spirv_capabilities(spirv: &Spirv) -> impl Iterator<Item = &Capability> {
35 spirv
36 .iter_capability()
37 .filter_map(|instruction| match instruction {
38 Instruction::Capability { capability } => Some(capability),
39 _ => None,
40 })
41 }
42
43 /// Returns an iterator of the extensions used by `spirv`.
44 #[inline]
spirv_extensions(spirv: &Spirv) -> impl Iterator<Item = &str>45 pub fn spirv_extensions(spirv: &Spirv) -> impl Iterator<Item = &str> {
46 spirv
47 .iter_extension()
48 .filter_map(|instruction| match instruction {
49 Instruction::Extension { name } => Some(name.as_str()),
50 _ => None,
51 })
52 }
53
54 /// Returns an iterator over all entry points in `spirv`, with information about the entry point.
55 #[inline]
entry_points( spirv: &Spirv, ) -> impl Iterator<Item = (String, ExecutionModel, EntryPointInfo)> + '_56 pub fn entry_points(
57 spirv: &Spirv,
58 ) -> impl Iterator<Item = (String, ExecutionModel, EntryPointInfo)> + '_ {
59 let interface_variables = interface_variables(spirv);
60
61 spirv.iter_entry_point().filter_map(move |instruction| {
62 let (execution_model, function_id, entry_point_name, interface) = match instruction {
63 Instruction::EntryPoint {
64 execution_model,
65 entry_point,
66 name,
67 interface,
68 ..
69 } => (*execution_model, *entry_point, name, interface),
70 _ => return None,
71 };
72
73 let execution = shader_execution(spirv, execution_model, function_id);
74 let stage = ShaderStage::from(execution);
75
76 let descriptor_binding_requirements = inspect_entry_point(
77 &interface_variables.descriptor_binding,
78 spirv,
79 stage,
80 function_id,
81 );
82 let push_constant_requirements = push_constant_requirements(spirv, stage);
83 let specialization_constant_requirements = specialization_constant_requirements(spirv);
84 let input_interface = shader_interface(
85 spirv,
86 interface,
87 StorageClass::Input,
88 matches!(
89 execution_model,
90 ExecutionModel::TessellationControl
91 | ExecutionModel::TessellationEvaluation
92 | ExecutionModel::Geometry
93 ),
94 );
95 let output_interface = shader_interface(
96 spirv,
97 interface,
98 StorageClass::Output,
99 matches!(execution_model, ExecutionModel::TessellationControl),
100 );
101
102 Some((
103 entry_point_name.clone(),
104 execution_model,
105 EntryPointInfo {
106 execution,
107 descriptor_binding_requirements,
108 push_constant_requirements,
109 specialization_constant_requirements,
110 input_interface,
111 output_interface,
112 },
113 ))
114 })
115 }
116
117 /// Extracts the `ShaderExecution` for the entry point `function_id` from `spirv`.
shader_execution( spirv: &Spirv, execution_model: ExecutionModel, function_id: Id, ) -> ShaderExecution118 fn shader_execution(
119 spirv: &Spirv,
120 execution_model: ExecutionModel,
121 function_id: Id,
122 ) -> ShaderExecution {
123 match execution_model {
124 ExecutionModel::Vertex => ShaderExecution::Vertex,
125
126 ExecutionModel::TessellationControl => ShaderExecution::TessellationControl,
127
128 ExecutionModel::TessellationEvaluation => ShaderExecution::TessellationEvaluation,
129
130 ExecutionModel::Geometry => {
131 let mut input = None;
132
133 for instruction in spirv.iter_execution_mode() {
134 let mode = match instruction {
135 Instruction::ExecutionMode {
136 entry_point, mode, ..
137 } if *entry_point == function_id => mode,
138 _ => continue,
139 };
140
141 match mode {
142 ExecutionMode::InputPoints => {
143 input = Some(GeometryShaderInput::Points);
144 }
145 ExecutionMode::InputLines => {
146 input = Some(GeometryShaderInput::Lines);
147 }
148 ExecutionMode::InputLinesAdjacency => {
149 input = Some(GeometryShaderInput::LinesWithAdjacency);
150 }
151 ExecutionMode::Triangles => {
152 input = Some(GeometryShaderInput::Triangles);
153 }
154 ExecutionMode::InputTrianglesAdjacency => {
155 input = Some(GeometryShaderInput::TrianglesWithAdjacency);
156 }
157 _ => (),
158 }
159 }
160
161 ShaderExecution::Geometry(GeometryShaderExecution {
162 input: input
163 .expect("Geometry shader does not have an input primitive ExecutionMode"),
164 })
165 }
166
167 ExecutionModel::Fragment => {
168 let mut fragment_tests_stages = FragmentTestsStages::Late;
169
170 for instruction in spirv.iter_execution_mode() {
171 let mode = match instruction {
172 Instruction::ExecutionMode {
173 entry_point, mode, ..
174 } if *entry_point == function_id => mode,
175 _ => continue,
176 };
177
178 #[allow(clippy::single_match)]
179 match mode {
180 ExecutionMode::EarlyFragmentTests => {
181 fragment_tests_stages = FragmentTestsStages::Early;
182 }
183 /*ExecutionMode::EarlyAndLateFragmentTestsAMD => {
184 fragment_tests_stages = FragmentTestsStages::EarlyAndLate;
185 }*/
186 _ => (),
187 }
188 }
189
190 ShaderExecution::Fragment(FragmentShaderExecution {
191 fragment_tests_stages,
192 })
193 }
194
195 ExecutionModel::GLCompute => ShaderExecution::Compute,
196
197 ExecutionModel::RayGenerationKHR => ShaderExecution::RayGeneration,
198 ExecutionModel::IntersectionKHR => ShaderExecution::Intersection,
199 ExecutionModel::AnyHitKHR => ShaderExecution::AnyHit,
200 ExecutionModel::ClosestHitKHR => ShaderExecution::ClosestHit,
201 ExecutionModel::MissKHR => ShaderExecution::Miss,
202 ExecutionModel::CallableKHR => ShaderExecution::Callable,
203
204 ExecutionModel::TaskNV => ShaderExecution::Task,
205 ExecutionModel::MeshNV => ShaderExecution::Mesh,
206
207 ExecutionModel::Kernel => todo!(),
208 }
209 }
210
211 #[derive(Clone, Debug, Default)]
212 struct InterfaceVariables {
213 descriptor_binding: HashMap<Id, DescriptorBindingVariable>,
214 }
215
216 // See also section 14.5.2 of the Vulkan specs: Descriptor Set Interface.
217 #[derive(Clone, Debug)]
218 struct DescriptorBindingVariable {
219 set: u32,
220 binding: u32,
221 reqs: DescriptorBindingRequirements,
222 }
223
interface_variables(spirv: &Spirv) -> InterfaceVariables224 fn interface_variables(spirv: &Spirv) -> InterfaceVariables {
225 let mut variables = InterfaceVariables::default();
226
227 for instruction in spirv.iter_global() {
228 if let Instruction::Variable {
229 result_id,
230 result_type_id: _,
231 storage_class,
232 ..
233 } = instruction
234 {
235 match storage_class {
236 StorageClass::StorageBuffer
237 | StorageClass::Uniform
238 | StorageClass::UniformConstant => {
239 variables.descriptor_binding.insert(
240 *result_id,
241 descriptor_binding_requirements_of(spirv, *result_id),
242 );
243 }
244 _ => (),
245 }
246 }
247 }
248
249 variables
250 }
251
inspect_entry_point( global: &HashMap<Id, DescriptorBindingVariable>, spirv: &Spirv, stage: ShaderStage, entry_point: Id, ) -> HashMap<(u32, u32), DescriptorBindingRequirements>252 fn inspect_entry_point(
253 global: &HashMap<Id, DescriptorBindingVariable>,
254 spirv: &Spirv,
255 stage: ShaderStage,
256 entry_point: Id,
257 ) -> HashMap<(u32, u32), DescriptorBindingRequirements> {
258 struct Context<'a> {
259 global: &'a HashMap<Id, DescriptorBindingVariable>,
260 spirv: &'a Spirv,
261 stage: ShaderStage,
262 inspected_functions: HashSet<Id>,
263 result: HashMap<Id, DescriptorBindingVariable>,
264 }
265
266 impl<'a> Context<'a> {
267 fn instruction_chain<const N: usize>(
268 &mut self,
269 chain: [fn(&Spirv, Id) -> Option<Id>; N],
270 id: Id,
271 ) -> Option<(&mut DescriptorBindingVariable, Option<u32>)> {
272 let id = chain
273 .into_iter()
274 .try_fold(id, |id, func| func(self.spirv, id))?;
275
276 if let Some(variable) = self.global.get(&id) {
277 // Variable was accessed without an access chain, return with index 0.
278 let variable = self.result.entry(id).or_insert_with(|| variable.clone());
279 variable.reqs.stages = self.stage.into();
280 return Some((variable, Some(0)));
281 }
282
283 let (id, indexes) = match *self.spirv.id(id).instruction() {
284 Instruction::AccessChain {
285 base, ref indexes, ..
286 } => (base, indexes),
287 _ => return None,
288 };
289
290 if let Some(variable) = self.global.get(&id) {
291 // Variable was accessed with an access chain.
292 // Retrieve index from instruction if it's a constant value.
293 // TODO: handle a `None` index too?
294 let index = match *self.spirv.id(*indexes.first().unwrap()).instruction() {
295 Instruction::Constant { ref value, .. } => Some(value[0]),
296 _ => None,
297 };
298 let variable = self.result.entry(id).or_insert_with(|| variable.clone());
299 variable.reqs.stages = self.stage.into();
300 return Some((variable, index));
301 }
302
303 None
304 }
305
306 fn inspect_entry_point_r(&mut self, function: Id) {
307 fn desc_reqs(
308 descriptor_variable: Option<(&mut DescriptorBindingVariable, Option<u32>)>,
309 ) -> Option<&mut DescriptorRequirements> {
310 descriptor_variable
311 .map(|(variable, index)| variable.reqs.descriptors.entry(index).or_default())
312 }
313
314 fn inst_image_texel_pointer(spirv: &Spirv, id: Id) -> Option<Id> {
315 match *spirv.id(id).instruction() {
316 Instruction::ImageTexelPointer { image, .. } => Some(image),
317 _ => None,
318 }
319 }
320
321 fn inst_load(spirv: &Spirv, id: Id) -> Option<Id> {
322 match *spirv.id(id).instruction() {
323 Instruction::Load { pointer, .. } => Some(pointer),
324 _ => None,
325 }
326 }
327
328 fn inst_sampled_image(spirv: &Spirv, id: Id) -> Option<Id> {
329 match *spirv.id(id).instruction() {
330 Instruction::SampledImage { sampler, .. } => Some(sampler),
331 _ => Some(id),
332 }
333 }
334
335 self.inspected_functions.insert(function);
336 let mut in_function = false;
337
338 for instruction in self.spirv.instructions() {
339 if !in_function {
340 match *instruction {
341 Instruction::Function { result_id, .. } if result_id == function => {
342 in_function = true;
343 }
344 _ => {}
345 }
346 } else {
347 let stage = self.stage;
348
349 match *instruction {
350 Instruction::AtomicLoad { pointer, .. } => {
351 // Storage buffer
352 if let Some(desc_reqs) = desc_reqs(self.instruction_chain([], pointer))
353 {
354 desc_reqs.memory_read = stage.into();
355 }
356
357 // Storage image
358 if let Some(desc_reqs) = desc_reqs(
359 self.instruction_chain([inst_image_texel_pointer], pointer),
360 ) {
361 desc_reqs.memory_read = stage.into();
362 desc_reqs.storage_image_atomic = true;
363 }
364 }
365
366 Instruction::AtomicStore { pointer, .. } => {
367 // Storage buffer
368 if let Some(desc_reqs) = desc_reqs(self.instruction_chain([], pointer))
369 {
370 desc_reqs.memory_write = stage.into();
371 }
372
373 // Storage image
374 if let Some(desc_reqs) = desc_reqs(
375 self.instruction_chain([inst_image_texel_pointer], pointer),
376 ) {
377 desc_reqs.memory_write = stage.into();
378 desc_reqs.storage_image_atomic = true;
379 }
380 }
381
382 Instruction::AtomicExchange { pointer, .. }
383 | Instruction::AtomicCompareExchange { pointer, .. }
384 | Instruction::AtomicCompareExchangeWeak { pointer, .. }
385 | Instruction::AtomicIIncrement { pointer, .. }
386 | Instruction::AtomicIDecrement { pointer, .. }
387 | Instruction::AtomicIAdd { pointer, .. }
388 | Instruction::AtomicISub { pointer, .. }
389 | Instruction::AtomicSMin { pointer, .. }
390 | Instruction::AtomicUMin { pointer, .. }
391 | Instruction::AtomicSMax { pointer, .. }
392 | Instruction::AtomicUMax { pointer, .. }
393 | Instruction::AtomicAnd { pointer, .. }
394 | Instruction::AtomicOr { pointer, .. }
395 | Instruction::AtomicXor { pointer, .. }
396 | Instruction::AtomicFlagTestAndSet { pointer, .. }
397 | Instruction::AtomicFlagClear { pointer, .. }
398 | Instruction::AtomicFMinEXT { pointer, .. }
399 | Instruction::AtomicFMaxEXT { pointer, .. }
400 | Instruction::AtomicFAddEXT { pointer, .. } => {
401 // Storage buffer
402 if let Some(desc_reqs) = desc_reqs(self.instruction_chain([], pointer))
403 {
404 desc_reqs.memory_read = stage.into();
405 desc_reqs.memory_write = stage.into();
406 }
407
408 // Storage image
409 if let Some(desc_reqs) = desc_reqs(
410 self.instruction_chain([inst_image_texel_pointer], pointer),
411 ) {
412 desc_reqs.memory_read = stage.into();
413 desc_reqs.memory_write = stage.into();
414 desc_reqs.storage_image_atomic = true;
415 }
416 }
417
418 Instruction::CopyMemory { target, source, .. } => {
419 self.instruction_chain([], target);
420 self.instruction_chain([], source);
421 }
422
423 Instruction::CopyObject { operand, .. } => {
424 self.instruction_chain([], operand);
425 }
426
427 Instruction::ExtInst { ref operands, .. } => {
428 // We don't know which extended instructions take pointers,
429 // so we must interpret every operand as a pointer.
430 for &operand in operands {
431 self.instruction_chain([], operand);
432 }
433 }
434
435 Instruction::FunctionCall {
436 function,
437 ref arguments,
438 ..
439 } => {
440 // Rather than trying to figure out the type of each argument, we just
441 // try all of them as pointers.
442 for &argument in arguments {
443 self.instruction_chain([], argument);
444 }
445
446 if !self.inspected_functions.contains(&function) {
447 self.inspect_entry_point_r(function);
448 }
449 }
450
451 Instruction::FunctionEnd => return,
452
453 Instruction::ImageGather {
454 sampled_image,
455 image_operands,
456 ..
457 }
458 | Instruction::ImageSparseGather {
459 sampled_image,
460 image_operands,
461 ..
462 } => {
463 if let Some(desc_reqs) =
464 desc_reqs(self.instruction_chain(
465 [inst_sampled_image, inst_load],
466 sampled_image,
467 ))
468 {
469 desc_reqs.memory_read = stage.into();
470 desc_reqs.sampler_no_ycbcr_conversion = true;
471
472 if image_operands.as_ref().map_or(false, |image_operands| {
473 image_operands.bias.is_some()
474 || image_operands.const_offset.is_some()
475 || image_operands.offset.is_some()
476 }) {
477 desc_reqs.sampler_no_unnormalized_coordinates = true;
478 }
479 }
480 }
481
482 Instruction::ImageDrefGather { sampled_image, .. }
483 | Instruction::ImageSparseDrefGather { sampled_image, .. } => {
484 if let Some(desc_reqs) =
485 desc_reqs(self.instruction_chain(
486 [inst_sampled_image, inst_load],
487 sampled_image,
488 ))
489 {
490 desc_reqs.memory_read = stage.into();
491 desc_reqs.sampler_no_unnormalized_coordinates = true;
492 desc_reqs.sampler_no_ycbcr_conversion = true;
493 }
494 }
495
496 Instruction::ImageSampleImplicitLod {
497 sampled_image,
498 image_operands,
499 ..
500 }
501 | Instruction::ImageSampleProjImplicitLod {
502 sampled_image,
503 image_operands,
504 ..
505 }
506 | Instruction::ImageSparseSampleProjImplicitLod {
507 sampled_image,
508 image_operands,
509 ..
510 }
511 | Instruction::ImageSparseSampleImplicitLod {
512 sampled_image,
513 image_operands,
514 ..
515 } => {
516 if let Some(desc_reqs) =
517 desc_reqs(self.instruction_chain(
518 [inst_sampled_image, inst_load],
519 sampled_image,
520 ))
521 {
522 desc_reqs.memory_read = stage.into();
523 desc_reqs.sampler_no_unnormalized_coordinates = true;
524
525 if image_operands.as_ref().map_or(false, |image_operands| {
526 image_operands.const_offset.is_some()
527 || image_operands.offset.is_some()
528 }) {
529 desc_reqs.sampler_no_ycbcr_conversion = true;
530 }
531 }
532 }
533
534 Instruction::ImageSampleProjExplicitLod {
535 sampled_image,
536 image_operands,
537 ..
538 }
539 | Instruction::ImageSparseSampleProjExplicitLod {
540 sampled_image,
541 image_operands,
542 ..
543 } => {
544 if let Some(desc_reqs) =
545 desc_reqs(self.instruction_chain(
546 [inst_sampled_image, inst_load],
547 sampled_image,
548 ))
549 {
550 desc_reqs.memory_read = stage.into();
551 desc_reqs.sampler_no_unnormalized_coordinates = true;
552
553 if image_operands.const_offset.is_some()
554 || image_operands.offset.is_some()
555 {
556 desc_reqs.sampler_no_ycbcr_conversion = true;
557 }
558 }
559 }
560
561 Instruction::ImageSampleDrefImplicitLod {
562 sampled_image,
563 image_operands,
564 ..
565 }
566 | Instruction::ImageSampleProjDrefImplicitLod {
567 sampled_image,
568 image_operands,
569 ..
570 }
571 | Instruction::ImageSparseSampleDrefImplicitLod {
572 sampled_image,
573 image_operands,
574 ..
575 }
576 | Instruction::ImageSparseSampleProjDrefImplicitLod {
577 sampled_image,
578 image_operands,
579 ..
580 } => {
581 if let Some(desc_reqs) =
582 desc_reqs(self.instruction_chain(
583 [inst_sampled_image, inst_load],
584 sampled_image,
585 ))
586 {
587 desc_reqs.memory_read = stage.into();
588 desc_reqs.sampler_no_unnormalized_coordinates = true;
589 desc_reqs.sampler_compare = true;
590
591 if image_operands.as_ref().map_or(false, |image_operands| {
592 image_operands.const_offset.is_some()
593 || image_operands.offset.is_some()
594 }) {
595 desc_reqs.sampler_no_ycbcr_conversion = true;
596 }
597 }
598 }
599
600 Instruction::ImageSampleDrefExplicitLod {
601 sampled_image,
602 image_operands,
603 ..
604 }
605 | Instruction::ImageSampleProjDrefExplicitLod {
606 sampled_image,
607 image_operands,
608 ..
609 }
610 | Instruction::ImageSparseSampleDrefExplicitLod {
611 sampled_image,
612 image_operands,
613 ..
614 }
615 | Instruction::ImageSparseSampleProjDrefExplicitLod {
616 sampled_image,
617 image_operands,
618 ..
619 } => {
620 if let Some(desc_reqs) =
621 desc_reqs(self.instruction_chain(
622 [inst_sampled_image, inst_load],
623 sampled_image,
624 ))
625 {
626 desc_reqs.memory_read = stage.into();
627 desc_reqs.sampler_no_unnormalized_coordinates = true;
628 desc_reqs.sampler_compare = true;
629
630 if image_operands.const_offset.is_some()
631 || image_operands.offset.is_some()
632 {
633 desc_reqs.sampler_no_ycbcr_conversion = true;
634 }
635 }
636 }
637
638 Instruction::ImageSampleExplicitLod {
639 sampled_image,
640 image_operands,
641 ..
642 }
643 | Instruction::ImageSparseSampleExplicitLod {
644 sampled_image,
645 image_operands,
646 ..
647 } => {
648 if let Some(desc_reqs) =
649 desc_reqs(self.instruction_chain(
650 [inst_sampled_image, inst_load],
651 sampled_image,
652 ))
653 {
654 desc_reqs.memory_read = stage.into();
655
656 if image_operands.bias.is_some()
657 || image_operands.const_offset.is_some()
658 || image_operands.offset.is_some()
659 {
660 desc_reqs.sampler_no_unnormalized_coordinates = true;
661 }
662
663 if image_operands.const_offset.is_some()
664 || image_operands.offset.is_some()
665 {
666 desc_reqs.sampler_no_ycbcr_conversion = true;
667 }
668 }
669 }
670
671 Instruction::ImageTexelPointer { image, .. } => {
672 self.instruction_chain([], image);
673 }
674
675 Instruction::ImageRead { image, .. } => {
676 if let Some(desc_reqs) =
677 desc_reqs(self.instruction_chain([inst_load], image))
678 {
679 desc_reqs.memory_read = stage.into();
680 }
681 }
682
683 Instruction::ImageWrite { image, .. } => {
684 if let Some(desc_reqs) =
685 desc_reqs(self.instruction_chain([inst_load], image))
686 {
687 desc_reqs.memory_write = stage.into();
688 }
689 }
690
691 Instruction::Load { pointer, .. } => {
692 if let Some((binding_variable, index)) =
693 self.instruction_chain([], pointer)
694 {
695 // Only loads on buffers access memory directly.
696 // Loads on images load the image object itself, but don't touch
697 // the texels in memory yet.
698 if binding_variable.reqs.descriptor_types.iter().any(|ty| {
699 matches!(
700 ty,
701 DescriptorType::UniformBuffer
702 | DescriptorType::UniformBufferDynamic
703 | DescriptorType::StorageBuffer
704 | DescriptorType::StorageBufferDynamic
705 )
706 }) {
707 if let Some(desc_reqs) =
708 desc_reqs(Some((binding_variable, index)))
709 {
710 desc_reqs.memory_read = stage.into();
711 }
712 }
713 }
714 }
715
716 Instruction::SampledImage { image, sampler, .. } => {
717 let identifier = match self.instruction_chain([inst_load], image) {
718 Some((variable, Some(index))) => DescriptorIdentifier {
719 set: variable.set,
720 binding: variable.binding,
721 index,
722 },
723 _ => continue,
724 };
725
726 if let Some(desc_reqs) =
727 desc_reqs(self.instruction_chain([inst_load], sampler))
728 {
729 desc_reqs.sampler_with_images.insert(identifier);
730 }
731 }
732
733 Instruction::Store { pointer, .. } => {
734 // This can only apply to buffers, right?
735 if let Some(desc_reqs) = desc_reqs(self.instruction_chain([], pointer))
736 {
737 desc_reqs.memory_write = stage.into();
738 }
739 }
740
741 _ => (),
742 }
743 }
744 }
745 }
746 }
747
748 let mut context = Context {
749 global,
750 spirv,
751 stage,
752 inspected_functions: HashSet::default(),
753 result: HashMap::default(),
754 };
755 context.inspect_entry_point_r(entry_point);
756
757 context
758 .result
759 .into_values()
760 .map(|variable| ((variable.set, variable.binding), variable.reqs))
761 .collect()
762 }
763
764 /// Returns a `DescriptorBindingRequirements` value for the pointed type.
765 ///
766 /// See also section 14.5.2 of the Vulkan specs: Descriptor Set Interface
descriptor_binding_requirements_of(spirv: &Spirv, variable_id: Id) -> DescriptorBindingVariable767 fn descriptor_binding_requirements_of(spirv: &Spirv, variable_id: Id) -> DescriptorBindingVariable {
768 let variable_id_info = spirv.id(variable_id);
769
770 let mut reqs = DescriptorBindingRequirements {
771 descriptor_count: Some(1),
772 ..Default::default()
773 };
774
775 let (mut next_type_id, is_storage_buffer) = {
776 let variable_type_id = match *variable_id_info.instruction() {
777 Instruction::Variable { result_type_id, .. } => result_type_id,
778 _ => panic!("Id {} is not a variable", variable_id),
779 };
780
781 match *spirv.id(variable_type_id).instruction() {
782 Instruction::TypePointer {
783 ty, storage_class, ..
784 } => (Some(ty), storage_class == StorageClass::StorageBuffer),
785 _ => panic!(
786 "Variable {} result_type_id does not refer to a TypePointer instruction",
787 variable_id
788 ),
789 }
790 };
791
792 while let Some(id) = next_type_id {
793 let id_info = spirv.id(id);
794
795 next_type_id = match *id_info.instruction() {
796 Instruction::TypeStruct { .. } => {
797 let decoration_block = id_info.iter_decoration().any(|instruction| {
798 matches!(
799 instruction,
800 Instruction::Decorate {
801 decoration: Decoration::Block,
802 ..
803 }
804 )
805 });
806
807 let decoration_buffer_block = id_info.iter_decoration().any(|instruction| {
808 matches!(
809 instruction,
810 Instruction::Decorate {
811 decoration: Decoration::BufferBlock,
812 ..
813 }
814 )
815 });
816
817 assert!(
818 decoration_block ^ decoration_buffer_block,
819 "Structs in shader interface are expected to be decorated with one of Block or \
820 BufferBlock",
821 );
822
823 if decoration_buffer_block || decoration_block && is_storage_buffer {
824 reqs.descriptor_types = vec![
825 DescriptorType::StorageBuffer,
826 DescriptorType::StorageBufferDynamic,
827 ];
828 } else {
829 reqs.descriptor_types = vec![
830 DescriptorType::UniformBuffer,
831 DescriptorType::UniformBufferDynamic,
832 ];
833 };
834
835 None
836 }
837
838 Instruction::TypeImage {
839 sampled_type,
840 dim,
841 arrayed,
842 ms,
843 sampled,
844 image_format,
845 ..
846 } => {
847 assert!(
848 sampled != 0,
849 "Vulkan requires that variables of type OpTypeImage have a Sampled operand of \
850 1 or 2",
851 );
852 reqs.image_format = image_format.into();
853 reqs.image_multisampled = ms != 0;
854 reqs.image_scalar_type = Some(match *spirv.id(sampled_type).instruction() {
855 Instruction::TypeInt {
856 width, signedness, ..
857 } => {
858 assert!(width == 32); // TODO: 64-bit components
859 match signedness {
860 0 => ShaderScalarType::Uint,
861 1 => ShaderScalarType::Sint,
862 _ => unreachable!(),
863 }
864 }
865 Instruction::TypeFloat { width, .. } => {
866 assert!(width == 32); // TODO: 64-bit components
867 ShaderScalarType::Float
868 }
869 _ => unreachable!(),
870 });
871
872 match dim {
873 Dim::SubpassData => {
874 assert!(
875 reqs.image_format.is_none(),
876 "If Dim is SubpassData, Image Format must be Unknown",
877 );
878 assert!(sampled == 2, "If Dim is SubpassData, Sampled must be 2");
879 assert!(arrayed == 0, "If Dim is SubpassData, Arrayed must be 0");
880
881 reqs.descriptor_types = vec![DescriptorType::InputAttachment];
882 }
883 Dim::Buffer => {
884 if sampled == 1 {
885 reqs.descriptor_types = vec![DescriptorType::UniformTexelBuffer];
886 } else {
887 reqs.descriptor_types = vec![DescriptorType::StorageTexelBuffer];
888 }
889 }
890 _ => {
891 reqs.image_view_type = Some(match (dim, arrayed) {
892 (Dim::Dim1D, 0) => ImageViewType::Dim1d,
893 (Dim::Dim1D, 1) => ImageViewType::Dim1dArray,
894 (Dim::Dim2D, 0) => ImageViewType::Dim2d,
895 (Dim::Dim2D, 1) => ImageViewType::Dim2dArray,
896 (Dim::Dim3D, 0) => ImageViewType::Dim3d,
897 (Dim::Dim3D, 1) => {
898 panic!("Vulkan doesn't support arrayed 3D textures")
899 }
900 (Dim::Cube, 0) => ImageViewType::Cube,
901 (Dim::Cube, 1) => ImageViewType::CubeArray,
902 (Dim::Rect, _) => {
903 panic!("Vulkan doesn't support rectangle textures")
904 }
905 _ => unreachable!(),
906 });
907
908 if reqs.descriptor_types.is_empty() {
909 if sampled == 1 {
910 reqs.descriptor_types = vec![DescriptorType::SampledImage];
911 } else {
912 reqs.descriptor_types = vec![DescriptorType::StorageImage];
913 }
914 }
915 }
916 }
917
918 None
919 }
920
921 Instruction::TypeSampler { .. } => {
922 reqs.descriptor_types = vec![DescriptorType::Sampler];
923
924 None
925 }
926
927 Instruction::TypeSampledImage { image_type, .. } => {
928 reqs.descriptor_types = vec![DescriptorType::CombinedImageSampler];
929
930 Some(image_type)
931 }
932
933 Instruction::TypeArray {
934 element_type,
935 length,
936 ..
937 } => {
938 let len = match spirv.id(length).instruction() {
939 Instruction::Constant { value, .. } => {
940 value.iter().rev().fold(0, |a, &b| (a << 32) | b as u64)
941 }
942 _ => panic!("failed to find array length"),
943 };
944
945 if let Some(count) = reqs.descriptor_count.as_mut() {
946 *count *= len as u32
947 }
948
949 Some(element_type)
950 }
951
952 Instruction::TypeRuntimeArray { element_type, .. } => {
953 reqs.descriptor_count = None;
954
955 Some(element_type)
956 }
957
958 Instruction::TypeAccelerationStructureKHR { .. } => None, // FIXME temporary workaround
959
960 _ => {
961 let name = variable_id_info
962 .iter_name()
963 .find_map(|instruction| match *instruction {
964 Instruction::Name { ref name, .. } => Some(name.as_str()),
965 _ => None,
966 })
967 .unwrap_or("__unnamed");
968
969 panic!(
970 "Couldn't find relevant type for global variable `{}` (id {}, maybe \
971 unimplemented)",
972 name, variable_id,
973 );
974 }
975 };
976 }
977
978 DescriptorBindingVariable {
979 set: variable_id_info
980 .iter_decoration()
981 .find_map(|instruction| match *instruction {
982 Instruction::Decorate {
983 decoration: Decoration::DescriptorSet { descriptor_set },
984 ..
985 } => Some(descriptor_set),
986 _ => None,
987 })
988 .unwrap(),
989 binding: variable_id_info
990 .iter_decoration()
991 .find_map(|instruction| match *instruction {
992 Instruction::Decorate {
993 decoration: Decoration::Binding { binding_point },
994 ..
995 } => Some(binding_point),
996 _ => None,
997 })
998 .unwrap(),
999 reqs,
1000 }
1001 }
1002
1003 /// Extracts the `PushConstantRange` from `spirv`.
push_constant_requirements(spirv: &Spirv, stage: ShaderStage) -> Option<PushConstantRange>1004 fn push_constant_requirements(spirv: &Spirv, stage: ShaderStage) -> Option<PushConstantRange> {
1005 spirv
1006 .iter_global()
1007 .find_map(|instruction| match *instruction {
1008 Instruction::TypePointer {
1009 ty,
1010 storage_class: StorageClass::PushConstant,
1011 ..
1012 } => {
1013 let id_info = spirv.id(ty);
1014 assert!(matches!(
1015 id_info.instruction(),
1016 Instruction::TypeStruct { .. }
1017 ));
1018 let start = offset_of_struct(spirv, ty);
1019 let end =
1020 size_of_type(spirv, ty).expect("Found runtime-sized push constants") as u32;
1021
1022 Some(PushConstantRange {
1023 stages: stage.into(),
1024 offset: start,
1025 size: end - start,
1026 })
1027 }
1028 _ => None,
1029 })
1030 }
1031
1032 /// Extracts the `SpecializationConstantRequirements` from `spirv`.
specialization_constant_requirements( spirv: &Spirv, ) -> HashMap<u32, SpecializationConstantRequirements>1033 fn specialization_constant_requirements(
1034 spirv: &Spirv,
1035 ) -> HashMap<u32, SpecializationConstantRequirements> {
1036 spirv
1037 .iter_global()
1038 .filter_map(|instruction| {
1039 match *instruction {
1040 Instruction::SpecConstantTrue {
1041 result_type_id,
1042 result_id,
1043 }
1044 | Instruction::SpecConstantFalse {
1045 result_type_id,
1046 result_id,
1047 }
1048 | Instruction::SpecConstant {
1049 result_type_id,
1050 result_id,
1051 ..
1052 }
1053 | Instruction::SpecConstantComposite {
1054 result_type_id,
1055 result_id,
1056 ..
1057 } => spirv
1058 .id(result_id)
1059 .iter_decoration()
1060 .find_map(|instruction| match *instruction {
1061 Instruction::Decorate {
1062 decoration:
1063 Decoration::SpecId {
1064 specialization_constant_id,
1065 },
1066 ..
1067 } => Some(specialization_constant_id),
1068 _ => None,
1069 })
1070 .map(|constant_id| {
1071 let size = match *spirv.id(result_type_id).instruction() {
1072 Instruction::TypeBool { .. } => {
1073 // Translate bool to Bool32
1074 std::mem::size_of::<ash::vk::Bool32>() as DeviceSize
1075 }
1076 _ => size_of_type(spirv, result_type_id)
1077 .expect("Found runtime-sized specialization constant"),
1078 };
1079 (constant_id, SpecializationConstantRequirements { size })
1080 }),
1081 _ => None,
1082 }
1083 })
1084 .collect()
1085 }
1086
1087 /// Extracts the `ShaderInterface` with the given storage class from `spirv`.
shader_interface( spirv: &Spirv, interface: &[Id], filter_storage_class: StorageClass, ignore_first_array: bool, ) -> ShaderInterface1088 fn shader_interface(
1089 spirv: &Spirv,
1090 interface: &[Id],
1091 filter_storage_class: StorageClass,
1092 ignore_first_array: bool,
1093 ) -> ShaderInterface {
1094 let elements: Vec<_> = interface
1095 .iter()
1096 .filter_map(|&id| {
1097 let (result_type_id, result_id) = match *spirv.id(id).instruction() {
1098 Instruction::Variable {
1099 result_type_id,
1100 result_id,
1101 storage_class,
1102 ..
1103 } if storage_class == filter_storage_class => (result_type_id, result_id),
1104 _ => return None,
1105 };
1106
1107 if is_builtin(spirv, result_id) {
1108 return None;
1109 }
1110
1111 let id_info = spirv.id(result_id);
1112
1113 let name = id_info
1114 .iter_name()
1115 .find_map(|instruction| match *instruction {
1116 Instruction::Name { ref name, .. } => Some(Cow::Owned(name.clone())),
1117 _ => None,
1118 });
1119
1120 let location = id_info
1121 .iter_decoration()
1122 .find_map(|instruction| match *instruction {
1123 Instruction::Decorate {
1124 decoration: Decoration::Location { location },
1125 ..
1126 } => Some(location),
1127 _ => None,
1128 })
1129 .unwrap_or_else(|| {
1130 panic!(
1131 "Input/output variable with id {} (name {:?}) is missing a location",
1132 result_id, name,
1133 )
1134 });
1135 let component = id_info
1136 .iter_decoration()
1137 .find_map(|instruction| match *instruction {
1138 Instruction::Decorate {
1139 decoration: Decoration::Component { component },
1140 ..
1141 } => Some(component),
1142 _ => None,
1143 })
1144 .unwrap_or(0);
1145
1146 let ty = shader_interface_type_of(spirv, result_type_id, ignore_first_array);
1147 assert!(ty.num_elements >= 1);
1148
1149 Some(ShaderInterfaceEntry {
1150 location,
1151 component,
1152 ty,
1153 name,
1154 })
1155 })
1156 .collect();
1157
1158 // Checking for overlapping elements.
1159 for (offset, element1) in elements.iter().enumerate() {
1160 for element2 in elements.iter().skip(offset + 1) {
1161 if element1.location == element2.location
1162 || (element1.location < element2.location
1163 && element1.location + element1.ty.num_locations() > element2.location)
1164 || (element2.location < element1.location
1165 && element2.location + element2.ty.num_locations() > element1.location)
1166 {
1167 panic!(
1168 "The locations of attributes `{:?}` ({}..{}) and `{:?}` ({}..{}) overlap",
1169 element1.name,
1170 element1.location,
1171 element1.location + element1.ty.num_locations(),
1172 element2.name,
1173 element2.location,
1174 element2.location + element2.ty.num_locations(),
1175 );
1176 }
1177 }
1178 }
1179
1180 ShaderInterface { elements }
1181 }
1182
1183 /// Returns the size of a type, or `None` if its size cannot be determined.
size_of_type(spirv: &Spirv, id: Id) -> Option<DeviceSize>1184 fn size_of_type(spirv: &Spirv, id: Id) -> Option<DeviceSize> {
1185 let id_info = spirv.id(id);
1186
1187 match *id_info.instruction() {
1188 Instruction::TypeBool { .. } => {
1189 panic!("Can't put booleans in structs")
1190 }
1191 Instruction::TypeInt { width, .. } | Instruction::TypeFloat { width, .. } => {
1192 assert!(width % 8 == 0);
1193 Some(width as DeviceSize / 8)
1194 }
1195 Instruction::TypeVector {
1196 component_type,
1197 component_count,
1198 ..
1199 } => size_of_type(spirv, component_type)
1200 .map(|component_size| component_size * component_count as DeviceSize),
1201 Instruction::TypeMatrix {
1202 column_type,
1203 column_count,
1204 ..
1205 } => {
1206 // FIXME: row-major or column-major
1207 size_of_type(spirv, column_type)
1208 .map(|column_size| column_size * column_count as DeviceSize)
1209 }
1210 Instruction::TypeArray { length, .. } => {
1211 let stride = id_info
1212 .iter_decoration()
1213 .find_map(|instruction| match *instruction {
1214 Instruction::Decorate {
1215 decoration: Decoration::ArrayStride { array_stride },
1216 ..
1217 } => Some(array_stride),
1218 _ => None,
1219 })
1220 .unwrap();
1221 let length = match spirv.id(length).instruction() {
1222 Instruction::Constant { value, .. } => Some(
1223 value
1224 .iter()
1225 .rev()
1226 .fold(0u64, |a, &b| (a << 32) | b as DeviceSize),
1227 ),
1228 _ => None,
1229 }
1230 .unwrap();
1231
1232 Some(stride as DeviceSize * length)
1233 }
1234 Instruction::TypeRuntimeArray { .. } => None,
1235 Instruction::TypeStruct {
1236 ref member_types, ..
1237 } => {
1238 let mut end_of_struct = 0;
1239
1240 for (&member, member_info) in member_types.iter().zip(id_info.iter_members()) {
1241 // Built-ins have an unknown size.
1242 if member_info.iter_decoration().any(|instruction| {
1243 matches!(
1244 instruction,
1245 Instruction::MemberDecorate {
1246 decoration: Decoration::BuiltIn { .. },
1247 ..
1248 }
1249 )
1250 }) {
1251 return None;
1252 }
1253
1254 // Some structs don't have `Offset` decorations, in the case they are used as local
1255 // variables only. Ignoring these.
1256 let offset =
1257 member_info
1258 .iter_decoration()
1259 .find_map(|instruction| match *instruction {
1260 Instruction::MemberDecorate {
1261 decoration: Decoration::Offset { byte_offset },
1262 ..
1263 } => Some(byte_offset),
1264 _ => None,
1265 })?;
1266 let size = size_of_type(spirv, member)?;
1267 end_of_struct = end_of_struct.max(offset as DeviceSize + size);
1268 }
1269
1270 Some(end_of_struct)
1271 }
1272 _ => panic!("Type {} not found", id),
1273 }
1274 }
1275
1276 /// Returns the smallest offset of all members of a struct, or 0 if `id` is not a struct.
offset_of_struct(spirv: &Spirv, id: Id) -> u321277 fn offset_of_struct(spirv: &Spirv, id: Id) -> u32 {
1278 spirv
1279 .id(id)
1280 .iter_members()
1281 .filter_map(|member_info| {
1282 member_info
1283 .iter_decoration()
1284 .find_map(|instruction| match *instruction {
1285 Instruction::MemberDecorate {
1286 decoration: Decoration::Offset { byte_offset },
1287 ..
1288 } => Some(byte_offset),
1289 _ => None,
1290 })
1291 })
1292 .min()
1293 .unwrap_or(0)
1294 }
1295
1296 /// If `ignore_first_array` is true, the function expects the outermost instruction to be
1297 /// `OpTypeArray`. If it's the case, the OpTypeArray will be ignored. If not, the function will
1298 /// panic.
shader_interface_type_of( spirv: &Spirv, id: Id, ignore_first_array: bool, ) -> ShaderInterfaceEntryType1299 fn shader_interface_type_of(
1300 spirv: &Spirv,
1301 id: Id,
1302 ignore_first_array: bool,
1303 ) -> ShaderInterfaceEntryType {
1304 match *spirv.id(id).instruction() {
1305 Instruction::TypeInt {
1306 width, signedness, ..
1307 } => {
1308 assert!(!ignore_first_array);
1309 ShaderInterfaceEntryType {
1310 base_type: match signedness {
1311 0 => ShaderScalarType::Uint,
1312 1 => ShaderScalarType::Sint,
1313 _ => unreachable!(),
1314 },
1315 num_components: 1,
1316 num_elements: 1,
1317 is_64bit: match width {
1318 8 | 16 | 32 => false,
1319 64 => true,
1320 _ => unimplemented!(),
1321 },
1322 }
1323 }
1324 Instruction::TypeFloat { width, .. } => {
1325 assert!(!ignore_first_array);
1326 ShaderInterfaceEntryType {
1327 base_type: ShaderScalarType::Float,
1328 num_components: 1,
1329 num_elements: 1,
1330 is_64bit: match width {
1331 16 | 32 => false,
1332 64 => true,
1333 _ => unimplemented!(),
1334 },
1335 }
1336 }
1337 Instruction::TypeVector {
1338 component_type,
1339 component_count,
1340 ..
1341 } => {
1342 assert!(!ignore_first_array);
1343 ShaderInterfaceEntryType {
1344 num_components: component_count,
1345 ..shader_interface_type_of(spirv, component_type, false)
1346 }
1347 }
1348 Instruction::TypeMatrix {
1349 column_type,
1350 column_count,
1351 ..
1352 } => {
1353 assert!(!ignore_first_array);
1354 ShaderInterfaceEntryType {
1355 num_elements: column_count,
1356 ..shader_interface_type_of(spirv, column_type, false)
1357 }
1358 }
1359 Instruction::TypeArray {
1360 element_type,
1361 length,
1362 ..
1363 } => {
1364 if ignore_first_array {
1365 shader_interface_type_of(spirv, element_type, false)
1366 } else {
1367 let mut ty = shader_interface_type_of(spirv, element_type, false);
1368 let num_elements = spirv
1369 .instructions()
1370 .iter()
1371 .find_map(|instruction| match *instruction {
1372 Instruction::Constant {
1373 result_id,
1374 ref value,
1375 ..
1376 } if result_id == length => Some(value.clone()),
1377 _ => None,
1378 })
1379 .expect("failed to find array length")
1380 .iter()
1381 .rev()
1382 .fold(0u64, |a, &b| (a << 32) | b as u64)
1383 as u32;
1384 ty.num_elements *= num_elements;
1385 ty
1386 }
1387 }
1388 Instruction::TypePointer { ty, .. } => {
1389 shader_interface_type_of(spirv, ty, ignore_first_array)
1390 }
1391 _ => panic!("Type {} not found or invalid", id),
1392 }
1393 }
1394
1395 /// Returns true if a `BuiltIn` decorator is applied on an id.
is_builtin(spirv: &Spirv, id: Id) -> bool1396 fn is_builtin(spirv: &Spirv, id: Id) -> bool {
1397 let id_info = spirv.id(id);
1398
1399 if id_info.iter_decoration().any(|instruction| {
1400 matches!(
1401 instruction,
1402 Instruction::Decorate {
1403 decoration: Decoration::BuiltIn { .. },
1404 ..
1405 }
1406 )
1407 }) {
1408 return true;
1409 }
1410
1411 if id_info
1412 .iter_members()
1413 .flat_map(|member_info| member_info.iter_decoration())
1414 .any(|instruction| {
1415 matches!(
1416 instruction,
1417 Instruction::MemberDecorate {
1418 decoration: Decoration::BuiltIn { .. },
1419 ..
1420 }
1421 )
1422 })
1423 {
1424 return true;
1425 }
1426
1427 match id_info.instruction() {
1428 Instruction::Variable {
1429 result_type_id: ty, ..
1430 }
1431 | Instruction::TypeArray {
1432 element_type: ty, ..
1433 }
1434 | Instruction::TypeRuntimeArray {
1435 element_type: ty, ..
1436 }
1437 | Instruction::TypePointer { ty, .. } => is_builtin(spirv, *ty),
1438 Instruction::TypeStruct { member_types, .. } => {
1439 member_types.iter().any(|ty| is_builtin(spirv, *ty))
1440 }
1441 _ => false,
1442 }
1443 }
1444