diff --git a/spirv_hlsl.cpp b/spirv_hlsl.cpp index a18fa3c60..69b979578 100644 --- a/spirv_hlsl.cpp +++ b/spirv_hlsl.cpp @@ -769,6 +769,10 @@ void CompilerHLSL::emit_builtin_inputs_in_struct() auto builtin = static_cast(i); switch (builtin) { + case BuiltInPosition: + type = "float4"; + semantic = legacy ? "POSITION" : "SV_Position"; + break; case BuiltInFragCoord: type = "float4"; semantic = legacy ? "VPOS" : "SV_Position"; @@ -783,8 +787,27 @@ void CompilerHLSL::emit_builtin_inputs_in_struct() break; case BuiltInPrimitiveId: - type = "uint"; - semantic = "SV_PrimitiveID"; + // For geometry shaders, PrimitiveId is a direct function parameter + // (SV_PrimitiveID), not part of the input struct. + if (get_entry_point().model != ExecutionModelGeometry) + { + type = "uint"; + semantic = "SV_PrimitiveID"; + } + break; + + case BuiltInInvocationId: + if (get_entry_point().model == ExecutionModelGeometry) + { + type = "uint"; + semantic = "SV_GSInstanceID"; + } + else if (get_entry_point().model != ExecutionModelTessellationControl) + { + // For tesc, InvocationId is a direct function parameter (SV_OutputControlPointID), + // not part of the input struct. + SPIRV_CROSS_THROW("InvocationId is only supported in geometry and tessellation control shaders."); + } break; case BuiltInInstanceId: @@ -1139,8 +1162,9 @@ void CompilerHLSL::emit_interface_block_in_struct(const SPIRVariable &var, unord (execution.model == ExecutionModelGeometry && var.storage == StorageClassInput) || has_decoration(var.self, DecorationPerVertexKHR)) { - decl_type.array.erase(decl_type.array.begin()); - decl_type.array_size_literal.erase(decl_type.array_size_literal.begin()); + // The per-vertex/per-CP dimension is the outermost (last element in array vector). + decl_type.array.pop_back(); + decl_type.array_size_literal.pop_back(); } statement(to_interpolation_qualifiers(get_decoration_bitset(var.self)), variable_decl(decl_type, name), " : ", semantic, ";"); @@ -1164,6 +1188,9 @@ std::string CompilerHLSL::builtin_to_glsl(BuiltIn builtin, StorageClass storage) { switch (builtin) { + case BuiltInPosition: + // We want to avoid clash between input/output for geometry shader + return storage == StorageClass::StorageClassInput ? "gl_PositionIn" : "gl_Position"; case BuiltInVertexId: return "gl_VertexID"; case BuiltInInstanceId: @@ -1247,9 +1274,7 @@ void CompilerHLSL::emit_builtin_variables() // Emit global variables for the interface variables which are statically used by the shader. builtins.for_each_bit([&](uint32_t i) { - const char *type = nullptr; auto builtin = static_cast(i); - uint32_t array_size = 0; string init_expr; auto init_itr = builtin_to_initializer.find(builtin); @@ -1268,147 +1293,163 @@ void CompilerHLSL::emit_builtin_variables() } } - switch (builtin) + // If we need to emit 2 separate variables (for both input & output), we'll update this value + bool has_separate_input_output = false; + for (int variable_index = 0; variable_index < (has_separate_input_output ? 2 : 1); variable_index++) { - case BuiltInFragCoord: - case BuiltInPosition: - type = "float4"; - break; + uint32_t array_size = 0; + StorageClass storage = active_input_builtins.get(i) && variable_index == 0 + ? StorageClassInput + : StorageClassOutput; + const char *type = nullptr; + switch (builtin) + { + case BuiltInFragCoord: + type = "float4"; + break; - case BuiltInFragDepth: - type = "float"; - break; + case BuiltInPosition: + type = "float4"; + if (storage == StorageClass::StorageClassInput && + (get_execution_model() == ExecutionModelGeometry || + get_execution_model() == ExecutionModelTessellationControl)) + array_size = input_vertices_from_execution_mode(get_entry_point()); + break; - case BuiltInVertexId: - case BuiltInVertexIndex: - case BuiltInInstanceIndex: - type = "int"; - if (hlsl_options.support_nonzero_base_vertex_base_instance || hlsl_options.shader_model >= 68) - base_vertex_info.used = true; - break; + case BuiltInFragDepth: + type = "float"; + break; - case BuiltInBaseVertex: - case BuiltInBaseInstance: - type = "int"; - base_vertex_info.used = true; - break; + case BuiltInVertexId: + case BuiltInVertexIndex: + case BuiltInInstanceIndex: + type = "int"; + if (hlsl_options.support_nonzero_base_vertex_base_instance || hlsl_options.shader_model >= 68) + base_vertex_info.used = true; + break; - case BuiltInInstanceId: - case BuiltInSampleId: - type = "int"; - break; + case BuiltInBaseVertex: + case BuiltInBaseInstance: + type = "int"; + base_vertex_info.used = true; + break; - case BuiltInPointSize: - if (hlsl_options.point_size_compat || hlsl_options.shader_model <= 30) - { - // Just emit the global variable, it will be ignored. - type = "float"; + case BuiltInInstanceId: + case BuiltInSampleId: + type = "int"; break; - } - else - SPIRV_CROSS_THROW(join("Unsupported builtin in HLSL: ", unsigned(builtin))); - case BuiltInGlobalInvocationId: - case BuiltInLocalInvocationId: - case BuiltInWorkgroupId: - type = "uint3"; - break; + case BuiltInPointSize: + if (hlsl_options.point_size_compat || hlsl_options.shader_model <= 30) + { + // Just emit the global variable, it will be ignored. + type = "float"; + break; + } + else + SPIRV_CROSS_THROW(join("Unsupported builtin in HLSL: ", unsigned(builtin))); - case BuiltInLocalInvocationIndex: - type = "uint"; - break; + case BuiltInGlobalInvocationId: + case BuiltInLocalInvocationId: + case BuiltInWorkgroupId: + type = "uint3"; + break; - case BuiltInFrontFacing: - type = "bool"; - break; + case BuiltInLocalInvocationIndex: + type = "uint"; + break; - case BuiltInNumWorkgroups: - case BuiltInPointCoord: - // Handled specially. - break; + case BuiltInFrontFacing: + type = "bool"; + break; - case BuiltInSubgroupLocalInvocationId: - case BuiltInSubgroupSize: - if (hlsl_options.shader_model < 60) - SPIRV_CROSS_THROW("Need SM 6.0 for Wave ops."); - break; + case BuiltInNumWorkgroups: + case BuiltInPointCoord: + // Handled specially. + break; - case BuiltInSubgroupEqMask: - case BuiltInSubgroupLtMask: - case BuiltInSubgroupLeMask: - case BuiltInSubgroupGtMask: - case BuiltInSubgroupGeMask: - if (hlsl_options.shader_model < 60) - SPIRV_CROSS_THROW("Need SM 6.0 for Wave ops."); - type = "uint4"; - break; + case BuiltInSubgroupLocalInvocationId: + case BuiltInSubgroupSize: + if (hlsl_options.shader_model < 60) + SPIRV_CROSS_THROW("Need SM 6.0 for Wave ops."); + break; - case BuiltInHelperInvocation: - if (hlsl_options.shader_model < 50) - SPIRV_CROSS_THROW("Need SM 5.0 for Helper Invocation."); - break; + case BuiltInSubgroupEqMask: + case BuiltInSubgroupLtMask: + case BuiltInSubgroupLeMask: + case BuiltInSubgroupGtMask: + case BuiltInSubgroupGeMask: + if (hlsl_options.shader_model < 60) + SPIRV_CROSS_THROW("Need SM 6.0 for Wave ops."); + type = "uint4"; + break; - case BuiltInClipDistance: - array_size = clip_distance_count; - type = "float"; - break; + case BuiltInHelperInvocation: + if (hlsl_options.shader_model < 50) + SPIRV_CROSS_THROW("Need SM 5.0 for Helper Invocation."); + break; - case BuiltInCullDistance: - array_size = cull_distance_count; - type = "float"; - break; + case BuiltInClipDistance: + array_size = clip_distance_count; + type = "float"; + break; - case BuiltInSampleMask: - if (active_input_builtins.get(BuiltInSampleMask)) - type = sample_mask_in_basetype == SPIRType::UInt ? "uint" : "int"; - else - type = sample_mask_out_basetype == SPIRType::UInt ? "uint" : "int"; - array_size = 1; - break; + case BuiltInCullDistance: + array_size = cull_distance_count; + type = "float"; + break; - case BuiltInPrimitiveId: - case BuiltInViewIndex: - case BuiltInLayer: - type = "uint"; - break; + case BuiltInSampleMask: + if (storage == StorageClass::StorageClassInput) + type = sample_mask_in_basetype == SPIRType::UInt ? "uint" : "int"; + else + type = sample_mask_out_basetype == SPIRType::UInt ? "uint" : "int"; + array_size = 1; + break; - case BuiltInViewportIndex: - case BuiltInPrimitiveShadingRateKHR: - case BuiltInPrimitiveLineIndicesEXT: - case BuiltInCullPrimitiveEXT: - type = "uint"; - break; + case BuiltInPrimitiveId: + case BuiltInViewIndex: + case BuiltInLayer: + type = "uint"; + break; - case BuiltInBaryCoordKHR: - case BuiltInBaryCoordNoPerspKHR: - if (hlsl_options.shader_model < 61) - SPIRV_CROSS_THROW("Need SM 6.1 for barycentrics."); - type = "float3"; - break; + case BuiltInViewportIndex: + case BuiltInPrimitiveShadingRateKHR: + case BuiltInPrimitiveLineIndicesEXT: + case BuiltInCullPrimitiveEXT: + type = "uint"; + break; - default: - SPIRV_CROSS_THROW(join("Unsupported builtin in HLSL: ", unsigned(builtin))); - } + case BuiltInBaryCoordKHR: + case BuiltInBaryCoordNoPerspKHR: + if (hlsl_options.shader_model < 61) + SPIRV_CROSS_THROW("Need SM 6.1 for barycentrics."); + type = "float3"; + break; - StorageClass storage = active_input_builtins.get(i) ? StorageClassInput : StorageClassOutput; + default: + SPIRV_CROSS_THROW(join("Unsupported builtin in HLSL: ", unsigned(builtin))); + } - if (type) - { - if (array_size) - statement("static ", type, " ", builtin_to_glsl(builtin, storage), "[", array_size, "]", init_expr, ";"); - else - statement("static ", type, " ", builtin_to_glsl(builtin, storage), init_expr, ";"); - } + if (type) + { + auto builtin_name = builtin_to_glsl(builtin, storage); + if (array_size) + statement("static ", type, " ", builtin_name, "[", array_size, "]", init_expr, ";"); + else + statement("static ", type, " ", builtin_name, init_expr, ";"); - // SampleMask can be both in and out with sample builtin, in this case we have already - // declared the input variable and we need to add the output one now. - if (builtin == BuiltInSampleMask && storage == StorageClassInput && this->active_output_builtins.get(i)) - { - type = sample_mask_out_basetype == SPIRType::UInt ? "uint" : "int"; - if (array_size) - statement("static ", type, " ", this->builtin_to_glsl(builtin, StorageClassOutput), "[", array_size, "]", init_expr, ";"); - else - statement("static ", type, " ", this->builtin_to_glsl(builtin, StorageClassOutput), init_expr, ";"); + if (storage == StorageClassInput && this->active_output_builtins.get(i)) + { + auto out_builtin_name = builtin_to_glsl(builtin, StorageClassOutput); + if (out_builtin_name != builtin_name) + { + // If built-in name differs, we need to output it again + // (we reevaluate type and array size in case they are different) + has_separate_input_output = true; + } + } + } } }); @@ -3249,6 +3290,8 @@ void CompilerHLSL::emit_hlsl_entry_point() statement("[maxvertexcount(", execution.output_vertices, ")]"); arguments.push_back(join(prim, " SPIRV_Cross_Input stage_input[", input_vertices, "]")); + if (active_input_builtins.get(BuiltInPrimitiveId)) + arguments.push_back("uint gl_PrimitiveID : SV_PrimitiveID"); arguments.push_back(join("inout ", stream_type, " ", "geometry_stream")); break; } @@ -3351,6 +3394,17 @@ void CompilerHLSL::emit_hlsl_entry_point() auto builtin = builtin_to_glsl(static_cast(i), StorageClassInput); switch (static_cast(i)) { + case BuiltInPosition: + if (execution.model == ExecutionModelGeometry) + { + statement("for (int i = 0; i < ", input_vertices, "; i++)"); + begin_scope(); + statement(builtin, "[i] = stage_input[i].", builtin, ";"); + end_scope(); + } + else + statement(builtin, " = stage_input.", builtin, ";"); + break; case BuiltInFragCoord: // VPOS in D3D9 is sampled at integer locations, apply half-pixel offset to be consistent. // TODO: Do we need an option here? Any reason why a D3D9 shader would be used @@ -3420,6 +3474,30 @@ void CompilerHLSL::emit_hlsl_entry_point() case BuiltInHelperInvocation: break; + case BuiltInPrimitiveId: + if (execution.model == ExecutionModelGeometry) + { + // PrimitiveId is a separate function parameter for GS. + // The global is named gl_PrimitiveIDIn (GLSL convention). + statement(builtin, " = gl_PrimitiveID;"); + } + else + statement(builtin, " = stage_input.", builtin, ";"); + break; + + case BuiltInInvocationId: + if (execution.model == ExecutionModelTessellationControl) + { + // Copy from function parameter to global. + statement(builtin, " = uCPID;"); + } + else + { + // For geometry shaders, copy from struct as usual. + statement(builtin, " = stage_input[0].", builtin, ";"); + } + break; + case BuiltInSubgroupEqMask: // Emulate these ... // No 64-bit in HLSL, so have to do it in 32-bit and unroll.