Skip to content

Legalization of Workgroup function parameter #2430

@ehsannas

Description

@ehsannas

We are wondering whether spirv-opt can legalize/optimize the following scenario.

We have this sample HLSL shader:

struct testStruct {
	float Data[10][10];
};

float testFunc(testStruct t, uint topLeftX, uint topLeftY) {
  t.Data[0][0] = 2.0f * t.Data[0][0];
  return t.Data[0][0];
}
  
groupshared testStruct sharedStruct;

RWStructuredBuffer<float> testOut;

[numthreads(8, 8, 1)]
void testMain( uint3 DispatchThreadId	: SV_DispatchThreadID, uint3 GroupThreadID : SV_GroupThreadID, uint3 GroupID : SV_GroupID ) {
    float result = testFunc(sharedStruct, GroupThreadID.x, GroupThreadID.y);
    testOut[GroupThreadID.x] = result;
}

Which DXC compiles to the following SPIR-V (without legalization, without optimization):

; SPIR-V
; Version: 1.0
; Generator: Google spiregg; 0
; Bound: 71
; Schema: 0
               OpCapability Shader
               OpMemoryModel Logical GLSL450
               OpEntryPoint GLCompute %testMain "testMain" %gl_GlobalInvocationID %gl_LocalInvocationID %gl_WorkGroupID
               OpExecutionMode %testMain LocalSize 8 8 1
               OpSource HLSL 600
               OpName %testStruct "testStruct"
               OpMemberName %testStruct 0 "Data"
               OpName %sharedStruct "sharedStruct"
               OpName %type_RWStructuredBuffer_float "type.RWStructuredBuffer.float"
               OpName %testOut "testOut"
               OpName %type_ACSBuffer_counter "type.ACSBuffer.counter"
               OpMemberName %type_ACSBuffer_counter 0 "counter"
               OpName %counter_var_testOut "counter.var.testOut"
               OpName %testMain "testMain"
               OpName %param_var_DispatchThreadId "param.var.DispatchThreadId"
               OpName %param_var_GroupThreadID "param.var.GroupThreadID"
               OpName %param_var_GroupID "param.var.GroupID"
               OpName %src_testMain "src.testMain"
               OpName %DispatchThreadId "DispatchThreadId"
               OpName %GroupThreadID "GroupThreadID"
               OpName %GroupID "GroupID"
               OpName %bb_entry "bb.entry"
               OpName %result "result"
               OpName %param_var_topLeftX "param.var.topLeftX"
               OpName %param_var_topLeftY "param.var.topLeftY"
               OpName %testFunc "testFunc"
               OpName %t "t"
               OpName %topLeftX "topLeftX"
               OpName %topLeftY "topLeftY"
               OpName %bb_entry_0 "bb.entry"
               OpDecorate %gl_GlobalInvocationID BuiltIn GlobalInvocationId
               OpDecorate %gl_LocalInvocationID BuiltIn LocalInvocationId
               OpDecorate %gl_WorkGroupID BuiltIn WorkgroupId
               OpDecorate %testOut DescriptorSet 0
               OpDecorate %testOut Binding 0
               OpDecorate %counter_var_testOut DescriptorSet 0
               OpDecorate %counter_var_testOut Binding 1
               OpDecorate %_runtimearr_float ArrayStride 4
               OpMemberDecorate %type_RWStructuredBuffer_float 0 Offset 0
               OpDecorate %type_RWStructuredBuffer_float BufferBlock
               OpMemberDecorate %type_ACSBuffer_counter 0 Offset 0
               OpDecorate %type_ACSBuffer_counter BufferBlock
        %int = OpTypeInt 32 1
      %int_0 = OpConstant %int 0
      %int_1 = OpConstant %int 1
      %float = OpTypeFloat 32
    %float_2 = OpConstant %float 2
       %uint = OpTypeInt 32 0
    %uint_10 = OpConstant %uint 10
%_arr_float_uint_10 = OpTypeArray %float %uint_10
%_arr__arr_float_uint_10_uint_10 = OpTypeArray %_arr_float_uint_10 %uint_10
 %testStruct = OpTypeStruct %_arr__arr_float_uint_10_uint_10
