diff --git a/reference/shaders-msl-no-opt/comp/subgroups.nocompat.vk.msl32.comp b/reference/shaders-msl-no-opt/comp/subgroups.nocompat.vk.msl32.comp new file mode 100644 index 000000000..21555ab32 --- /dev/null +++ b/reference/shaders-msl-no-opt/comp/subgroups.nocompat.vk.msl32.comp @@ -0,0 +1,916 @@ +#pragma clang diagnostic ignored "-Wmissing-prototypes" + +#include +#include + +using namespace metal; + +template +inline T spvSubgroupBroadcast(T value, ushort lane) +{ + return simd_broadcast(value, lane); +} + +template<> +inline bool spvSubgroupBroadcast(bool value, ushort lane) +{ + return !!simd_broadcast((ushort)value, lane); +} + +template +inline vec spvSubgroupBroadcast(vec value, ushort lane) +{ + return (vec)simd_broadcast((vec)value, lane); +} + +template +inline T spvSubgroupBroadcastFirst(T value) +{ + return simd_broadcast_first(value); +} + +template<> +inline bool spvSubgroupBroadcastFirst(bool value) +{ + return !!simd_broadcast_first((ushort)value); +} + +template +inline vec spvSubgroupBroadcastFirst(vec value) +{ + return (vec)simd_broadcast_first((vec)value); +} + +inline uint4 spvSubgroupBallot(bool value) +{ + simd_vote vote = simd_ballot(value); + // simd_ballot() returns a 64-bit integer-like object, but + // SPIR-V callers expect a uint4. We must convert. + // FIXME: This won't include higher bits if Apple ever supports + // 128 lanes in an SIMD-group. + return uint4(as_type((simd_vote::vote_t)vote), 0, 0); +} + +inline bool spvSubgroupBallotBitExtract(uint4 ballot, uint bit) +{ + return !!extract_bits(ballot[bit / 32], bit % 32, 1); +} + +inline uint spvSubgroupBallotFindLSB(uint4 ballot, uint gl_SubgroupSize) +{ + uint4 mask = uint4(extract_bits(0xFFFFFFFF, 0, min(gl_SubgroupSize, 32u)), extract_bits(0xFFFFFFFF, 0, (uint)max((int)gl_SubgroupSize - 32, 0)), uint2(0)); + ballot &= mask; + return select(ctz(ballot.x), select(32 + ctz(ballot.y), select(64 + ctz(ballot.z), select(96 + ctz(ballot.w), uint(-1), ballot.w == 0), ballot.z == 0), ballot.y == 0), ballot.x == 0); +} + +inline uint spvSubgroupBallotFindMSB(uint4 ballot, uint gl_SubgroupSize) +{ + uint4 mask = uint4(extract_bits(0xFFFFFFFF, 0, min(gl_SubgroupSize, 32u)), extract_bits(0xFFFFFFFF, 0, (uint)max((int)gl_SubgroupSize - 32, 0)), uint2(0)); + ballot &= mask; + return select(128 - (clz(ballot.w) + 1), select(96 - (clz(ballot.z) + 1), select(64 - (clz(ballot.y) + 1), select(32 - (clz(ballot.x) + 1), uint(-1), ballot.x == 0), ballot.y == 0), ballot.z == 0), ballot.w == 0); +} + +inline uint spvPopCount4(uint4 ballot) +{ + return popcount(ballot.x) + popcount(ballot.y) + popcount(ballot.z) + popcount(ballot.w); +} + +inline uint spvSubgroupBallotBitCount(uint4 ballot, uint gl_SubgroupSize) +{ + uint4 mask = uint4(extract_bits(0xFFFFFFFF, 0, min(gl_SubgroupSize, 32u)), extract_bits(0xFFFFFFFF, 0, (uint)max((int)gl_SubgroupSize - 32, 0)), uint2(0)); + return spvPopCount4(ballot & mask); +} + +inline uint spvSubgroupBallotInclusiveBitCount(uint4 ballot, uint gl_SubgroupInvocationID) +{ + uint4 mask = uint4(extract_bits(0xFFFFFFFF, 0, min(gl_SubgroupInvocationID + 1, 32u)), extract_bits(0xFFFFFFFF, 0, (uint)max((int)gl_SubgroupInvocationID + 1 - 32, 0)), uint2(0)); + return spvPopCount4(ballot & mask); +} + +inline uint spvSubgroupBallotExclusiveBitCount(uint4 ballot, uint gl_SubgroupInvocationID) +{ + uint4 mask = uint4(extract_bits(0xFFFFFFFF, 0, min(gl_SubgroupInvocationID, 32u)), extract_bits(0xFFFFFFFF, 0, (uint)max((int)gl_SubgroupInvocationID - 32, 0)), uint2(0)); + return spvPopCount4(ballot & mask); +} + +template +inline bool spvSubgroupAllEqual(T value) +{ + return simd_all(all(value == simd_broadcast_first(value))); +} + +template<> +inline bool spvSubgroupAllEqual(bool value) +{ + return simd_all(value) || !simd_any(value); +} + +template +inline bool spvSubgroupAllEqual(vec value) +{ + return simd_all(all(value == (vec)simd_broadcast_first((vec)value))); +} + +template +inline T spvSubgroupShuffle(T value, ushort lane) +{ + return simd_shuffle(value, lane); +} + +template<> +inline bool spvSubgroupShuffle(bool value, ushort lane) +{ + return !!simd_shuffle((ushort)value, lane); +} + +template +inline vec spvSubgroupShuffle(vec value, ushort lane) +{ + return (vec)simd_shuffle((vec)value, lane); +} + +template<> +inline ulong spvSubgroupShuffle(ulong value, ushort lane) +{ + return as_type(spvSubgroupShuffle(as_type(value), lane)); +} + +template<> +inline ulong2 spvSubgroupShuffle(ulong2 value, ushort lane) +{ + return ulong2(spvSubgroupShuffle(value.x, lane), spvSubgroupShuffle(value.y, lane)); +} + +inline ulong3 spvSubgroupShuffle(ulong3 value, ushort lane) +{ + return ulong3(spvSubgroupShuffle(value.xy, lane), spvSubgroupShuffle(value.z, lane)); +} + +inline ulong4 spvSubgroupShuffle(ulong4 value, ushort lane) +{ + return ulong4(spvSubgroupShuffle(value.xy, lane), spvSubgroupShuffle(value.zw, lane)); +} + +template +inline vec spvSubgroupShuffle(vec value, ushort lane) +{ + return vec(spvSubgroupShuffle(vec(value), lane)); +} + +template +inline T spvSubgroupShuffleXor(T value, ushort mask) +{ + return simd_shuffle_xor(value, mask); +} + +template<> +inline bool spvSubgroupShuffleXor(bool value, ushort mask) +{ + return !!simd_shuffle_xor((ushort)value, mask); +} + +template +inline vec spvSubgroupShuffleXor(vec value, ushort mask) +{ + return (vec)simd_shuffle_xor((vec)value, mask); +} + +template +inline T spvSubgroupShuffleUp(T value, ushort delta) +{ + return simd_shuffle_up(value, delta); +} + +template<> +inline bool spvSubgroupShuffleUp(bool value, ushort delta) +{ + return !!simd_shuffle_up((ushort)value, delta); +} + +template +inline vec spvSubgroupShuffleUp(vec value, ushort delta) +{ + return (vec)simd_shuffle_up((vec)value, delta); +} + +template +inline T spvSubgroupShuffleDown(T value, ushort delta) +{ + return simd_shuffle_down(value, delta); +} + +template<> +inline bool spvSubgroupShuffleDown(bool value, ushort delta) +{ + return !!simd_shuffle_down((ushort)value, delta); +} + +template +inline vec spvSubgroupShuffleDown(vec value, ushort delta) +{ + return (vec)simd_shuffle_down((vec)value, delta); +} + +template +inline T spvSubgroupRotate(T value, ushort delta) +{ + return simd_shuffle_rotate_down(value, delta); +} + +template<> +inline bool spvSubgroupRotate(bool value, ushort delta) +{ + return !!simd_shuffle_rotate_down((ushort)value, delta); +} + +template +inline vec spvSubgroupRotate(vec value, ushort delta) +{ + return (vec)simd_shuffle_rotate_down((vec)value, delta); +} + +template +struct spvClusteredAddDetail; + +// Base cases +template<> +struct spvClusteredAddDetail<1, 0> +{ + template + static T op(T value, uint) + { + return value; + } +}; + +template +struct spvClusteredAddDetail<1, offset> +{ + template + static T op(T value, uint lid) + { + // If the target lane is inactive, then return identity. + if (!extract_bits(as_type((simd_vote::vote_t)simd_active_threads_mask())[(lid ^ offset) / 32], (lid ^ offset) % 32, 1)) + return 0; + return simd_shuffle_xor(value, offset); + } +}; + +template<> +struct spvClusteredAddDetail<4, 0> +{ + template + static T op(T value, uint) + { + return quad_sum(value); + } +}; + +template +struct spvClusteredAddDetail<4, offset> +{ + template + static T op(T value, uint lid) + { + // Here, we care if any of the lanes in the quad are active. + uint quad_mask = extract_bits(as_type((simd_vote::vote_t)simd_active_threads_mask())[(lid ^ offset) / 32], ((lid ^ offset) % 32) & ~3, 4); + if (!quad_mask) + return 0; + // But we need to make sure we shuffle from an active lane. + return simd_shuffle(quad_sum(value), ((lid ^ offset) & ~3) | ctz(quad_mask)); + } +}; + +// General case +template +struct spvClusteredAddDetail +{ + template + static T op(T value, uint lid) + { + return spvClusteredAddDetail::op(value, lid) + spvClusteredAddDetail::op(value, lid); + } +}; + +template +T spvClustered_sum(T value, uint lid) +{ + return spvClusteredAddDetail::op(value, lid); +} + +template +struct spvClusteredMulDetail; + +// Base cases +template<> +struct spvClusteredMulDetail<1, 0> +{ + template + static T op(T value, uint) + { + return value; + } +}; + +template +struct spvClusteredMulDetail<1, offset> +{ + template + static T op(T value, uint lid) + { + // If the target lane is inactive, then return identity. + if (!extract_bits(as_type((simd_vote::vote_t)simd_active_threads_mask())[(lid ^ offset) / 32], (lid ^ offset) % 32, 1)) + return 1; + return simd_shuffle_xor(value, offset); + } +}; + +template<> +struct spvClusteredMulDetail<4, 0> +{ + template + static T op(T value, uint) + { + return quad_product(value); + } +}; + +template +struct spvClusteredMulDetail<4, offset> +{ + template + static T op(T value, uint lid) + { + // Here, we care if any of the lanes in the quad are active. + uint quad_mask = extract_bits(as_type((simd_vote::vote_t)simd_active_threads_mask())[(lid ^ offset) / 32], ((lid ^ offset) % 32) & ~3, 4); + if (!quad_mask) + return 1; + // But we need to make sure we shuffle from an active lane. + return simd_shuffle(quad_product(value), ((lid ^ offset) & ~3) | ctz(quad_mask)); + } +}; + +// General case +template +struct spvClusteredMulDetail +{ + template + static T op(T value, uint lid) + { + return spvClusteredMulDetail::op(value, lid) * spvClusteredMulDetail::op(value, lid); + } +}; + +template +T spvClustered_product(T value, uint lid) +{ + return spvClusteredMulDetail::op(value, lid); +} + +template +struct spvClusteredMinDetail; + +// Base cases +template<> +struct spvClusteredMinDetail<1, 0> +{ + template + static T op(T value, uint) + { + return value; + } +}; + +template +struct spvClusteredMinDetail<1, offset> +{ + template + static T op(T value, uint lid) + { + // If the target lane is inactive, then return identity. + if (!extract_bits(as_type((simd_vote::vote_t)simd_active_threads_mask())[(lid ^ offset) / 32], (lid ^ offset) % 32, 1)) + return numeric_limits::max(); + return simd_shuffle_xor(value, offset); + } +}; + +template<> +struct spvClusteredMinDetail<4, 0> +{ + template + static T op(T value, uint) + { + return quad_min(value); + } +}; + +template +struct spvClusteredMinDetail<4, offset> +{ + template + static T op(T value, uint lid) + { + // Here, we care if any of the lanes in the quad are active. + uint quad_mask = extract_bits(as_type((simd_vote::vote_t)simd_active_threads_mask())[(lid ^ offset) / 32], ((lid ^ offset) % 32) & ~3, 4); + if (!quad_mask) + return numeric_limits::max(); + // But we need to make sure we shuffle from an active lane. + return simd_shuffle(quad_min(value), ((lid ^ offset) & ~3) | ctz(quad_mask)); + } +}; + +// General case +template +struct spvClusteredMinDetail +{ + template + static T op(T value, uint lid) + { + return min(spvClusteredMinDetail::op(value, lid), spvClusteredMinDetail::op(value, lid)); + } +}; + +template +T spvClustered_min(T value, uint lid) +{ + return spvClusteredMinDetail::op(value, lid); +} + +template +struct spvClusteredMaxDetail; + +// Base cases +template<> +struct spvClusteredMaxDetail<1, 0> +{ + template + static T op(T value, uint) + { + return value; + } +}; + +template +struct spvClusteredMaxDetail<1, offset> +{ + template + static T op(T value, uint lid) + { + // If the target lane is inactive, then return identity. + if (!extract_bits(as_type((simd_vote::vote_t)simd_active_threads_mask())[(lid ^ offset) / 32], (lid ^ offset) % 32, 1)) + return numeric_limits::min(); + return simd_shuffle_xor(value, offset); + } +}; + +template<> +struct spvClusteredMaxDetail<4, 0> +{ + template + static T op(T value, uint) + { + return quad_max(value); + } +}; + +template +struct spvClusteredMaxDetail<4, offset> +{ + template + static T op(T value, uint lid) + { + // Here, we care if any of the lanes in the quad are active. + uint quad_mask = extract_bits(as_type((simd_vote::vote_t)simd_active_threads_mask())[(lid ^ offset) / 32], ((lid ^ offset) % 32) & ~3, 4); + if (!quad_mask) + return numeric_limits::min(); + // But we need to make sure we shuffle from an active lane. + return simd_shuffle(quad_max(value), ((lid ^ offset) & ~3) | ctz(quad_mask)); + } +}; + +// General case +template +struct spvClusteredMaxDetail +{ + template + static T op(T value, uint lid) + { + return max(spvClusteredMaxDetail::op(value, lid), spvClusteredMaxDetail::op(value, lid)); + } +}; + +template +T spvClustered_max(T value, uint lid) +{ + return spvClusteredMaxDetail::op(value, lid); +} + +template +struct spvClusteredAndDetail; + +// Base cases +template<> +struct spvClusteredAndDetail<1, 0> +{ + template + static T op(T value, uint) + { + return value; + } +}; + +template +struct spvClusteredAndDetail<1, offset> +{ + template + static T op(T value, uint lid) + { + // If the target lane is inactive, then return identity. + if (!extract_bits(as_type((simd_vote::vote_t)simd_active_threads_mask())[(lid ^ offset) / 32], (lid ^ offset) % 32, 1)) + return ~T(0); + return simd_shuffle_xor(value, offset); + } +}; + +template<> +struct spvClusteredAndDetail<4, 0> +{ + template + static T op(T value, uint) + { + return quad_and(value); + } +}; + +template +struct spvClusteredAndDetail<4, offset> +{ + template + static T op(T value, uint lid) + { + // Here, we care if any of the lanes in the quad are active. + uint quad_mask = extract_bits(as_type((simd_vote::vote_t)simd_active_threads_mask())[(lid ^ offset) / 32], ((lid ^ offset) % 32) & ~3, 4); + if (!quad_mask) + return ~T(0); + // But we need to make sure we shuffle from an active lane. + return simd_shuffle(quad_and(value), ((lid ^ offset) & ~3) | ctz(quad_mask)); + } +}; + +// General case +template +struct spvClusteredAndDetail +{ + template + static T op(T value, uint lid) + { + return spvClusteredAndDetail::op(value, lid) & spvClusteredAndDetail::op(value, lid); + } +}; + +template +T spvClustered_and(T value, uint lid) +{ + return spvClusteredAndDetail::op(value, lid); +} + +template +struct spvClusteredOrDetail; + +// Base cases +template<> +struct spvClusteredOrDetail<1, 0> +{ + template + static T op(T value, uint) + { + return value; + } +}; + +template +struct spvClusteredOrDetail<1, offset> +{ + template + static T op(T value, uint lid) + { + // If the target lane is inactive, then return identity. + if (!extract_bits(as_type((simd_vote::vote_t)simd_active_threads_mask())[(lid ^ offset) / 32], (lid ^ offset) % 32, 1)) + return 0; + return simd_shuffle_xor(value, offset); + } +}; + +template<> +struct spvClusteredOrDetail<4, 0> +{ + template + static T op(T value, uint) + { + return quad_or(value); + } +}; + +template +struct spvClusteredOrDetail<4, offset> +{ + template + static T op(T value, uint lid) + { + // Here, we care if any of the lanes in the quad are active. + uint quad_mask = extract_bits(as_type((simd_vote::vote_t)simd_active_threads_mask())[(lid ^ offset) / 32], ((lid ^ offset) % 32) & ~3, 4); + if (!quad_mask) + return 0; + // But we need to make sure we shuffle from an active lane. + return simd_shuffle(quad_or(value), ((lid ^ offset) & ~3) | ctz(quad_mask)); + } +}; + +// General case +template +struct spvClusteredOrDetail +{ + template + static T op(T value, uint lid) + { + return spvClusteredOrDetail::op(value, lid) | spvClusteredOrDetail::op(value, lid); + } +}; + +template +T spvClustered_or(T value, uint lid) +{ + return spvClusteredOrDetail::op(value, lid); +} + +template +struct spvClusteredXorDetail; + +// Base cases +template<> +struct spvClusteredXorDetail<1, 0> +{ + template + static T op(T value, uint) + { + return value; + } +}; + +template +struct spvClusteredXorDetail<1, offset> +{ + template + static T op(T value, uint lid) + { + // If the target lane is inactive, then return identity. + if (!extract_bits(as_type((simd_vote::vote_t)simd_active_threads_mask())[(lid ^ offset) / 32], (lid ^ offset) % 32, 1)) + return 0; + return simd_shuffle_xor(value, offset); + } +}; + +template<> +struct spvClusteredXorDetail<4, 0> +{ + template + static T op(T value, uint) + { + return quad_xor(value); + } +}; + +template +struct spvClusteredXorDetail<4, offset> +{ + template + static T op(T value, uint lid) + { + // Here, we care if any of the lanes in the quad are active. + uint quad_mask = extract_bits(as_type((simd_vote::vote_t)simd_active_threads_mask())[(lid ^ offset) / 32], ((lid ^ offset) % 32) & ~3, 4); + if (!quad_mask) + return 0; + // But we need to make sure we shuffle from an active lane. + return simd_shuffle(quad_xor(value), ((lid ^ offset) & ~3) | ctz(quad_mask)); + } +}; + +// General case +template +struct spvClusteredXorDetail +{ + template + static T op(T value, uint lid) + { + return spvClusteredXorDetail::op(value, lid) ^ spvClusteredXorDetail::op(value, lid); + } +}; + +template +T spvClustered_xor(T value, uint lid) +{ + return spvClusteredXorDetail::op(value, lid); +} + +template +inline T spvQuadBroadcast(T value, uint lane) +{ + return quad_broadcast(value, lane); +} + +template<> +inline bool spvQuadBroadcast(bool value, uint lane) +{ + return !!quad_broadcast((ushort)value, lane); +} + +template +inline vec spvQuadBroadcast(vec value, uint lane) +{ + return (vec)quad_broadcast((vec)value, lane); +} + +template +inline T spvQuadSwap(T value, uint dir) +{ + return quad_shuffle_xor(value, dir + 1); +} + +template<> +inline bool spvQuadSwap(bool value, uint dir) +{ + return !!quad_shuffle_xor((ushort)value, dir + 1); +} + +template +inline vec spvQuadSwap(vec value, uint dir) +{ + return (vec)quad_shuffle_xor((vec)value, dir + 1); +} + +struct SSBO +{ + float FragColor; +}; + +constant uint3 gl_WorkGroupSize [[maybe_unused]] = uint3(1u); + +static inline __attribute__((always_inline)) +void doClusteredRotate(thread uint& gl_SubgroupInvocationID) +{ + uint _15 = spvSubgroupShuffle(20u, ((gl_SubgroupInvocationID + 4u) & 7) + (gl_SubgroupInvocationID & 4294967288)); + uint rotated_clustered = _15; + bool _20 = spvSubgroupShuffle(false, ((gl_SubgroupInvocationID + 4u) & 7) + (gl_SubgroupInvocationID & 4294967288)); + bool rotated_clustered_bool = _20; +} + +kernel void main0(device SSBO& _24 [[buffer(0)]], uint gl_NumSubgroups [[simdgroups_per_threadgroup]], uint gl_SubgroupID [[simdgroup_index_in_threadgroup]], uint gl_SubgroupSize [[threads_per_simdgroup]], uint gl_SubgroupInvocationID [[thread_index_in_simdgroup]]) +{ + uint4 gl_SubgroupEqMask = gl_SubgroupInvocationID >= 32 ? uint4(0, (1 << (gl_SubgroupInvocationID - 32)), uint2(0)) : uint4(1 << gl_SubgroupInvocationID, uint3(0)); + uint4 gl_SubgroupGeMask = uint4(insert_bits(0u, 0xFFFFFFFF, min(gl_SubgroupInvocationID, 32u), (uint)max(min((int)gl_SubgroupSize, 32) - (int)gl_SubgroupInvocationID, 0)), insert_bits(0u, 0xFFFFFFFF, (uint)max((int)gl_SubgroupInvocationID - 32, 0), (uint)max((int)gl_SubgroupSize - (int)max(gl_SubgroupInvocationID, 32u), 0)), uint2(0)); + uint4 gl_SubgroupGtMask = uint4(insert_bits(0u, 0xFFFFFFFF, min(gl_SubgroupInvocationID + 1, 32u), (uint)max(min((int)gl_SubgroupSize, 32) - (int)gl_SubgroupInvocationID - 1, 0)), insert_bits(0u, 0xFFFFFFFF, (uint)max((int)gl_SubgroupInvocationID + 1 - 32, 0), (uint)max((int)gl_SubgroupSize - (int)max(gl_SubgroupInvocationID + 1, 32u), 0)), uint2(0)); + uint4 gl_SubgroupLeMask = uint4(extract_bits(0xFFFFFFFF, 0, min(gl_SubgroupInvocationID + 1, 32u)), extract_bits(0xFFFFFFFF, 0, (uint)max((int)gl_SubgroupInvocationID + 1 - 32, 0)), uint2(0)); + uint4 gl_SubgroupLtMask = uint4(extract_bits(0xFFFFFFFF, 0, min(gl_SubgroupInvocationID, 32u)), extract_bits(0xFFFFFFFF, 0, (uint)max((int)gl_SubgroupInvocationID - 32, 0)), uint2(0)); + _24.FragColor = float(gl_NumSubgroups); + _24.FragColor = float(gl_SubgroupID); + _24.FragColor = float(gl_SubgroupSize); + _24.FragColor = float(gl_SubgroupInvocationID); + simdgroup_barrier(mem_flags::mem_device | mem_flags::mem_threadgroup | mem_flags::mem_texture); + atomic_thread_fence(mem_flags::mem_device | mem_flags::mem_threadgroup | mem_flags::mem_texture, memory_order_seq_cst, thread_scope_simdgroup); + atomic_thread_fence(mem_flags::mem_device, memory_order_seq_cst, thread_scope_simdgroup); + atomic_thread_fence(mem_flags::mem_threadgroup, memory_order_seq_cst, thread_scope_simdgroup); + atomic_thread_fence(mem_flags::mem_texture, memory_order_seq_cst, thread_scope_simdgroup); + bool _50 = simd_is_first(); + bool elected = _50; + _24.FragColor = float4(gl_SubgroupEqMask).x; + _24.FragColor = float4(gl_SubgroupGeMask).x; + _24.FragColor = float4(gl_SubgroupGtMask).x; + _24.FragColor = float4(gl_SubgroupLeMask).x; + _24.FragColor = float4(gl_SubgroupLtMask).x; + float4 broadcasted = spvSubgroupBroadcast(float4(10.0), 8u); + bool2 broadcasted_bool = spvSubgroupBroadcast(bool2(true), 8u); + float3 first = spvSubgroupBroadcastFirst(float3(20.0)); + bool4 first_bool = spvSubgroupBroadcastFirst(bool4(false)); + uint4 ballot_value = spvSubgroupBallot(true); + bool inverse_ballot_value = spvSubgroupBallotBitExtract(ballot_value, gl_SubgroupInvocationID); + bool bit_extracted = spvSubgroupBallotBitExtract(uint4(10u), 8u); + uint bit_count = spvSubgroupBallotBitCount(ballot_value, gl_SubgroupSize); + uint inclusive_bit_count = spvSubgroupBallotInclusiveBitCount(ballot_value, gl_SubgroupInvocationID); + uint exclusive_bit_count = spvSubgroupBallotExclusiveBitCount(ballot_value, gl_SubgroupInvocationID); + uint lsb = spvSubgroupBallotFindLSB(ballot_value, gl_SubgroupSize); + uint msb = spvSubgroupBallotFindMSB(ballot_value, gl_SubgroupSize); + uint shuffled = spvSubgroupShuffle(10u, 8u); + bool shuffled_bool = spvSubgroupShuffle(true, 9u); + uint shuffled_xor = spvSubgroupShuffleXor(30u, 8u); + bool shuffled_xor_bool = spvSubgroupShuffleXor(false, 9u); + uint shuffled_up = spvSubgroupShuffleUp(20u, 4u); + bool shuffled_up_bool = spvSubgroupShuffleUp(true, 4u); + uint shuffled_down = spvSubgroupShuffleDown(20u, 4u); + bool shuffled_down_bool = spvSubgroupShuffleDown(false, 4u); + uint rotated = spvSubgroupRotate(20u, 4u); + bool rotated_bool = spvSubgroupRotate(false, 4u); + doClusteredRotate(gl_SubgroupInvocationID); + bool has_all = simd_all(true); + bool has_any = simd_any(true); + bool has_equal = spvSubgroupAllEqual(0); + has_equal = spvSubgroupAllEqual(true); + has_equal = spvSubgroupAllEqual(float3(0.0, 1.0, 2.0)); + has_equal = spvSubgroupAllEqual(bool4(true, true, false, true)); + float4 added = simd_sum(float4(20.0)); + int4 iadded = simd_sum(int4(20)); + float4 multiplied = simd_product(float4(20.0)); + int4 imultiplied = simd_product(int4(20)); + float4 lo = simd_min(float4(20.0)); + float4 hi = simd_max(float4(20.0)); + int4 slo = simd_min(int4(20)); + int4 shi = simd_max(int4(20)); + uint4 ulo = simd_min(uint4(20u)); + uint4 uhi = simd_max(uint4(20u)); + uint4 anded = simd_and(ballot_value); + uint4 ored = simd_or(ballot_value); + uint4 xored = simd_xor(ballot_value); + bool4 anded_b = bool4(simd_and(ushort4(ballot_value == uint4(42u)))); + bool4 ored_b = bool4(simd_or(ushort4(ballot_value == uint4(42u)))); + bool4 xored_b = bool4(simd_xor(ushort4(ballot_value == uint4(42u)))); + added = simd_prefix_inclusive_sum(added); + iadded = simd_prefix_inclusive_sum(iadded); + multiplied = simd_prefix_inclusive_product(multiplied); + imultiplied = simd_prefix_inclusive_product(imultiplied); + added = simd_prefix_exclusive_sum(multiplied); + multiplied = simd_prefix_exclusive_product(multiplied); + iadded = simd_prefix_exclusive_sum(imultiplied); + imultiplied = simd_prefix_exclusive_product(imultiplied); + added = spvClustered_sum<1>(added, gl_SubgroupInvocationID); + multiplied = spvClustered_product<1>(multiplied, gl_SubgroupInvocationID); + iadded = spvClustered_sum<1>(iadded, gl_SubgroupInvocationID); + imultiplied = spvClustered_product<1>(imultiplied, gl_SubgroupInvocationID); + lo = spvClustered_min<1>(lo, gl_SubgroupInvocationID); + hi = spvClustered_max<1>(hi, gl_SubgroupInvocationID); + ulo = spvClustered_min<1>(ulo, gl_SubgroupInvocationID); + uhi = spvClustered_max<1>(uhi, gl_SubgroupInvocationID); + slo = spvClustered_min<1>(slo, gl_SubgroupInvocationID); + shi = spvClustered_max<1>(shi, gl_SubgroupInvocationID); + anded = spvClustered_and<1>(anded, gl_SubgroupInvocationID); + ored = spvClustered_or<1>(ored, gl_SubgroupInvocationID); + xored = spvClustered_xor<1>(xored, gl_SubgroupInvocationID); + anded_b = bool4(spvClustered_and<1>(ushort4(anded == uint4(2u)), gl_SubgroupInvocationID)); + ored_b = bool4(spvClustered_or<1>(ushort4(ored == uint4(3u)), gl_SubgroupInvocationID)); + xored_b = bool4(spvClustered_xor<1>(ushort4(xored == uint4(4u)), gl_SubgroupInvocationID)); + added = spvClustered_sum<2>(added, gl_SubgroupInvocationID); + multiplied = spvClustered_product<2>(multiplied, gl_SubgroupInvocationID); + iadded = spvClustered_sum<2>(iadded, gl_SubgroupInvocationID); + imultiplied = spvClustered_product<2>(imultiplied, gl_SubgroupInvocationID); + lo = spvClustered_min<2>(lo, gl_SubgroupInvocationID); + hi = spvClustered_max<2>(hi, gl_SubgroupInvocationID); + ulo = spvClustered_min<2>(ulo, gl_SubgroupInvocationID); + uhi = spvClustered_max<2>(uhi, gl_SubgroupInvocationID); + slo = spvClustered_min<2>(slo, gl_SubgroupInvocationID); + shi = spvClustered_max<2>(shi, gl_SubgroupInvocationID); + anded = spvClustered_and<2>(anded, gl_SubgroupInvocationID); + ored = spvClustered_or<2>(ored, gl_SubgroupInvocationID); + xored = spvClustered_xor<2>(xored, gl_SubgroupInvocationID); + anded_b = bool4(spvClustered_and<2>(ushort4(anded == uint4(2u)), gl_SubgroupInvocationID)); + ored_b = bool4(spvClustered_or<2>(ushort4(ored == uint4(3u)), gl_SubgroupInvocationID)); + xored_b = bool4(spvClustered_xor<2>(ushort4(xored == uint4(4u)), gl_SubgroupInvocationID)); + added = spvClustered_sum<4>(added, gl_SubgroupInvocationID); + multiplied = spvClustered_product<4>(multiplied, gl_SubgroupInvocationID); + iadded = spvClustered_sum<4>(iadded, gl_SubgroupInvocationID); + imultiplied = spvClustered_product<4>(imultiplied, gl_SubgroupInvocationID); + lo = spvClustered_min<4>(lo, gl_SubgroupInvocationID); + hi = spvClustered_max<4>(hi, gl_SubgroupInvocationID); + ulo = spvClustered_min<4>(ulo, gl_SubgroupInvocationID); + uhi = spvClustered_max<4>(uhi, gl_SubgroupInvocationID); + slo = spvClustered_min<4>(slo, gl_SubgroupInvocationID); + shi = spvClustered_max<4>(shi, gl_SubgroupInvocationID); + anded = spvClustered_and<4>(anded, gl_SubgroupInvocationID); + ored = spvClustered_or<4>(ored, gl_SubgroupInvocationID); + xored = spvClustered_xor<4>(xored, gl_SubgroupInvocationID); + anded_b = bool4(spvClustered_and<4>(ushort4(anded == uint4(2u)), gl_SubgroupInvocationID)); + ored_b = bool4(spvClustered_or<4>(ushort4(ored == uint4(3u)), gl_SubgroupInvocationID)); + xored_b = bool4(spvClustered_xor<4>(ushort4(xored == uint4(4u)), gl_SubgroupInvocationID)); + added = spvClustered_sum<16>(added, gl_SubgroupInvocationID); + multiplied = spvClustered_product<16>(multiplied, gl_SubgroupInvocationID); + iadded = spvClustered_sum<16>(iadded, gl_SubgroupInvocationID); + imultiplied = spvClustered_product<16>(imultiplied, gl_SubgroupInvocationID); + lo = spvClustered_min<16>(lo, gl_SubgroupInvocationID); + hi = spvClustered_max<16>(hi, gl_SubgroupInvocationID); + ulo = spvClustered_min<16>(ulo, gl_SubgroupInvocationID); + uhi = spvClustered_max<16>(uhi, gl_SubgroupInvocationID); + slo = spvClustered_min<16>(slo, gl_SubgroupInvocationID); + shi = spvClustered_max<16>(shi, gl_SubgroupInvocationID); + anded = spvClustered_and<16>(anded, gl_SubgroupInvocationID); + ored = spvClustered_or<16>(ored, gl_SubgroupInvocationID); + xored = spvClustered_xor<16>(xored, gl_SubgroupInvocationID); + anded_b = bool4(spvClustered_and<16>(ushort4(anded == uint4(2u)), gl_SubgroupInvocationID)); + ored_b = bool4(spvClustered_or<16>(ushort4(ored == uint4(3u)), gl_SubgroupInvocationID)); + xored_b = bool4(spvClustered_xor<16>(ushort4(xored == uint4(4u)), gl_SubgroupInvocationID)); + float4 swap_horiz = spvQuadSwap(float4(20.0), 0u); + bool4 swap_horiz_bool = spvQuadSwap(bool4(true), 0u); + float4 swap_vertical = spvQuadSwap(float4(20.0), 1u); + bool4 swap_vertical_bool = spvQuadSwap(bool4(true), 1u); + float4 swap_diagonal = spvQuadSwap(float4(20.0), 2u); + bool4 swap_diagonal_bool = spvQuadSwap(bool4(true), 2u); + float4 quad_broadcast0 = spvQuadBroadcast(float4(20.0), 3u); + bool4 quad_broadcast_bool = spvQuadBroadcast(bool4(true), 3u); +} + diff --git a/shaders-msl-no-opt/comp/subgroups.nocompat.vk.msl32.comp b/shaders-msl-no-opt/comp/subgroups.nocompat.vk.msl32.comp new file mode 100644 index 000000000..c8172fd95 --- /dev/null +++ b/shaders-msl-no-opt/comp/subgroups.nocompat.vk.msl32.comp @@ -0,0 +1,211 @@ +#version 450 +#extension GL_KHR_shader_subgroup_basic : require +#extension GL_KHR_shader_subgroup_ballot : require +#extension GL_KHR_shader_subgroup_vote : require +#extension GL_KHR_shader_subgroup_shuffle : require +#extension GL_KHR_shader_subgroup_shuffle_relative : require +#extension GL_KHR_shader_subgroup_arithmetic : require +#extension GL_KHR_shader_subgroup_clustered : require +#extension GL_KHR_shader_subgroup_quad : require +#extension GL_KHR_shader_subgroup_rotate : require +layout(local_size_x = 1) in; + +layout(std430, binding = 0) buffer SSBO +{ + float FragColor; +}; + +void doClusteredRotate() +{ + uint rotated_clustered = subgroupClusteredRotate(20u, 4u, 8u); + bool rotated_clustered_bool = subgroupClusteredRotate(false, 4u, 8u); +} + +void main() +{ + // basic + FragColor = float(gl_NumSubgroups); + FragColor = float(gl_SubgroupID); + FragColor = float(gl_SubgroupSize); + FragColor = float(gl_SubgroupInvocationID); + subgroupBarrier(); + subgroupMemoryBarrier(); + subgroupMemoryBarrierBuffer(); + subgroupMemoryBarrierShared(); + subgroupMemoryBarrierImage(); + bool elected = subgroupElect(); + + // ballot + FragColor = float(gl_SubgroupEqMask); + FragColor = float(gl_SubgroupGeMask); + FragColor = float(gl_SubgroupGtMask); + FragColor = float(gl_SubgroupLeMask); + FragColor = float(gl_SubgroupLtMask); + vec4 broadcasted = subgroupBroadcast(vec4(10.0), 8u); + bvec2 broadcasted_bool = subgroupBroadcast(bvec2(true), 8u); + vec3 first = subgroupBroadcastFirst(vec3(20.0)); + bvec4 first_bool = subgroupBroadcastFirst(bvec4(false)); + uvec4 ballot_value = subgroupBallot(true); + bool inverse_ballot_value = subgroupInverseBallot(ballot_value); + bool bit_extracted = subgroupBallotBitExtract(uvec4(10u), 8u); + uint bit_count = subgroupBallotBitCount(ballot_value); + uint inclusive_bit_count = subgroupBallotInclusiveBitCount(ballot_value); + uint exclusive_bit_count = subgroupBallotExclusiveBitCount(ballot_value); + uint lsb = subgroupBallotFindLSB(ballot_value); + uint msb = subgroupBallotFindMSB(ballot_value); + + // shuffle + uint shuffled = subgroupShuffle(10u, 8u); + bool shuffled_bool = subgroupShuffle(true, 9u); + uint shuffled_xor = subgroupShuffleXor(30u, 8u); + bool shuffled_xor_bool = subgroupShuffleXor(false, 9u); + + // shuffle relative + uint shuffled_up = subgroupShuffleUp(20u, 4u); + bool shuffled_up_bool = subgroupShuffleUp(true, 4u); + uint shuffled_down = subgroupShuffleDown(20u, 4u); + bool shuffled_down_bool = subgroupShuffleDown(false, 4u); + + // rotate + uint rotated = subgroupRotate(20u, 4u); + bool rotated_bool = subgroupRotate(false, 4u); + doClusteredRotate(); + + // vote + bool has_all = subgroupAll(true); + bool has_any = subgroupAny(true); + bool has_equal = subgroupAllEqual(0); + has_equal = subgroupAllEqual(true); + has_equal = subgroupAllEqual(vec3(0.0, 1.0, 2.0)); + has_equal = subgroupAllEqual(bvec4(true, true, false, true)); + + // arithmetic + vec4 added = subgroupAdd(vec4(20.0)); + ivec4 iadded = subgroupAdd(ivec4(20)); + vec4 multiplied = subgroupMul(vec4(20.0)); + ivec4 imultiplied = subgroupMul(ivec4(20)); + vec4 lo = subgroupMin(vec4(20.0)); + vec4 hi = subgroupMax(vec4(20.0)); + ivec4 slo = subgroupMin(ivec4(20)); + ivec4 shi = subgroupMax(ivec4(20)); + uvec4 ulo = subgroupMin(uvec4(20)); + uvec4 uhi = subgroupMax(uvec4(20)); + uvec4 anded = subgroupAnd(ballot_value); + uvec4 ored = subgroupOr(ballot_value); + uvec4 xored = subgroupXor(ballot_value); + bvec4 anded_b = subgroupAnd(equal(ballot_value, uvec4(42))); + bvec4 ored_b = subgroupOr(equal(ballot_value, uvec4(42))); + bvec4 xored_b = subgroupXor(equal(ballot_value, uvec4(42))); + + added = subgroupInclusiveAdd(added); + iadded = subgroupInclusiveAdd(iadded); + multiplied = subgroupInclusiveMul(multiplied); + imultiplied = subgroupInclusiveMul(imultiplied); + //lo = subgroupInclusiveMin(lo); // FIXME: Unsupported by Metal + //hi = subgroupInclusiveMax(hi); + //slo = subgroupInclusiveMin(slo); + //shi = subgroupInclusiveMax(shi); + //ulo = subgroupInclusiveMin(ulo); + //uhi = subgroupInclusiveMax(uhi); + //anded = subgroupInclusiveAnd(anded); + //ored = subgroupInclusiveOr(ored); + //xored = subgroupInclusiveXor(ored); + //added = subgroupExclusiveAdd(lo); + + added = subgroupExclusiveAdd(multiplied); + multiplied = subgroupExclusiveMul(multiplied); + iadded = subgroupExclusiveAdd(imultiplied); + imultiplied = subgroupExclusiveMul(imultiplied); + //lo = subgroupExclusiveMin(lo); // FIXME: Unsupported by Metal + //hi = subgroupExclusiveMax(hi); + //ulo = subgroupExclusiveMin(ulo); + //uhi = subgroupExclusiveMax(uhi); + //slo = subgroupExclusiveMin(slo); + //shi = subgroupExclusiveMax(shi); + //anded = subgroupExclusiveAnd(anded); + //ored = subgroupExclusiveOr(ored); + //xored = subgroupExclusiveXor(ored); + + // clustered + added = subgroupClusteredAdd(added, 1u); + multiplied = subgroupClusteredMul(multiplied, 1u); + iadded = subgroupClusteredAdd(iadded, 1u); + imultiplied = subgroupClusteredMul(imultiplied, 1u); + lo = subgroupClusteredMin(lo, 1u); + hi = subgroupClusteredMax(hi, 1u); + ulo = subgroupClusteredMin(ulo, 1u); + uhi = subgroupClusteredMax(uhi, 1u); + slo = subgroupClusteredMin(slo, 1u); + shi = subgroupClusteredMax(shi, 1u); + anded = subgroupClusteredAnd(anded, 1u); + ored = subgroupClusteredOr(ored, 1u); + xored = subgroupClusteredXor(xored, 1u); + + anded_b = subgroupClusteredAnd(equal(anded, uvec4(2u)), 1u); + ored_b = subgroupClusteredOr(equal(ored, uvec4(3u)), 1u); + xored_b = subgroupClusteredXor(equal(xored, uvec4(4u)), 1u); + + added = subgroupClusteredAdd(added, 2u); + multiplied = subgroupClusteredMul(multiplied, 2u); + iadded = subgroupClusteredAdd(iadded, 2u); + imultiplied = subgroupClusteredMul(imultiplied, 2u); + lo = subgroupClusteredMin(lo, 2u); + hi = subgroupClusteredMax(hi, 2u); + ulo = subgroupClusteredMin(ulo, 2u); + uhi = subgroupClusteredMax(uhi, 2u); + slo = subgroupClusteredMin(slo, 2u); + shi = subgroupClusteredMax(shi, 2u); + anded = subgroupClusteredAnd(anded, 2u); + ored = subgroupClusteredOr(ored, 2u); + xored = subgroupClusteredXor(xored, 2u); + + anded_b = subgroupClusteredAnd(equal(anded, uvec4(2u)), 2u); + ored_b = subgroupClusteredOr(equal(ored, uvec4(3u)), 2u); + xored_b = subgroupClusteredXor(equal(xored, uvec4(4u)), 2u); + + added = subgroupClusteredAdd(added, 4u); + multiplied = subgroupClusteredMul(multiplied, 4u); + iadded = subgroupClusteredAdd(iadded, 4u); + imultiplied = subgroupClusteredMul(imultiplied, 4u); + lo = subgroupClusteredMin(lo, 4u); + hi = subgroupClusteredMax(hi, 4u); + ulo = subgroupClusteredMin(ulo, 4u); + uhi = subgroupClusteredMax(uhi, 4u); + slo = subgroupClusteredMin(slo, 4u); + shi = subgroupClusteredMax(shi, 4u); + anded = subgroupClusteredAnd(anded, 4u); + ored = subgroupClusteredOr(ored, 4u); + xored = subgroupClusteredXor(xored, 4u); + + anded_b = subgroupClusteredAnd(equal(anded, uvec4(2u)), 4u); + ored_b = subgroupClusteredOr(equal(ored, uvec4(3u)), 4u); + xored_b = subgroupClusteredXor(equal(xored, uvec4(4u)), 4u); + + added = subgroupClusteredAdd(added, 16u); + multiplied = subgroupClusteredMul(multiplied, 16u); + iadded = subgroupClusteredAdd(iadded, 16u); + imultiplied = subgroupClusteredMul(imultiplied, 16u); + lo = subgroupClusteredMin(lo, 16u); + hi = subgroupClusteredMax(hi, 16u); + ulo = subgroupClusteredMin(ulo, 16u); + uhi = subgroupClusteredMax(uhi, 16u); + slo = subgroupClusteredMin(slo, 16u); + shi = subgroupClusteredMax(shi, 16u); + anded = subgroupClusteredAnd(anded, 16u); + ored = subgroupClusteredOr(ored, 16u); + xored = subgroupClusteredXor(xored, 16u); + + anded_b = subgroupClusteredAnd(equal(anded, uvec4(2u)), 16u); + ored_b = subgroupClusteredOr(equal(ored, uvec4(3u)), 16u); + xored_b = subgroupClusteredXor(equal(xored, uvec4(4u)), 16u); + + // quad + vec4 swap_horiz = subgroupQuadSwapHorizontal(vec4(20.0)); + bvec4 swap_horiz_bool = subgroupQuadSwapHorizontal(bvec4(true)); + vec4 swap_vertical = subgroupQuadSwapVertical(vec4(20.0)); + bvec4 swap_vertical_bool = subgroupQuadSwapVertical(bvec4(true)); + vec4 swap_diagonal = subgroupQuadSwapDiagonal(vec4(20.0)); + bvec4 swap_diagonal_bool = subgroupQuadSwapDiagonal(bvec4(true)); + vec4 quad_broadcast = subgroupQuadBroadcast(vec4(20.0), 3u); + bvec4 quad_broadcast_bool = subgroupQuadBroadcast(bvec4(true), 3u); +} diff --git a/spirv_msl.cpp b/spirv_msl.cpp index 6cf6b11b9..e6136991d 100644 --- a/spirv_msl.cpp +++ b/spirv_msl.cpp @@ -11103,7 +11103,7 @@ void CompilerMSL::emit_barrier(uint32_t id_exe_scope, uint32_t id_mem_scope, uin break; case ScopeSubgroup: - bar_stmt += ", thread_scope_subgroup"; + bar_stmt += ", thread_scope_simdgroup"; break; case ScopeInvocation: