From 2047b7b88afc825fbeb7ac50bda38e8a99a741ab Mon Sep 17 00:00:00 2001 From: Chengye YU Date: Tue, 10 Feb 2026 21:35:10 +0800 Subject: [PATCH 1/8] MSL: add initial cooperative matrix support --- .../cooperative-matrix-bfloat.asm.msl31.comp | 24 ++ .../cooperative-matrix-length.asm.msl31.comp | 16 ++ ...operative-matrix-load-store.asm.msl31.comp | 21 ++ .../cooperative-matrix-muladd.asm.msl31.comp | 38 +++ .../cooperative-matrix-bfloat.asm.msl31.comp | 52 ++++ .../cooperative-matrix-length.asm.msl31.comp | 42 +++ ...operative-matrix-load-store.asm.msl31.comp | 51 ++++ .../cooperative-matrix-muladd.asm.msl31.comp | 77 ++++++ spirv_msl.cpp | 240 ++++++++++++++++++ spirv_msl.hpp | 1 + 10 files changed, 562 insertions(+) create mode 100644 reference/shaders-msl-no-opt/asm/comp/cooperative-matrix-bfloat.asm.msl31.comp create mode 100644 reference/shaders-msl-no-opt/asm/comp/cooperative-matrix-length.asm.msl31.comp create mode 100644 reference/shaders-msl-no-opt/asm/comp/cooperative-matrix-load-store.asm.msl31.comp create mode 100644 reference/shaders-msl-no-opt/asm/comp/cooperative-matrix-muladd.asm.msl31.comp create mode 100644 shaders-msl-no-opt/asm/comp/cooperative-matrix-bfloat.asm.msl31.comp create mode 100644 shaders-msl-no-opt/asm/comp/cooperative-matrix-length.asm.msl31.comp create mode 100644 shaders-msl-no-opt/asm/comp/cooperative-matrix-load-store.asm.msl31.comp create mode 100644 shaders-msl-no-opt/asm/comp/cooperative-matrix-muladd.asm.msl31.comp diff --git a/reference/shaders-msl-no-opt/asm/comp/cooperative-matrix-bfloat.asm.msl31.comp b/reference/shaders-msl-no-opt/asm/comp/cooperative-matrix-bfloat.asm.msl31.comp new file mode 100644 index 000000000..757500a1e --- /dev/null +++ b/reference/shaders-msl-no-opt/asm/comp/cooperative-matrix-bfloat.asm.msl31.comp @@ -0,0 +1,24 @@ +#include +#include +#include + +using namespace metal; + +struct SSBO +{ + bfloat data[1]; +}; + +kernel void main0(device SSBO& ssbo [[buffer(0)]]) +{ + simdgroup_bfloat8x8 _21; + simdgroup_load(_21, &ssbo.data[0u], 8u); + simdgroup_bfloat8x8 _22; + simdgroup_load(_22, &ssbo.data[0u], 8u); + simdgroup_bfloat8x8 _23; + simdgroup_load(_23, &ssbo.data[0u], 8u); + simdgroup_bfloat8x8 _24; + simdgroup_multiply_accumulate(_24, _21, _22, _23); + simdgroup_store(_24, &ssbo.data[0u], 8u); +} + diff --git a/reference/shaders-msl-no-opt/asm/comp/cooperative-matrix-length.asm.msl31.comp b/reference/shaders-msl-no-opt/asm/comp/cooperative-matrix-length.asm.msl31.comp new file mode 100644 index 000000000..8161a5c9c --- /dev/null +++ b/reference/shaders-msl-no-opt/asm/comp/cooperative-matrix-length.asm.msl31.comp @@ -0,0 +1,16 @@ +#include +#include +#include + +using namespace metal; + +struct SSBO +{ + uint data[1]; +}; + +kernel void main0(device SSBO& ssbo [[buffer(0)]]) +{ + ssbo.data[0u] = uint(sizeof(simdgroup_float8x8::storage_type) / sizeof(float)); +} + diff --git a/reference/shaders-msl-no-opt/asm/comp/cooperative-matrix-load-store.asm.msl31.comp b/reference/shaders-msl-no-opt/asm/comp/cooperative-matrix-load-store.asm.msl31.comp new file mode 100644 index 000000000..005cf2c7e --- /dev/null +++ b/reference/shaders-msl-no-opt/asm/comp/cooperative-matrix-load-store.asm.msl31.comp @@ -0,0 +1,21 @@ +#include +#include +#include + +using namespace metal; + +struct SSBO +{ + float data[1]; +}; + +kernel void main0(device SSBO& ssbo [[buffer(0)]]) +{ + simdgroup_float8x8 _20; + simdgroup_load(_20, &ssbo.data[0u], 8u); + simdgroup_store(_20, &ssbo.data[0u], 8u); + simdgroup_float8x8 _21; + simdgroup_load(_21, &ssbo.data[0u], 8u, ulong2(0), true); + simdgroup_store(_21, &ssbo.data[0u], 8u, ulong2(0), true); +} + diff --git a/reference/shaders-msl-no-opt/asm/comp/cooperative-matrix-muladd.asm.msl31.comp b/reference/shaders-msl-no-opt/asm/comp/cooperative-matrix-muladd.asm.msl31.comp new file mode 100644 index 000000000..f5a47982e --- /dev/null +++ b/reference/shaders-msl-no-opt/asm/comp/cooperative-matrix-muladd.asm.msl31.comp @@ -0,0 +1,38 @@ +#include +#include +#include + +using namespace metal; + +struct SSBO32 +{ + float data[1]; +}; + +struct SSBO16 +{ + half data[1]; +}; + +kernel void main0(device SSBO32& ssbo32 [[buffer(0)]], device SSBO16& ssbo16 [[buffer(1)]]) +{ + simdgroup_float8x8 _30; + simdgroup_load(_30, &ssbo32.data[0u], 8u); + simdgroup_float8x8 _31; + simdgroup_load(_31, &ssbo32.data[0u], 8u); + simdgroup_float8x8 _32; + simdgroup_load(_32, &ssbo32.data[0u], 8u); + simdgroup_float8x8 _33; + simdgroup_multiply_accumulate(_33, _30, _31, _32); + simdgroup_store(_33, &ssbo32.data[0u], 8u); + simdgroup_half8x8 _35; + simdgroup_load(_35, &ssbo16.data[0u], 8u); + simdgroup_half8x8 _36; + simdgroup_load(_36, &ssbo16.data[0u], 8u); + simdgroup_half8x8 _37; + simdgroup_load(_37, &ssbo16.data[0u], 8u); + simdgroup_half8x8 _38; + simdgroup_multiply_accumulate(_38, _35, _36, _37); + simdgroup_store(_38, &ssbo16.data[0u], 8u); +} + diff --git a/shaders-msl-no-opt/asm/comp/cooperative-matrix-bfloat.asm.msl31.comp b/shaders-msl-no-opt/asm/comp/cooperative-matrix-bfloat.asm.msl31.comp new file mode 100644 index 000000000..18b795c2e --- /dev/null +++ b/shaders-msl-no-opt/asm/comp/cooperative-matrix-bfloat.asm.msl31.comp @@ -0,0 +1,52 @@ +; SPIR-V +; Version: 1.6 +; Generator: Khronos SPIR-V Tools Assembler; 0 +; Bound: 50 +; Schema: 0 + OpCapability Shader + OpCapability CooperativeMatrixKHR + OpCapability BFloat16TypeKHR + OpCapability BFloat16CooperativeMatrixKHR + OpCapability VulkanMemoryModel + OpExtension "SPV_KHR_cooperative_matrix" + OpExtension "SPV_KHR_bfloat16" + OpExtension "SPV_KHR_vulkan_memory_model" + OpMemoryModel Logical Vulkan + OpEntryPoint GLCompute %main "main" + OpExecutionMode %main LocalSize 32 1 1 + OpName %main "main" + OpName %SSBO "SSBO" + OpMemberName %SSBO 0 "data" + OpName %ssbo "ssbo" + OpDecorate %arr_bf16 ArrayStride 2 + OpMemberDecorate %SSBO 0 Offset 0 + OpDecorate %SSBO Block + OpDecorate %ssbo DescriptorSet 0 + OpDecorate %ssbo Binding 0 + %void = OpTypeVoid + %3 = OpTypeFunction %void + %bfloat = OpTypeFloat 16 BFloat16KHR + %uint = OpTypeInt 32 0 + %uint_0 = OpConstant %uint 0 + %uint_1 = OpConstant %uint 1 + %uint_2 = OpConstant %uint 2 + %uint_3 = OpConstant %uint 3 + %uint_8 = OpConstant %uint 8 + %arr_bf16 = OpTypeRuntimeArray %bfloat + %SSBO = OpTypeStruct %arr_bf16 +%ptr_ssbo_SSBO = OpTypePointer StorageBuffer %SSBO + %ssbo = OpVariable %ptr_ssbo_SSBO StorageBuffer +%ptr_ssbo_bf16 = OpTypePointer StorageBuffer %bfloat +%coopmat_bf16_A = OpTypeCooperativeMatrixKHR %bfloat %uint_3 %uint_8 %uint_8 %uint_0 +%coopmat_bf16_B = OpTypeCooperativeMatrixKHR %bfloat %uint_3 %uint_8 %uint_8 %uint_1 +%coopmat_bf16_acc = OpTypeCooperativeMatrixKHR %bfloat %uint_3 %uint_8 %uint_8 %uint_2 + %main = OpFunction %void None %3 + %5 = OpLabel + %p0 = OpAccessChain %ptr_ssbo_bf16 %ssbo %uint_0 %uint_0 + %bf_A = OpCooperativeMatrixLoadKHR %coopmat_bf16_A %p0 %uint_0 %uint_8 + %bf_B = OpCooperativeMatrixLoadKHR %coopmat_bf16_B %p0 %uint_0 %uint_8 + %bf_C = OpCooperativeMatrixLoadKHR %coopmat_bf16_acc %p0 %uint_0 %uint_8 + %bf_D = OpCooperativeMatrixMulAddKHR %coopmat_bf16_acc %bf_A %bf_B %bf_C + OpCooperativeMatrixStoreKHR %p0 %bf_D %uint_0 %uint_8 + OpReturn + OpFunctionEnd diff --git a/shaders-msl-no-opt/asm/comp/cooperative-matrix-length.asm.msl31.comp b/shaders-msl-no-opt/asm/comp/cooperative-matrix-length.asm.msl31.comp new file mode 100644 index 000000000..69e585869 --- /dev/null +++ b/shaders-msl-no-opt/asm/comp/cooperative-matrix-length.asm.msl31.comp @@ -0,0 +1,42 @@ +; SPIR-V +; Version: 1.6 +; Generator: Khronos SPIR-V Tools Assembler; 0 +; Bound: 24 +; Schema: 0 + OpCapability Shader + OpCapability CooperativeMatrixKHR + OpCapability VulkanMemoryModel + OpExtension "SPV_KHR_cooperative_matrix" + OpExtension "SPV_KHR_vulkan_memory_model" + OpMemoryModel Logical Vulkan + OpEntryPoint GLCompute %main "main" + OpExecutionMode %main LocalSize 32 1 1 + OpName %main "main" + OpName %SSBO "SSBO" + OpMemberName %SSBO 0 "data" + OpName %ssbo "ssbo" + OpDecorate %arr_uint ArrayStride 4 + OpMemberDecorate %SSBO 0 Offset 0 + OpDecorate %SSBO Block + OpDecorate %ssbo DescriptorSet 0 + OpDecorate %ssbo Binding 0 + %void = OpTypeVoid + %3 = OpTypeFunction %void + %uint = OpTypeInt 32 0 + %float = OpTypeFloat 32 + %uint_0 = OpConstant %uint 0 + %uint_3 = OpConstant %uint 3 + %uint_8 = OpConstant %uint 8 + %arr_uint = OpTypeRuntimeArray %uint + %SSBO = OpTypeStruct %arr_uint +%ptr_ssbo_SSBO = OpTypePointer StorageBuffer %SSBO + %ssbo = OpVariable %ptr_ssbo_SSBO StorageBuffer +%ptr_ssbo_uint = OpTypePointer StorageBuffer %uint + %coopmat_a = OpTypeCooperativeMatrixKHR %float %uint_3 %uint_8 %uint_8 %uint_0 + %main = OpFunction %void None %3 + %5 = OpLabel + %len = OpCooperativeMatrixLengthKHR %uint %coopmat_a + %p = OpAccessChain %ptr_ssbo_uint %ssbo %uint_0 %uint_0 + OpStore %p %len + OpReturn + OpFunctionEnd diff --git a/shaders-msl-no-opt/asm/comp/cooperative-matrix-load-store.asm.msl31.comp b/shaders-msl-no-opt/asm/comp/cooperative-matrix-load-store.asm.msl31.comp new file mode 100644 index 000000000..a1fcb26e9 --- /dev/null +++ b/shaders-msl-no-opt/asm/comp/cooperative-matrix-load-store.asm.msl31.comp @@ -0,0 +1,51 @@ +; SPIR-V +; Version: 1.6 +; Generator: Khronos SPIR-V Tools Assembler; 0 +; Bound: 50 +; Schema: 0 + OpCapability Shader + OpCapability CooperativeMatrixKHR + OpCapability VulkanMemoryModel + OpExtension "SPV_KHR_cooperative_matrix" + OpExtension "SPV_KHR_vulkan_memory_model" + OpMemoryModel Logical Vulkan + OpEntryPoint GLCompute %main "main" + OpExecutionMode %main LocalSize 32 1 1 + OpName %main "main" + OpName %SSBO "SSBO" + OpMemberName %SSBO 0 "data" + OpName %ssbo "ssbo" + OpDecorate %arr_float ArrayStride 4 + OpMemberDecorate %SSBO 0 Offset 0 + OpDecorate %SSBO Block + OpDecorate %ssbo DescriptorSet 0 + OpDecorate %ssbo Binding 0 + %void = OpTypeVoid + %3 = OpTypeFunction %void + %float = OpTypeFloat 32 + %uint = OpTypeInt 32 0 + %uint_0 = OpConstant %uint 0 + %uint_1 = OpConstant %uint 1 + %uint_2 = OpConstant %uint 2 + %uint_3 = OpConstant %uint 3 + %uint_8 = OpConstant %uint 8 + %arr_float = OpTypeRuntimeArray %float + %SSBO = OpTypeStruct %arr_float +%ptr_ssbo_SSBO = OpTypePointer StorageBuffer %SSBO + %ssbo = OpVariable %ptr_ssbo_SSBO StorageBuffer +%ptr_ssbo_float = OpTypePointer StorageBuffer %float +%coopmat_a = OpTypeCooperativeMatrixKHR %float %uint_3 %uint_8 %uint_8 %uint_0 +%coopmat_acc = OpTypeCooperativeMatrixKHR %float %uint_3 %uint_8 %uint_8 %uint_2 + %main = OpFunction %void None %3 + %5 = OpLabel +; Row-major load from offset 0 + %p0 = OpAccessChain %ptr_ssbo_float %ssbo %uint_0 %uint_0 + %mat_a = OpCooperativeMatrixLoadKHR %coopmat_a %p0 %uint_0 %uint_8 +; Row-major store to offset 0 + OpCooperativeMatrixStoreKHR %p0 %mat_a %uint_0 %uint_8 +; Column-major load from offset 0 + %mat_col = OpCooperativeMatrixLoadKHR %coopmat_acc %p0 %uint_1 %uint_8 +; Column-major store to offset 0 + OpCooperativeMatrixStoreKHR %p0 %mat_col %uint_1 %uint_8 + OpReturn + OpFunctionEnd diff --git a/shaders-msl-no-opt/asm/comp/cooperative-matrix-muladd.asm.msl31.comp b/shaders-msl-no-opt/asm/comp/cooperative-matrix-muladd.asm.msl31.comp new file mode 100644 index 000000000..1ddc3eb44 --- /dev/null +++ b/shaders-msl-no-opt/asm/comp/cooperative-matrix-muladd.asm.msl31.comp @@ -0,0 +1,77 @@ +; SPIR-V +; Version: 1.6 +; Generator: Khronos SPIR-V Tools Assembler; 0 +; Bound: 60 +; Schema: 0 + OpCapability Shader + OpCapability Float16 + OpCapability CooperativeMatrixKHR + OpCapability VulkanMemoryModel + OpExtension "SPV_KHR_cooperative_matrix" + OpExtension "SPV_KHR_vulkan_memory_model" + OpMemoryModel Logical Vulkan + OpEntryPoint GLCompute %main "main" + OpExecutionMode %main LocalSize 32 1 1 + OpName %main "main" + OpName %SSBO32 "SSBO32" + OpMemberName %SSBO32 0 "data" + OpName %ssbo32 "ssbo32" + OpName %SSBO16 "SSBO16" + OpMemberName %SSBO16 0 "data" + OpName %ssbo16 "ssbo16" + OpDecorate %arr_float ArrayStride 4 + OpMemberDecorate %SSBO32 0 Offset 0 + OpDecorate %SSBO32 Block + OpDecorate %ssbo32 DescriptorSet 0 + OpDecorate %ssbo32 Binding 0 + OpDecorate %arr_half ArrayStride 2 + OpMemberDecorate %SSBO16 0 Offset 0 + OpDecorate %SSBO16 Block + OpDecorate %ssbo16 DescriptorSet 0 + OpDecorate %ssbo16 Binding 1 + %void = OpTypeVoid + %3 = OpTypeFunction %void + %float = OpTypeFloat 32 + %half = OpTypeFloat 16 + %uint = OpTypeInt 32 0 + %uint_0 = OpConstant %uint 0 + %uint_1 = OpConstant %uint 1 + %uint_2 = OpConstant %uint 2 + %uint_3 = OpConstant %uint 3 + %uint_8 = OpConstant %uint 8 + %arr_float = OpTypeRuntimeArray %float + %SSBO32 = OpTypeStruct %arr_float +%ptr_ssbo_SSBO32 = OpTypePointer StorageBuffer %SSBO32 + %ssbo32 = OpVariable %ptr_ssbo_SSBO32 StorageBuffer + %arr_half = OpTypeRuntimeArray %half + %SSBO16 = OpTypeStruct %arr_half +%ptr_ssbo_SSBO16 = OpTypePointer StorageBuffer %SSBO16 + %ssbo16 = OpVariable %ptr_ssbo_SSBO16 StorageBuffer +%ptr_ssbo_float = OpTypePointer StorageBuffer %float +%ptr_ssbo_half = OpTypePointer StorageBuffer %half +; float32 cooperative matrix types +%coopmat_f32_A = OpTypeCooperativeMatrixKHR %float %uint_3 %uint_8 %uint_8 %uint_0 +%coopmat_f32_B = OpTypeCooperativeMatrixKHR %float %uint_3 %uint_8 %uint_8 %uint_1 +%coopmat_f32_acc = OpTypeCooperativeMatrixKHR %float %uint_3 %uint_8 %uint_8 %uint_2 +; half cooperative matrix types +%coopmat_f16_A = OpTypeCooperativeMatrixKHR %half %uint_3 %uint_8 %uint_8 %uint_0 +%coopmat_f16_B = OpTypeCooperativeMatrixKHR %half %uint_3 %uint_8 %uint_8 %uint_1 +%coopmat_f16_acc = OpTypeCooperativeMatrixKHR %half %uint_3 %uint_8 %uint_8 %uint_2 + %main = OpFunction %void None %3 + %5 = OpLabel +; float32 muladd: D = A * B + C + %p_f32 = OpAccessChain %ptr_ssbo_float %ssbo32 %uint_0 %uint_0 + %f_A = OpCooperativeMatrixLoadKHR %coopmat_f32_A %p_f32 %uint_0 %uint_8 + %f_B = OpCooperativeMatrixLoadKHR %coopmat_f32_B %p_f32 %uint_0 %uint_8 + %f_C = OpCooperativeMatrixLoadKHR %coopmat_f32_acc %p_f32 %uint_0 %uint_8 + %f_D = OpCooperativeMatrixMulAddKHR %coopmat_f32_acc %f_A %f_B %f_C + OpCooperativeMatrixStoreKHR %p_f32 %f_D %uint_0 %uint_8 +; half muladd: D = A * B + C + %p_f16 = OpAccessChain %ptr_ssbo_half %ssbo16 %uint_0 %uint_0 + %h_A = OpCooperativeMatrixLoadKHR %coopmat_f16_A %p_f16 %uint_0 %uint_8 + %h_B = OpCooperativeMatrixLoadKHR %coopmat_f16_B %p_f16 %uint_0 %uint_8 + %h_C = OpCooperativeMatrixLoadKHR %coopmat_f16_acc %p_f16 %uint_0 %uint_8 + %h_D = OpCooperativeMatrixMulAddKHR %coopmat_f16_acc %h_A %h_B %h_C + OpCooperativeMatrixStoreKHR %p_f16 %h_D %uint_0 %uint_8 + OpReturn + OpFunctionEnd diff --git a/spirv_msl.cpp b/spirv_msl.cpp index 025427cc8..26f3b6db3 100644 --- a/spirv_msl.cpp +++ b/spirv_msl.cpp @@ -1930,6 +1930,13 @@ void CompilerMSL::preprocess_op_codes() add_header_line("using namespace metal::raytracing;"); add_header_line("#endif"); } + + if (preproc.uses_cooperative_matrix) + { + if (!msl_options.supports_msl_version(3, 1)) + SPIRV_CROSS_THROW("Cooperative matrices require MSL 3.1 or later."); + add_header_line("#include "); + } } // Move the Private and Workgroup global variables to the entry function. @@ -10694,10 +10701,185 @@ void CompilerMSL::emit_instruction(const Instruction &instruction) break; } + case OpCooperativeMatrixLoadKHR: + { + uint32_t result_type = ops[0]; + uint32_t id = ops[1]; + uint32_t ptr = ops[2]; + uint32_t layout = ops[3]; + + auto &layout_c = get(layout); + if (layout_c.specialization) + SPIRV_CROSS_THROW("MSL cooperative matrix load does not support spec-constant layout."); + uint32_t layout_val = layout_c.scalar(); + bool col_major = false; + + switch (layout_val) + { + case CooperativeMatrixLayoutRowMajorKHR: + case CooperativeMatrixLayoutColumnMajorKHR: + if (instruction.length < 5) + SPIRV_CROSS_THROW("MSL cooperative matrix load requires Stride for row/column-major layouts."); + col_major = (layout_val == CooperativeMatrixLayoutColumnMajorKHR); + break; + + default: + SPIRV_CROSS_THROW("MSL cooperative matrix load only supports RowMajorKHR and ColumnMajorKHR layouts."); + } + + uint32_t stride = ops[4]; + + emit_uninitialized_temporary_expression(result_type, id); + auto ptr_expr = to_ptr_expression(ptr); + + if (col_major) + statement("simdgroup_load(", to_expression(id), ", ", + ptr_expr, ", ", to_expression(stride), ", ulong2(0), true);"); + else + statement("simdgroup_load(", to_expression(id), ", ", + ptr_expr, ", ", to_expression(stride), ");"); + + register_read(id, ptr, false); + break; + } + + case OpCooperativeMatrixStoreKHR: + { + uint32_t ptr = ops[0]; + uint32_t obj = ops[1]; + uint32_t layout = ops[2]; + + auto &layout_c = get(layout); + if (layout_c.specialization) + SPIRV_CROSS_THROW("MSL cooperative matrix store does not support spec-constant layout."); + uint32_t layout_val = layout_c.scalar(); + bool col_major = false; + + switch (layout_val) + { + case CooperativeMatrixLayoutRowMajorKHR: + case CooperativeMatrixLayoutColumnMajorKHR: + if (instruction.length < 4) + SPIRV_CROSS_THROW("MSL cooperative matrix store requires Stride for row/column-major layouts."); + col_major = (layout_val == CooperativeMatrixLayoutColumnMajorKHR); + break; + + default: + SPIRV_CROSS_THROW("MSL cooperative matrix store only supports RowMajorKHR and ColumnMajorKHR layouts."); + } + + uint32_t stride = ops[3]; + + auto ptr_expr = to_ptr_expression(ptr); + + if (col_major) + statement("simdgroup_store(", to_expression(obj), ", ", + ptr_expr, ", ", to_expression(stride), ", ulong2(0), true);"); + else + statement("simdgroup_store(", to_expression(obj), ", ", + ptr_expr, ", ", to_expression(stride), ");"); + + register_write(ptr); + break; + } + + case OpCooperativeMatrixMulAddKHR: + { + uint32_t result_type = ops[0]; + uint32_t id = ops[1]; + uint32_t A = ops[2], B = ops[3], C = ops[4]; + + emit_uninitialized_temporary_expression(result_type, id); + statement("simdgroup_multiply_accumulate(", to_expression(id), ", ", + to_unpacked_expression(A), ", ", + to_unpacked_expression(B), ", ", + to_unpacked_expression(C), ");"); + + inherit_expression_dependencies(id, A); + inherit_expression_dependencies(id, B); + inherit_expression_dependencies(id, C); + break; + } + + case OpCooperativeMatrixLengthKHR: + { + uint32_t result_type = ops[0]; + uint32_t id = ops[1]; + auto &coop_type = get(ops[2]); + + if (coop_type.op != OpTypeCooperativeMatrixKHR) + SPIRV_CROSS_THROW("OpCooperativeMatrixLengthKHR requires cooperative matrix type."); + + auto &component_type = get(coop_type.parent_type); + auto coop_type_name = type_to_glsl(coop_type); + auto component_type_name = type_to_glsl(component_type); + + auto expr = join(type_to_glsl(get(result_type)), + "(sizeof(", coop_type_name, "::storage_type) / sizeof(", component_type_name, "))"); + emit_op(result_type, id, expr, true); + break; + } + default: + { + auto is_cooperative_matrix_typed_id = [&](uint32_t typed_id) -> bool { + auto *type = maybe_get(typed_id); + if (!type) + { + auto *var = maybe_get(typed_id); + if (var) + type = maybe_get(var->basetype); + + auto *expr = maybe_get(typed_id); + if (!type && expr) + type = maybe_get(expr->expression_type); + + auto *constant = maybe_get(typed_id); + if (!type && constant) + type = maybe_get(constant->constant_type); + + auto *constant_op = maybe_get(typed_id); + if (!type && constant_op) + type = maybe_get(constant_op->basetype); + + auto *undef = maybe_get(typed_id); + if (!type && undef) + type = maybe_get(undef->basetype); + + auto *ac = maybe_get(typed_id); + if (!type && ac) + type = maybe_get(ac->basetype); + } + + while (type && (is_pointer(*type) || is_array(*type))) + type = maybe_get(type->parent_type); + + return type && type->op == OpTypeCooperativeMatrixKHR; + }; + + // Prevent GLSL cooperative matrix code from leaking into MSL output. + // Element-wise arithmetic on cooperative matrices is not supported in Metal. + if (instruction.length >= 2) + { + if (is_cooperative_matrix_typed_id(ops[0])) + SPIRV_CROSS_THROW("Unsupported operation on cooperative matrix in MSL backend."); + + if (opcode == OpCompositeExtract || opcode == OpVectorExtractDynamic) + { + if (instruction.length >= 3 && is_cooperative_matrix_typed_id(ops[2])) + SPIRV_CROSS_THROW("Unsupported extraction from cooperative matrix in MSL backend."); + } + else if (opcode == OpCompositeInsert || opcode == OpVectorInsertDynamic) + { + if ((instruction.length >= 3 && is_cooperative_matrix_typed_id(ops[2])) || + (instruction.length >= 4 && is_cooperative_matrix_typed_id(ops[3]))) + SPIRV_CROSS_THROW("Unsupported operation on cooperative matrix in MSL backend."); + } + } CompilerGLSL::emit_instruction(instruction); break; } + } previous_instruction_opcode = opcode; } @@ -16725,6 +16907,48 @@ string CompilerMSL::type_to_glsl(const SPIRType &type, uint32_t id, bool member) return type_name; } + // Cooperative matrix -> Metal simdgroup matrix type + { + const SPIRType *coop_type = &type; + while (is_pointer(*coop_type) || is_array(*coop_type)) + coop_type = &get(coop_type->parent_type); + + if (coop_type->op == OpTypeCooperativeMatrixKHR) + { + if (!msl_options.supports_msl_version(3, 1)) + SPIRV_CROSS_THROW("Cooperative matrices require MSL 3.1 or later."); + + // Only Subgroup scope + auto &scope_c = get(coop_type->ext.cooperative.scope_id); + if (scope_c.specialization) + SPIRV_CROSS_THROW("MSL does not support spec-constant scope for cooperative matrices."); + if (scope_c.scalar() != ScopeSubgroup) + SPIRV_CROSS_THROW("MSL cooperative matrices only support Subgroup scope."); + + // Only 8x8 + auto &rows_c = get(coop_type->ext.cooperative.rows_id); + auto &cols_c = get(coop_type->ext.cooperative.columns_id); + if (rows_c.specialization || cols_c.specialization) + SPIRV_CROSS_THROW("MSL does not support spec-constant dimensions for cooperative matrices."); + if (rows_c.scalar() != 8 || cols_c.scalar() != 8) + SPIRV_CROSS_THROW("MSL cooperative matrices only support 8x8 dimensions."); + + // Map component type to simdgroup_*8x8 + auto &comp = get(coop_type->parent_type); + switch (comp.basetype) + { + case SPIRType::Float: + return "simdgroup_float8x8"; + case SPIRType::Half: + return "simdgroup_half8x8"; + case SPIRType::BFloat16: + return "simdgroup_bfloat8x8"; + default: + SPIRV_CROSS_THROW("Unsupported component type for MSL cooperative matrix."); + } + } + } + switch (type.basetype) { case SPIRType::Struct: @@ -16808,6 +17032,11 @@ string CompilerMSL::type_to_glsl(const SPIRType &type, uint32_t id, bool member) case SPIRType::Double: type_name = "double"; // Currently unsupported break; + case SPIRType::BFloat16: + if (!msl_options.supports_msl_version(3, 1)) + SPIRV_CROSS_THROW("bfloat16 requires MSL 3.1 or later."); + type_name = "bfloat"; + break; case SPIRType::AccelerationStructure: if (msl_options.supports_msl_version(2, 4)) type_name = "raytracing::acceleration_structure"; @@ -18860,6 +19089,17 @@ bool CompilerMSL::OpCodePreprocessor::handle(Op opcode, const uint32_t *args, ui needs_helper_invocation = true; break; + case OpCooperativeMatrixLoadKHR: + case OpCooperativeMatrixMulAddKHR: + case OpCooperativeMatrixLengthKHR: + uses_cooperative_matrix = true; + break; + + case OpCooperativeMatrixStoreKHR: + uses_cooperative_matrix = true; + check_resource_write(args[0]); + break; + default: break; } diff --git a/spirv_msl.hpp b/spirv_msl.hpp index 033cb903b..75b18f908 100644 --- a/spirv_msl.hpp +++ b/spirv_msl.hpp @@ -1401,6 +1401,7 @@ class CompilerMSL : public CompilerGLSL bool needs_subgroup_size = false; bool needs_sample_id = false; bool needs_helper_invocation = false; + bool uses_cooperative_matrix = false; }; // OpcodeHandler that scans for uses of sampled images From 6fc910e9b4ad2e528d00f3e4e2e1faf9c0e1fe35 Mon Sep 17 00:00:00 2001 From: Chengye YU Date: Tue, 10 Feb 2026 22:17:03 +0800 Subject: [PATCH 2/8] MSL: reject cooperative matrix muladd operand flags --- spirv_msl.cpp | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/spirv_msl.cpp b/spirv_msl.cpp index 26f3b6db3..740e5f7da 100644 --- a/spirv_msl.cpp +++ b/spirv_msl.cpp @@ -10788,6 +10788,10 @@ void CompilerMSL::emit_instruction(const Instruction &instruction) uint32_t result_type = ops[0]; uint32_t id = ops[1]; uint32_t A = ops[2], B = ops[3], C = ops[4]; + uint32_t matrix_operands = instruction.length >= 6 ? ops[5] : uint32_t(CooperativeMatrixOperandsMaskNone); + + if (matrix_operands != uint32_t(CooperativeMatrixOperandsMaskNone)) + SPIRV_CROSS_THROW("MSL cooperative matrix muladd does not support setting matrix operands flags."); emit_uninitialized_temporary_expression(result_type, id); statement("simdgroup_multiply_accumulate(", to_expression(id), ", ", From 188e392198da8bae2c621944b7079be37589f212 Mon Sep 17 00:00:00 2001 From: Hans-Kristian Arntzen Date: Thu, 26 Feb 2026 10:48:12 +0100 Subject: [PATCH 3/8] Workaround some MSVC shenanigans --- spirv_msl.cpp | 26 +++++++++++++------------- 1 file changed, 13 insertions(+), 13 deletions(-) diff --git a/spirv_msl.cpp b/spirv_msl.cpp index 740e5f7da..cd4fe8e39 100644 --- a/spirv_msl.cpp +++ b/spirv_msl.cpp @@ -10830,33 +10830,33 @@ void CompilerMSL::emit_instruction(const Instruction &instruction) auto *type = maybe_get(typed_id); if (!type) { - auto *var = maybe_get(typed_id); + auto *var = this->maybe_get(typed_id); if (var) - type = maybe_get(var->basetype); + type = this->maybe_get(var->basetype); - auto *expr = maybe_get(typed_id); + auto *expr = this->maybe_get(typed_id); if (!type && expr) - type = maybe_get(expr->expression_type); + type = this->maybe_get(expr->expression_type); - auto *constant = maybe_get(typed_id); + auto *constant = this->maybe_get(typed_id); if (!type && constant) - type = maybe_get(constant->constant_type); + type = this->maybe_get(constant->constant_type); - auto *constant_op = maybe_get(typed_id); + auto *constant_op = this->maybe_get(typed_id); if (!type && constant_op) - type = maybe_get(constant_op->basetype); + type = this->maybe_get(constant_op->basetype); - auto *undef = maybe_get(typed_id); + auto *undef = this->maybe_get(typed_id); if (!type && undef) - type = maybe_get(undef->basetype); + type = this->maybe_get(undef->basetype); - auto *ac = maybe_get(typed_id); + auto *ac = this->maybe_get(typed_id); if (!type && ac) - type = maybe_get(ac->basetype); + type = this->maybe_get(ac->basetype); } while (type && (is_pointer(*type) || is_array(*type))) - type = maybe_get(type->parent_type); + type = this->maybe_get(type->parent_type); return type && type->op == OpTypeCooperativeMatrixKHR; }; From cda74fe5817b77300852718d594c34190c22e1da Mon Sep 17 00:00:00 2001 From: Chengye YU Date: Sat, 28 Feb 2026 15:42:25 +0800 Subject: [PATCH 4/8] MSL: Fix ptr-cast prepass and coopmat typed load/store --- ...x-workgroup-cast-load-store.asm.msl31.comp | 56 ++++++ ...matrix-workgroup-load-store.asm.msl31.comp | 58 ++++++ ...x-workgroup-cast-load-store.asm.msl31.comp | 41 ++++ ...matrix-workgroup-load-store.asm.msl31.comp | 39 ++++ spirv_cross.cpp | 9 +- spirv_msl.cpp | 179 ++++++++++++++---- 6 files changed, 344 insertions(+), 38 deletions(-) create mode 100644 reference/shaders-msl-no-opt/asm/comp/cooperative-matrix-workgroup-cast-load-store.asm.msl31.comp create mode 100644 reference/shaders-msl-no-opt/asm/comp/cooperative-matrix-workgroup-load-store.asm.msl31.comp create mode 100644 shaders-msl-no-opt/asm/comp/cooperative-matrix-workgroup-cast-load-store.asm.msl31.comp create mode 100644 shaders-msl-no-opt/asm/comp/cooperative-matrix-workgroup-load-store.asm.msl31.comp diff --git a/reference/shaders-msl-no-opt/asm/comp/cooperative-matrix-workgroup-cast-load-store.asm.msl31.comp b/reference/shaders-msl-no-opt/asm/comp/cooperative-matrix-workgroup-cast-load-store.asm.msl31.comp new file mode 100644 index 000000000..5d29f8c25 --- /dev/null +++ b/reference/shaders-msl-no-opt/asm/comp/cooperative-matrix-workgroup-cast-load-store.asm.msl31.comp @@ -0,0 +1,56 @@ +#pragma clang diagnostic ignored "-Wmissing-prototypes" +#pragma clang diagnostic ignored "-Wmissing-braces" + +#include +#include +#include + +using namespace metal; + +template +struct spvUnsafeArray +{ + T elements[Num ? Num : 1]; + + thread T& operator [] (size_t pos) thread + { + return elements[pos]; + } + constexpr const thread T& operator [] (size_t pos) const thread + { + return elements[pos]; + } + + device T& operator [] (size_t pos) device + { + return elements[pos]; + } + constexpr const device T& operator [] (size_t pos) const device + { + return elements[pos]; + } + + constexpr const constant T& operator [] (size_t pos) const constant + { + return elements[pos]; + } + + threadgroup T& operator [] (size_t pos) threadgroup + { + return elements[pos]; + } + constexpr const threadgroup T& operator [] (size_t pos) const threadgroup + { + return elements[pos]; + } +}; + +kernel void main0() +{ + threadgroup spvUnsafeArray _15; + _15[0u] = uchar(0); + simdgroup_half8x8 _20; + simdgroup_load(_20, reinterpret_cast(&_15[0u]), (16u) / 2u); + simdgroup_store(_20, reinterpret_cast(&_15[0u]), (16u) / 2u); +} + diff --git a/reference/shaders-msl-no-opt/asm/comp/cooperative-matrix-workgroup-load-store.asm.msl31.comp b/reference/shaders-msl-no-opt/asm/comp/cooperative-matrix-workgroup-load-store.asm.msl31.comp new file mode 100644 index 000000000..a69c7b7fa --- /dev/null +++ b/reference/shaders-msl-no-opt/asm/comp/cooperative-matrix-workgroup-load-store.asm.msl31.comp @@ -0,0 +1,58 @@ +#pragma clang diagnostic ignored "-Wmissing-prototypes" +#pragma clang diagnostic ignored "-Wmissing-braces" + +#include +#include +#include + +using namespace metal; + +template +struct spvUnsafeArray +{ + T elements[Num ? Num : 1]; + + thread T& operator [] (size_t pos) thread + { + return elements[pos]; + } + constexpr const thread T& operator [] (size_t pos) const thread + { + return elements[pos]; + } + + device T& operator [] (size_t pos) device + { + return elements[pos]; + } + constexpr const device T& operator [] (size_t pos) const device + { + return elements[pos]; + } + + constexpr const constant T& operator [] (size_t pos) const constant + { + return elements[pos]; + } + + threadgroup T& operator [] (size_t pos) threadgroup + { + return elements[pos]; + } + constexpr const threadgroup T& operator [] (size_t pos) const threadgroup + { + return elements[pos]; + } +}; + +kernel void main0() +{ + threadgroup spvUnsafeArray _14; + simdgroup_float8x8 _18; + simdgroup_load(_18, &_14[0u], 8u); + simdgroup_store(_18, &_14[0u], 8u); + simdgroup_float8x8 _19; + simdgroup_load(_19, &_14[0u], 8u, ulong2(0), true); + simdgroup_store(_19, &_14[0u], 8u, ulong2(0), true); +} + diff --git a/shaders-msl-no-opt/asm/comp/cooperative-matrix-workgroup-cast-load-store.asm.msl31.comp b/shaders-msl-no-opt/asm/comp/cooperative-matrix-workgroup-cast-load-store.asm.msl31.comp new file mode 100644 index 000000000..a8022da52 --- /dev/null +++ b/shaders-msl-no-opt/asm/comp/cooperative-matrix-workgroup-cast-load-store.asm.msl31.comp @@ -0,0 +1,41 @@ +; SPIR-V +; Version: 1.6 +; Generator: Khronos SPIR-V Tools Assembler; 0 +; Bound: 80 +; Schema: 0 + OpCapability Shader + OpCapability CooperativeMatrixKHR + OpCapability Float16 + OpCapability Int8 + OpCapability VulkanMemoryModel + OpExtension "SPV_KHR_cooperative_matrix" + OpExtension "SPV_KHR_vulkan_memory_model" + OpMemoryModel Logical Vulkan + OpEntryPoint GLCompute %main "main" + OpExecutionMode %main LocalSize 32 1 1 + OpName %main "main" + %void = OpTypeVoid + %3 = OpTypeFunction %void + %uint = OpTypeInt 32 0 + %u8 = OpTypeInt 8 0 + %half = OpTypeFloat 16 + %uint_0 = OpConstant %uint 0 + %uint_3 = OpConstant %uint 3 + %uint_8 = OpConstant %uint 8 + %uint_16 = OpConstant %uint 16 + %uint_128 = OpConstant %uint 128 + %arr_u8 = OpTypeArray %u8 %uint_128 +%ptr_wg_arr_u8 = OpTypePointer Workgroup %arr_u8 + %wg_data = OpVariable %ptr_wg_arr_u8 Workgroup +%ptr_wg_u8 = OpTypePointer Workgroup %u8 +%coopmat_half = OpTypeCooperativeMatrixKHR %half %uint_3 %uint_8 %uint_8 %uint_0 + %u8_zero = OpConstant %u8 0 + %main = OpFunction %void None %3 + %5 = OpLabel +; Use uint8_t backing storage. MSL backend must cast pointer and convert stride from bytes to elements. + %p_u8 = OpAccessChain %ptr_wg_u8 %wg_data %uint_0 + OpStore %p_u8 %u8_zero + %mat = OpCooperativeMatrixLoadKHR %coopmat_half %p_u8 %uint_0 %uint_16 + OpCooperativeMatrixStoreKHR %p_u8 %mat %uint_0 %uint_16 + OpReturn + OpFunctionEnd diff --git a/shaders-msl-no-opt/asm/comp/cooperative-matrix-workgroup-load-store.asm.msl31.comp b/shaders-msl-no-opt/asm/comp/cooperative-matrix-workgroup-load-store.asm.msl31.comp new file mode 100644 index 000000000..8c49080c6 --- /dev/null +++ b/shaders-msl-no-opt/asm/comp/cooperative-matrix-workgroup-load-store.asm.msl31.comp @@ -0,0 +1,39 @@ +; SPIR-V +; Version: 1.6 +; Generator: Khronos SPIR-V Tools Assembler; 0 +; Bound: 60 +; Schema: 0 + OpCapability Shader + OpCapability CooperativeMatrixKHR + OpCapability VulkanMemoryModel + OpExtension "SPV_KHR_cooperative_matrix" + OpExtension "SPV_KHR_vulkan_memory_model" + OpMemoryModel Logical Vulkan + OpEntryPoint GLCompute %main "main" + OpExecutionMode %main LocalSize 32 1 1 + OpName %main "main" + %void = OpTypeVoid + %3 = OpTypeFunction %void + %float = OpTypeFloat 32 + %uint = OpTypeInt 32 0 + %uint_0 = OpConstant %uint 0 + %uint_1 = OpConstant %uint 1 + %uint_3 = OpConstant %uint 3 + %uint_8 = OpConstant %uint 8 + %uint_64 = OpConstant %uint 64 + %arr_float = OpTypeArray %float %uint_64 +%ptr_wg_arr_float = OpTypePointer Workgroup %arr_float + %wg_data = OpVariable %ptr_wg_arr_float Workgroup +%ptr_wg_float = OpTypePointer Workgroup %float +%coopmat_a = OpTypeCooperativeMatrixKHR %float %uint_3 %uint_8 %uint_8 %uint_0 + %main = OpFunction %void None %3 + %5 = OpLabel +; Row-major load/store from workgroup memory. + %p0 = OpAccessChain %ptr_wg_float %wg_data %uint_0 + %mat_row = OpCooperativeMatrixLoadKHR %coopmat_a %p0 %uint_0 %uint_8 + OpCooperativeMatrixStoreKHR %p0 %mat_row %uint_0 %uint_8 +; Column-major load/store from workgroup memory. + %mat_col = OpCooperativeMatrixLoadKHR %coopmat_a %p0 %uint_1 %uint_8 + OpCooperativeMatrixStoreKHR %p0 %mat_col %uint_1 %uint_8 + OpReturn + OpFunctionEnd diff --git a/spirv_cross.cpp b/spirv_cross.cpp index 1031d3ff6..ee4c2afc1 100644 --- a/spirv_cross.cpp +++ b/spirv_cross.cpp @@ -411,11 +411,19 @@ SPIRVariable *Compiler::maybe_get_backing_variable(uint32_t chain) { auto *cexpr = maybe_get(chain); if (cexpr) + { var = maybe_get(cexpr->loaded_from); + if (!var && cexpr->loaded_from != chain) + var = maybe_get_backing_variable(cexpr->loaded_from); + } auto *access_chain = maybe_get(chain); if (access_chain) + { var = maybe_get(access_chain->loaded_from); + if (!var && access_chain->loaded_from != chain) + var = maybe_get_backing_variable(access_chain->loaded_from); + } } return var; @@ -5804,4 +5812,3 @@ const SPIRType *Compiler::OpcodeHandler::get_expression_result_type(uint32_t id) return &compiler.get(itr->second); } - diff --git a/spirv_msl.cpp b/spirv_msl.cpp index cd4fe8e39..9dbb569e6 100644 --- a/spirv_msl.cpp +++ b/spirv_msl.cpp @@ -10529,15 +10529,23 @@ void CompilerMSL::emit_instruction(const Instruction &instruction) if (opcode != OpBitcast || is_pointer(type) || is_pointer(input_type)) { string op; + auto input_expr = is_pointer(input_type) ? to_ptr_expression(ops[2]) : to_unpacked_expression(ops[2]); if ((type.vecsize == 1 || is_pointer(type)) && (input_type.vecsize == 1 || is_pointer(input_type))) - op = join("reinterpret_cast<", type_to_glsl(type), ">(", to_unpacked_expression(ops[2]), ")"); + op = join("reinterpret_cast<", type_to_glsl(type), ">(", input_expr, ")"); else if (input_type.vecsize == 2) - op = join("reinterpret_cast<", type_to_glsl(type), ">(as_type(", to_unpacked_expression(ops[2]), "))"); + op = join("reinterpret_cast<", type_to_glsl(type), ">(as_type(", input_expr, "))"); else - op = join("as_type<", type_to_glsl(type), ">(reinterpret_cast(", to_unpacked_expression(ops[2]), "))"); + op = join("as_type<", type_to_glsl(type), ">(reinterpret_cast(", input_expr, "))"); - emit_op(ops[0], ops[1], op, should_forward(ops[2])); + auto &expr = emit_op(ops[0], ops[1], op, should_forward(ops[2])); + if (is_pointer(type)) + { + if (auto *backing_var = maybe_get_backing_variable(ops[2])) + expr.loaded_from = backing_var->self; + else + expr.loaded_from = ID(ops[2]); + } inherit_expression_dependencies(ops[1], ops[2]); } else @@ -10706,38 +10714,76 @@ void CompilerMSL::emit_instruction(const Instruction &instruction) uint32_t result_type = ops[0]; uint32_t id = ops[1]; uint32_t ptr = ops[2]; - uint32_t layout = ops[3]; + uint32_t layout = ops[3]; - auto &layout_c = get(layout); - if (layout_c.specialization) - SPIRV_CROSS_THROW("MSL cooperative matrix load does not support spec-constant layout."); - uint32_t layout_val = layout_c.scalar(); - bool col_major = false; + auto &layout_c = get(layout); + if (layout_c.specialization) + SPIRV_CROSS_THROW("MSL cooperative matrix load does not support spec-constant layout."); + uint32_t layout_val = layout_c.scalar(); + bool col_major = false; - switch (layout_val) - { - case CooperativeMatrixLayoutRowMajorKHR: - case CooperativeMatrixLayoutColumnMajorKHR: - if (instruction.length < 5) - SPIRV_CROSS_THROW("MSL cooperative matrix load requires Stride for row/column-major layouts."); - col_major = (layout_val == CooperativeMatrixLayoutColumnMajorKHR); - break; + switch (layout_val) + { + case CooperativeMatrixLayoutRowMajorKHR: + case CooperativeMatrixLayoutColumnMajorKHR: + if (instruction.length < 5) + SPIRV_CROSS_THROW("MSL cooperative matrix load requires Stride for row/column-major layouts."); + col_major = (layout_val == CooperativeMatrixLayoutColumnMajorKHR); + break; - default: - SPIRV_CROSS_THROW("MSL cooperative matrix load only supports RowMajorKHR and ColumnMajorKHR layouts."); - } + default: + SPIRV_CROSS_THROW("MSL cooperative matrix load only supports RowMajorKHR and ColumnMajorKHR layouts."); + } uint32_t stride = ops[4]; emit_uninitialized_temporary_expression(result_type, id); + auto ptr_expr = to_ptr_expression(ptr); + string stride_expr = to_expression(stride); + + // The pointer operand is allowed to use a different element type than the cooperative matrix component type. + // In that case, cast the pointer and convert the stride from source element units to component element units. + auto &mat_type = get(result_type); + auto &component_type = get(mat_type.parent_type); + auto &ptr_type = expression_type(ptr); + auto &pointee_type = get(ptr_type.parent_type); + if (pointee_type.self != component_type.self) + { + auto addr_space = get_type_address_space(ptr_type, ptr); + ptr_expr = join("reinterpret_cast<", addr_space, " ", type_to_glsl(component_type), "*>(", ptr_expr, ")"); + + uint32_t src_bytes = (pointee_type.width * pointee_type.vecsize) / 8; + uint32_t dst_bytes = (component_type.width * component_type.vecsize) / 8; + if (src_bytes == 0 || dst_bytes == 0) + SPIRV_CROSS_THROW("Cannot determine element size for cooperative matrix load/store."); + + if (src_bytes == dst_bytes) + { + // No conversion needed. + } + else if (src_bytes > dst_bytes && (src_bytes % dst_bytes) == 0) + { + uint32_t multiplier = src_bytes / dst_bytes; + stride_expr = join("(", stride_expr, ") * ", multiplier, "u"); + } + else if (src_bytes < dst_bytes && (dst_bytes % src_bytes) == 0) + { + uint32_t divisor = dst_bytes / src_bytes; + stride_expr = join("(", stride_expr, ") / ", divisor, "u"); + } + else + { + stride_expr = join("((", stride_expr, ") * ", src_bytes, "u) / ", dst_bytes, "u"); + } + } if (col_major) statement("simdgroup_load(", to_expression(id), ", ", - ptr_expr, ", ", to_expression(stride), ", ulong2(0), true);"); + ptr_expr, ", ", stride_expr, ", ulong2(0), true);"); else statement("simdgroup_load(", to_expression(id), ", ", - ptr_expr, ", ", to_expression(stride), ");"); + ptr_expr, ", ", stride_expr, ");"); register_read(id, ptr, false); break; @@ -10768,20 +10814,57 @@ void CompilerMSL::emit_instruction(const Instruction &instruction) SPIRV_CROSS_THROW("MSL cooperative matrix store only supports RowMajorKHR and ColumnMajorKHR layouts."); } - uint32_t stride = ops[3]; + uint32_t stride = ops[3]; - auto ptr_expr = to_ptr_expression(ptr); + auto ptr_expr = to_ptr_expression(ptr); + string stride_expr = to_expression(stride); - if (col_major) - statement("simdgroup_store(", to_expression(obj), ", ", - ptr_expr, ", ", to_expression(stride), ", ulong2(0), true);"); - else - statement("simdgroup_store(", to_expression(obj), ", ", - ptr_expr, ", ", to_expression(stride), ");"); + // The pointer operand is allowed to use a different element type than the cooperative matrix component type. + // In that case, cast the pointer and convert the stride from source element units to component element units. + auto &mat_type = expression_type(obj); + auto &component_type = get(mat_type.parent_type); + auto &ptr_type = expression_type(ptr); + auto &pointee_type = get(ptr_type.parent_type); + if (pointee_type.self != component_type.self) + { + auto addr_space = get_type_address_space(ptr_type, ptr); + ptr_expr = join("reinterpret_cast<", addr_space, " ", type_to_glsl(component_type), "*>(", ptr_expr, ")"); - register_write(ptr); - break; - } + uint32_t src_bytes = (pointee_type.width * pointee_type.vecsize) / 8; + uint32_t dst_bytes = (component_type.width * component_type.vecsize) / 8; + if (src_bytes == 0 || dst_bytes == 0) + SPIRV_CROSS_THROW("Cannot determine element size for cooperative matrix load/store."); + + if (src_bytes == dst_bytes) + { + // No conversion needed. + } + else if (src_bytes > dst_bytes && (src_bytes % dst_bytes) == 0) + { + uint32_t multiplier = src_bytes / dst_bytes; + stride_expr = join("(", stride_expr, ") * ", multiplier, "u"); + } + else if (src_bytes < dst_bytes && (dst_bytes % src_bytes) == 0) + { + uint32_t divisor = dst_bytes / src_bytes; + stride_expr = join("(", stride_expr, ") / ", divisor, "u"); + } + else + { + stride_expr = join("((", stride_expr, ") * ", src_bytes, "u) / ", dst_bytes, "u"); + } + } + + if (col_major) + statement("simdgroup_store(", to_expression(obj), ", ", + ptr_expr, ", ", stride_expr, ", ulong2(0), true);"); + else + statement("simdgroup_store(", to_expression(obj), ", ", + ptr_expr, ", ", stride_expr, ");"); + + register_write(ptr); + break; + } case OpCooperativeMatrixMulAddKHR: { @@ -16914,10 +16997,10 @@ string CompilerMSL::type_to_glsl(const SPIRType &type, uint32_t id, bool member) // Cooperative matrix -> Metal simdgroup matrix type { const SPIRType *coop_type = &type; - while (is_pointer(*coop_type) || is_array(*coop_type)) - coop_type = &get(coop_type->parent_type); + while (coop_type && (is_pointer(*coop_type) || is_array(*coop_type))) + coop_type = maybe_get(coop_type->parent_type); - if (coop_type->op == OpTypeCooperativeMatrixKHR) + if (coop_type && coop_type->op == OpTypeCooperativeMatrixKHR) { if (!msl_options.supports_msl_version(3, 1)) SPIRV_CROSS_THROW("Cooperative matrices require MSL 3.1 or later."); @@ -19032,6 +19115,28 @@ bool CompilerMSL::OpCodePreprocessor::handle(Op opcode, const uint32_t *args, ui break; } + case OpBitcast: + case OpConvertPtrToU: + case OpConvertUToPtr: + { + if (length < 3) + break; + + auto &result_type = self.get(args[0]); + auto *arg_type = get_expression_result_type(args[2]); + if (!arg_type) + arg_type = &self.expression_type(args[2]); + + if (opcode != OpBitcast || self.is_pointer(result_type) || (arg_type && self.is_pointer(*arg_type))) + { + uint32_t id = args[1]; + set(id, "", args[0], true); + self.register_read(id, args[2], true); + self.ir.ids[id].set_allow_type_rewrite(); + } + break; + } + case OpExtInst: { uint32_t extension_set = args[2]; From 0e67b321b068d4973c9c2cd7ab96a17a3eff0bb1 Mon Sep 17 00:00:00 2001 From: Hans-Kristian Arntzen Date: Fri, 13 Mar 2026 11:10:22 +0100 Subject: [PATCH 5/8] Update some stray references. --- ...ice_address-packed-vec-and-cast-to-and-from-uvec2.msl23.comp | 2 +- ...ice_address-packed-vec-and-cast-to-and-from-uvec2.msl23.comp | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/reference/opt/shaders-msl/comp/buffer_device_address-packed-vec-and-cast-to-and-from-uvec2.msl23.comp b/reference/opt/shaders-msl/comp/buffer_device_address-packed-vec-and-cast-to-and-from-uvec2.msl23.comp index 30e89455d..3cd0c3256 100644 --- a/reference/opt/shaders-msl/comp/buffer_device_address-packed-vec-and-cast-to-and-from-uvec2.msl23.comp +++ b/reference/opt/shaders-msl/comp/buffer_device_address-packed-vec-and-cast-to-and-from-uvec2.msl23.comp @@ -19,7 +19,7 @@ struct SSBO kernel void main0(constant UBO& _10 [[buffer(0)]]) { (reinterpret_cast(as_type(_10.b)))->a1 = float3(1.0, 2.0, 3.0); - device SSBO* _39 = reinterpret_cast(as_type(as_type(reinterpret_cast(reinterpret_cast(as_type(_10.b + uint2(32u))))))); + device SSBO* _39 = reinterpret_cast(as_type(as_type(reinterpret_cast((reinterpret_cast(as_type(_10.b + uint2(32u)))))))); _39->a1 = float3(_39->a1) + float3(1.0); } diff --git a/reference/shaders-msl/comp/buffer_device_address-packed-vec-and-cast-to-and-from-uvec2.msl23.comp b/reference/shaders-msl/comp/buffer_device_address-packed-vec-and-cast-to-and-from-uvec2.msl23.comp index f79a8b520..ba8d81a20 100644 --- a/reference/shaders-msl/comp/buffer_device_address-packed-vec-and-cast-to-and-from-uvec2.msl23.comp +++ b/reference/shaders-msl/comp/buffer_device_address-packed-vec-and-cast-to-and-from-uvec2.msl23.comp @@ -19,7 +19,7 @@ struct SSBO kernel void main0(constant UBO& _10 [[buffer(0)]]) { (reinterpret_cast(as_type(_10.b)))->a1 = float3(1.0, 2.0, 3.0); - uint2 v2 = as_type(reinterpret_cast(reinterpret_cast(as_type(_10.b + uint2(32u))))); + uint2 v2 = as_type(reinterpret_cast((reinterpret_cast(as_type(_10.b + uint2(32u)))))); float3 v3 = float3((reinterpret_cast(as_type(v2)))->a1); (reinterpret_cast(as_type(v2)))->a1 = v3 + float3(1.0); } From 7c69662b97f07db4ae5e7c84ad0601bd80b2e61b Mon Sep 17 00:00:00 2001 From: Hans-Kristian Arntzen Date: Fri, 13 Mar 2026 11:10:39 +0100 Subject: [PATCH 6/8] Indentation fixes. --- spirv_msl.cpp | 126 +++++++++++++++++++++++++------------------------- 1 file changed, 63 insertions(+), 63 deletions(-) diff --git a/spirv_msl.cpp b/spirv_msl.cpp index 9dbb569e6..3cbd5dd5c 100644 --- a/spirv_msl.cpp +++ b/spirv_msl.cpp @@ -10714,26 +10714,26 @@ void CompilerMSL::emit_instruction(const Instruction &instruction) uint32_t result_type = ops[0]; uint32_t id = ops[1]; uint32_t ptr = ops[2]; - uint32_t layout = ops[3]; + uint32_t layout = ops[3]; - auto &layout_c = get(layout); - if (layout_c.specialization) - SPIRV_CROSS_THROW("MSL cooperative matrix load does not support spec-constant layout."); - uint32_t layout_val = layout_c.scalar(); - bool col_major = false; + auto &layout_c = get(layout); + if (layout_c.specialization) + SPIRV_CROSS_THROW("MSL cooperative matrix load does not support spec-constant layout."); + uint32_t layout_val = layout_c.scalar(); + bool col_major = false; - switch (layout_val) - { - case CooperativeMatrixLayoutRowMajorKHR: - case CooperativeMatrixLayoutColumnMajorKHR: - if (instruction.length < 5) - SPIRV_CROSS_THROW("MSL cooperative matrix load requires Stride for row/column-major layouts."); - col_major = (layout_val == CooperativeMatrixLayoutColumnMajorKHR); - break; + switch (layout_val) + { + case CooperativeMatrixLayoutRowMajorKHR: + case CooperativeMatrixLayoutColumnMajorKHR: + if (instruction.length < 5) + SPIRV_CROSS_THROW("MSL cooperative matrix load requires Stride for row/column-major layouts."); + col_major = (layout_val == CooperativeMatrixLayoutColumnMajorKHR); + break; - default: - SPIRV_CROSS_THROW("MSL cooperative matrix load only supports RowMajorKHR and ColumnMajorKHR layouts."); - } + default: + SPIRV_CROSS_THROW("MSL cooperative matrix load only supports RowMajorKHR and ColumnMajorKHR layouts."); + } uint32_t stride = ops[4]; @@ -10780,10 +10780,10 @@ void CompilerMSL::emit_instruction(const Instruction &instruction) if (col_major) statement("simdgroup_load(", to_expression(id), ", ", - ptr_expr, ", ", stride_expr, ", ulong2(0), true);"); + ptr_expr, ", ", stride_expr, ", ulong2(0), true);"); else statement("simdgroup_load(", to_expression(id), ", ", - ptr_expr, ", ", stride_expr, ");"); + ptr_expr, ", ", stride_expr, ");"); register_read(id, ptr, false); break; @@ -10814,58 +10814,58 @@ void CompilerMSL::emit_instruction(const Instruction &instruction) SPIRV_CROSS_THROW("MSL cooperative matrix store only supports RowMajorKHR and ColumnMajorKHR layouts."); } - uint32_t stride = ops[3]; + uint32_t stride = ops[3]; - auto ptr_expr = to_ptr_expression(ptr); - string stride_expr = to_expression(stride); + auto ptr_expr = to_ptr_expression(ptr); + string stride_expr = to_expression(stride); - // The pointer operand is allowed to use a different element type than the cooperative matrix component type. - // In that case, cast the pointer and convert the stride from source element units to component element units. - auto &mat_type = expression_type(obj); - auto &component_type = get(mat_type.parent_type); - auto &ptr_type = expression_type(ptr); - auto &pointee_type = get(ptr_type.parent_type); - if (pointee_type.self != component_type.self) - { - auto addr_space = get_type_address_space(ptr_type, ptr); - ptr_expr = join("reinterpret_cast<", addr_space, " ", type_to_glsl(component_type), "*>(", ptr_expr, ")"); + // The pointer operand is allowed to use a different element type than the cooperative matrix component type. + // In that case, cast the pointer and convert the stride from source element units to component element units. + auto &mat_type = expression_type(obj); + auto &component_type = get(mat_type.parent_type); + auto &ptr_type = expression_type(ptr); + auto &pointee_type = get(ptr_type.parent_type); + if (pointee_type.self != component_type.self) + { + auto addr_space = get_type_address_space(ptr_type, ptr); + ptr_expr = join("reinterpret_cast<", addr_space, " ", type_to_glsl(component_type), "*>(", ptr_expr, ")"); - uint32_t src_bytes = (pointee_type.width * pointee_type.vecsize) / 8; - uint32_t dst_bytes = (component_type.width * component_type.vecsize) / 8; - if (src_bytes == 0 || dst_bytes == 0) - SPIRV_CROSS_THROW("Cannot determine element size for cooperative matrix load/store."); + uint32_t src_bytes = (pointee_type.width * pointee_type.vecsize) / 8; + uint32_t dst_bytes = (component_type.width * component_type.vecsize) / 8; + if (src_bytes == 0 || dst_bytes == 0) + SPIRV_CROSS_THROW("Cannot determine element size for cooperative matrix load/store."); - if (src_bytes == dst_bytes) - { - // No conversion needed. - } - else if (src_bytes > dst_bytes && (src_bytes % dst_bytes) == 0) - { - uint32_t multiplier = src_bytes / dst_bytes; - stride_expr = join("(", stride_expr, ") * ", multiplier, "u"); - } - else if (src_bytes < dst_bytes && (dst_bytes % src_bytes) == 0) - { - uint32_t divisor = dst_bytes / src_bytes; - stride_expr = join("(", stride_expr, ") / ", divisor, "u"); - } - else - { - stride_expr = join("((", stride_expr, ") * ", src_bytes, "u) / ", dst_bytes, "u"); - } + if (src_bytes == dst_bytes) + { + // No conversion needed. + } + else if (src_bytes > dst_bytes && (src_bytes % dst_bytes) == 0) + { + uint32_t multiplier = src_bytes / dst_bytes; + stride_expr = join("(", stride_expr, ") * ", multiplier, "u"); + } + else if (src_bytes < dst_bytes && (dst_bytes % src_bytes) == 0) + { + uint32_t divisor = dst_bytes / src_bytes; + stride_expr = join("(", stride_expr, ") / ", divisor, "u"); } - - if (col_major) - statement("simdgroup_store(", to_expression(obj), ", ", - ptr_expr, ", ", stride_expr, ", ulong2(0), true);"); else - statement("simdgroup_store(", to_expression(obj), ", ", - ptr_expr, ", ", stride_expr, ");"); - - register_write(ptr); - break; + { + stride_expr = join("((", stride_expr, ") * ", src_bytes, "u) / ", dst_bytes, "u"); + } } + if (col_major) + statement("simdgroup_store(", to_expression(obj), ", ", + ptr_expr, ", ", stride_expr, ", ulong2(0), true);"); + else + statement("simdgroup_store(", to_expression(obj), ", ", + ptr_expr, ", ", stride_expr, ");"); + + register_write(ptr); + break; + } + case OpCooperativeMatrixMulAddKHR: { uint32_t result_type = ops[0]; From 8f3792430d6c55ccb5f367222d2a3a2026f5fa8c Mon Sep 17 00:00:00 2001 From: Hans-Kristian Arntzen Date: Fri, 13 Mar 2026 11:10:47 +0100 Subject: [PATCH 7/8] Simplify unsupported coopmat check. --- spirv_msl.cpp | 54 ++++++++++++++++----------------------------------- 1 file changed, 17 insertions(+), 37 deletions(-) diff --git a/spirv_msl.cpp b/spirv_msl.cpp index 3cbd5dd5c..dc70fc2ac 100644 --- a/spirv_msl.cpp +++ b/spirv_msl.cpp @@ -10909,47 +10909,27 @@ void CompilerMSL::emit_instruction(const Instruction &instruction) default: { - auto is_cooperative_matrix_typed_id = [&](uint32_t typed_id) -> bool { - auto *type = maybe_get(typed_id); - if (!type) - { - auto *var = this->maybe_get(typed_id); - if (var) - type = this->maybe_get(var->basetype); - - auto *expr = this->maybe_get(typed_id); - if (!type && expr) - type = this->maybe_get(expr->expression_type); - - auto *constant = this->maybe_get(typed_id); - if (!type && constant) - type = this->maybe_get(constant->constant_type); - - auto *constant_op = this->maybe_get(typed_id); - if (!type && constant_op) - type = this->maybe_get(constant_op->basetype); - - auto *undef = this->maybe_get(typed_id); - if (!type && undef) - type = this->maybe_get(undef->basetype); - - auto *ac = this->maybe_get(typed_id); - if (!type && ac) - type = this->maybe_get(ac->basetype); - } - - while (type && (is_pointer(*type) || is_array(*type))) - type = this->maybe_get(type->parent_type); - - return type && type->op == OpTypeCooperativeMatrixKHR; - }; - // Prevent GLSL cooperative matrix code from leaking into MSL output. // Element-wise arithmetic on cooperative matrices is not supported in Metal. + // Should cover any reasonable situation we come across. if (instruction.length >= 2) { - if (is_cooperative_matrix_typed_id(ops[0])) - SPIRV_CROSS_THROW("Unsupported operation on cooperative matrix in MSL backend."); + bool has_result = false, has_result_type = false; + HasResultAndType(opcode, &has_result, &has_result_type); + + if (has_result_type) + { + auto *type = &get(ops[0]); + while (type && (is_pointer(*type) || is_array(*type))) + type = this->maybe_get(type->parent_type); + if (type->op == OpTypeCooperativeMatrixKHR) + SPIRV_CROSS_THROW("Unsupported operation on cooperative matrix in MSL backend."); + } + + auto is_cooperative_matrix_typed_id = [&](uint32_t id) -> bool { + auto &type = expression_type(id); + return type.op == OpTypeCooperativeMatrixKHR; + }; if (opcode == OpCompositeExtract || opcode == OpVectorExtractDynamic) { From 7594d2b40ab98f6952c8a7529575520835d1be14 Mon Sep 17 00:00:00 2001 From: Hans-Kristian Arntzen Date: Fri, 13 Mar 2026 11:25:38 +0100 Subject: [PATCH 8/8] Revert questionable change to bitcast. --- ...ice_address-packed-vec-and-cast-to-and-from-uvec2.msl23.comp | 2 +- ...ice_address-packed-vec-and-cast-to-and-from-uvec2.msl23.comp | 2 +- spirv_msl.cpp | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/reference/opt/shaders-msl/comp/buffer_device_address-packed-vec-and-cast-to-and-from-uvec2.msl23.comp b/reference/opt/shaders-msl/comp/buffer_device_address-packed-vec-and-cast-to-and-from-uvec2.msl23.comp index 3cd0c3256..30e89455d 100644 --- a/reference/opt/shaders-msl/comp/buffer_device_address-packed-vec-and-cast-to-and-from-uvec2.msl23.comp +++ b/reference/opt/shaders-msl/comp/buffer_device_address-packed-vec-and-cast-to-and-from-uvec2.msl23.comp @@ -19,7 +19,7 @@ struct SSBO kernel void main0(constant UBO& _10 [[buffer(0)]]) { (reinterpret_cast(as_type(_10.b)))->a1 = float3(1.0, 2.0, 3.0); - device SSBO* _39 = reinterpret_cast(as_type(as_type(reinterpret_cast((reinterpret_cast(as_type(_10.b + uint2(32u)))))))); + device SSBO* _39 = reinterpret_cast(as_type(as_type(reinterpret_cast(reinterpret_cast(as_type(_10.b + uint2(32u))))))); _39->a1 = float3(_39->a1) + float3(1.0); } diff --git a/reference/shaders-msl/comp/buffer_device_address-packed-vec-and-cast-to-and-from-uvec2.msl23.comp b/reference/shaders-msl/comp/buffer_device_address-packed-vec-and-cast-to-and-from-uvec2.msl23.comp index ba8d81a20..f79a8b520 100644 --- a/reference/shaders-msl/comp/buffer_device_address-packed-vec-and-cast-to-and-from-uvec2.msl23.comp +++ b/reference/shaders-msl/comp/buffer_device_address-packed-vec-and-cast-to-and-from-uvec2.msl23.comp @@ -19,7 +19,7 @@ struct SSBO kernel void main0(constant UBO& _10 [[buffer(0)]]) { (reinterpret_cast(as_type(_10.b)))->a1 = float3(1.0, 2.0, 3.0); - uint2 v2 = as_type(reinterpret_cast((reinterpret_cast(as_type(_10.b + uint2(32u)))))); + uint2 v2 = as_type(reinterpret_cast(reinterpret_cast(as_type(_10.b + uint2(32u))))); float3 v3 = float3((reinterpret_cast(as_type(v2)))->a1); (reinterpret_cast(as_type(v2)))->a1 = v3 + float3(1.0); } diff --git a/spirv_msl.cpp b/spirv_msl.cpp index dc70fc2ac..384549363 100644 --- a/spirv_msl.cpp +++ b/spirv_msl.cpp @@ -10529,7 +10529,7 @@ void CompilerMSL::emit_instruction(const Instruction &instruction) if (opcode != OpBitcast || is_pointer(type) || is_pointer(input_type)) { string op; - auto input_expr = is_pointer(input_type) ? to_ptr_expression(ops[2]) : to_unpacked_expression(ops[2]); + auto input_expr = to_unpacked_expression(ops[2]); if ((type.vecsize == 1 || is_pointer(type)) && (input_type.vecsize == 1 || is_pointer(input_type))) op = join("reinterpret_cast<", type_to_glsl(type), ">(", input_expr, ")");