%_ptr_Workgroup_testStruct = OpTypePointer Workgroup %testStruct
%_runtimearr_float = OpTypeRuntimeArray %float
%type_RWStructuredBuffer_float = OpTypeStruct %_runtimearr_float
%_ptr_Uniform_type_RWStructuredBuffer_float = OpTypePointer Uniform %type_RWStructuredBuffer_float
%type_ACSBuffer_counter = OpTypeStruct %int
%_ptr_Uniform_type_ACSBuffer_counter = OpTypePointer Uniform %type_ACSBuffer_counter
     %v3uint = OpTypeVector %uint 3
%_ptr_Input_v3uint = OpTypePointer Input %v3uint
       %void = OpTypeVoid
         %27 = OpTypeFunction %void
%_ptr_Function_v3uint = OpTypePointer Function %v3uint
         %38 = OpTypeFunction %void %_ptr_Function_v3uint %_ptr_Function_v3uint %_ptr_Function_v3uint
%_ptr_Function_float = OpTypePointer Function %float
%_ptr_Function_uint = OpTypePointer Function %uint
%_ptr_Uniform_float = OpTypePointer Uniform %float
%_ptr_Function_testStruct = OpTypePointer Function %testStruct
         %59 = OpTypeFunction %float %_ptr_Function_testStruct %_ptr_Function_uint %_ptr_Function_uint
%sharedStruct = OpVariable %_ptr_Workgroup_testStruct Workgroup
    %testOut = OpVariable %_ptr_Uniform_type_RWStructuredBuffer_float Uniform
%counter_var_testOut = OpVariable %_ptr_Uniform_type_ACSBuffer_counter Uniform
%gl_GlobalInvocationID = OpVariable %_ptr_Input_v3uint Input
%gl_LocalInvocationID = OpVariable %_ptr_Input_v3uint Input
%gl_WorkGroupID = OpVariable %_ptr_Input_v3uint Input
   %testMain = OpFunction %void None %27
         %28 = OpLabel
%param_var_DispatchThreadId = OpVariable %_ptr_Function_v3uint Function
%param_var_GroupThreadID = OpVariable %_ptr_Function_v3uint Function
%param_var_GroupID = OpVariable %_ptr_Function_v3uint Function
         %33 = OpLoad %v3uint %gl_GlobalInvocationID
         %34 = OpLoad %v3uint %gl_LocalInvocationID
               OpStore %param_var_GroupThreadID %34
         %35 = OpLoad %v3uint %gl_WorkGroupID
         %36 = OpFunctionCall %void %src_testMain %param_var_DispatchThreadId %param_var_GroupThreadID %param_var_GroupID
               OpReturn
               OpFunctionEnd
%src_testMain = OpFunction %void None %38
%DispatchThreadId = OpFunctionParameter %_ptr_Function_v3uint
%GroupThreadID = OpFunctionParameter %_ptr_Function_v3uint
    %GroupID = OpFunctionParameter %_ptr_Function_v3uint
   %bb_entry = OpLabel
     %result = OpVariable %_ptr_Function_float Function
%param_var_topLeftX = OpVariable %_ptr_Function_uint Function
%param_var_topLeftY = OpVariable %_ptr_Function_uint Function
         %48 = OpAccessChain %_ptr_Function_uint %GroupThreadID %int_0
         %49 = OpLoad %uint %48
               OpStore %param_var_topLeftX %49
         %50 = OpAccessChain %_ptr_Function_uint %GroupThreadID %int_1
         %51 = OpLoad %uint %50
               OpStore %param_var_topLeftY %51
         %52 = OpFunctionCall %float %testFunc %sharedStruct %param_var_topLeftX %param_var_topLeftY
               OpStore %result %52
         %54 = OpLoad %float %result
         %55 = OpAccessChain %_ptr_Function_uint %GroupThreadID %int_0
         %56 = OpLoad %uint %55
         %58 = OpAccessChain %_ptr_Uniform_float %testOut %int_0 %56
               OpStore %58 %54
               OpReturn
               OpFunctionEnd
   %testFunc = OpFunction %float None %59
          %t = OpFunctionParameter %_ptr_Function_testStruct
   %topLeftX = OpFunctionParameter %_ptr_Function_uint
   %topLeftY = OpFunctionParameter %_ptr_Function_uint
 %bb_entry_0 = OpLabel
         %65 = OpAccessChain %_ptr_Function_float %t %int_0 %int_0 %int_0
         %66 = OpLoad %float %65
         %67 = OpFMul %float %float_2 %66
         %68 = OpAccessChain %_ptr_Function_float %t %int_0 %int_0 %int_0
               OpStore %68 %67
         %69 = OpAccessChain %_ptr_Function_float %t %int_0 %int_0 %int_0
         %70 = OpLoad %float %69
               OpReturnValue %70
               OpFunctionEnd

This is illegal as the validator points out:

OpFunctionCall Argument <id> '18[%sharedStruct]'s type does not match Function <id> '60[%_ptr_Function_testStruct]'s parameter type.
  %52 = OpFunctionCall %float %testFunc %sharedStruct %param_var_topLeftX %param_var_topLeftY

After running legalization and optimization, we get:

; SPIR-V
; Version: 1.0
; Generator: Google spiregg; 0
; Bound: 34
; Schema: 0
               OpCapability Shader
               OpMemoryModel Logical GLSL450
               OpEntryPoint GLCompute %testMain "testMain" %gl_GlobalInvocationID %gl_LocalInvocationID %gl_WorkGroupID
               OpExecutionMode %testMain LocalSize 8 8 1
               OpSource HLSL 600
               OpName %testStruct "testStruct"
               OpMemberName %testStruct 0 "Data"
               OpName %sharedStruct "sharedStruct"
               OpName %type_RWStructuredBuffer_float "type.RWStructuredBuffer.float"
               OpName %testOut "testOut"
               OpName %testMain "testMain"
               OpDecorate %gl_GlobalInvocationID BuiltIn GlobalInvocationId
               OpDecorate %gl_LocalInvocationID BuiltIn LocalInvocationId
               OpDecorate %gl_WorkGroupID BuiltIn WorkgroupId
               OpDecorate %testOut DescriptorSet 0
               OpDecorate %testOut Binding 0
               OpDecorate %_runtimearr_float ArrayStride 4
               OpMemberDecorate %type_RWStructuredBuffer_float 0 Offset 0
               OpDecorate %type_RWStructuredBuffer_float BufferBlock
        %int = OpTypeInt 32 1
      %int_0 = OpConstant %int 0
      %float = OpTypeFloat 32
    %float_2 = OpConstant %float 2
       %uint = OpTypeInt 32 0
    %uint_10 = OpConstant %uint 10
%_arr_float_uint_10 = OpTypeArray %float %uint_10
%_arr__arr_float_uint_10_uint_10 = OpTypeArray %_arr_float_uint_10 %uint_10
 %testStruct = OpTypeStruct %_arr__arr_float_uint_10_uint_10
%_ptr_Workgroup_testStruct = OpTypePointer Workgroup %testStruct
%_runtimearr_float = OpTypeRuntimeArray %float
%type_RWStructuredBuffer_float = OpTypeStruct %_runtimearr_float
%_ptr_Uniform_type_RWStructuredBuffer_float = OpTypePointer Uniform %type_RWStructuredBuffer_float
     %v3uint = OpTypeVector %uint 3
%_ptr_Input_v3uint = OpTypePointer Input %v3uint
       %void = OpTypeVoid
         %23 = OpTypeFunction %void
%_ptr_Function_float = OpTypePointer Function %float
%_ptr_Uniform_float = OpTypePointer Uniform %float
%sharedStruct = OpVariable %_ptr_Workgroup_testStruct Workgroup
    %testOut = OpVariable %_ptr_Uniform_type_RWStructuredBuffer_float Uniform
%gl_GlobalInvocationID = OpVariable %_ptr_Input_v3uint Input
%gl_LocalInvocationID = OpVariable %_ptr_Input_v3uint Input
%gl_WorkGroupID = OpVariable %_ptr_Input_v3uint Input
   %testMain = OpFunction %void None %23
         %26 = OpLabel
         %27 = OpLoad %v3uint %gl_LocalInvocationID
         %28 = OpAccessChain %_ptr_Function_float %sharedStruct %int_0 %int_0 %int_0
         %29 = OpLoad %float %28
         %30 = OpFMul %float %float_2 %29
               OpStore %28 %30
         %31 = OpLoad %float %28
         %32 = OpCompositeExtract %uint %27 0
         %33 = OpAccessChain %_ptr_Uniform_float %testOut %int_0 %32
               OpStore %33 %31
               OpReturn
               OpFunctionEnd

Which is still illegal:

The result pointer storage class and base pointer storage class in OpAccessChain do not match.
  %28 = OpAccessChain %_ptr_Function_float %sharedStruct %int_0 %int_0 %int_0

The HLSL->SPIRV compilation for this example was done using the code in this PR.

Metadata

Metadata

Assignees

Labels

No labels
No labels

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions