diff --git a/src/reader/spirv/function.cc b/src/reader/spirv/function.cc index 13368de2bb..fb63dc7d78 100644 --- a/src/reader/spirv/function.cc +++ b/src/reader/spirv/function.cc @@ -978,8 +978,10 @@ bool FunctionEmitter::EmitEntryPointAsWrapper() { const auto cast_name = namer_.MakeDerivedName(param_name + "_cast"); const auto cast_sym = builder_.Symbols().Register(cast_name); + // The parameter will have the WGSL type, but we need to add + // a bitcast to the variable store type. param_value = create( - source, forced_store_type->Build(builder_), param_value); + source, store_type->Build(builder_), param_value); } stmts.push_back(create( diff --git a/src/reader/spirv/parser_impl.cc b/src/reader/spirv/parser_impl.cc index 62d92a16b7..01df33cd53 100644 --- a/src/reader/spirv/parser_impl.cc +++ b/src/reader/spirv/parser_impl.cc @@ -1420,8 +1420,13 @@ bool ParserImpl::ConvertDecorationsForVariable( case SpvBuiltInSampleId: case SpvBuiltInVertexIndex: case SpvBuiltInInstanceIndex: - // The SPIR-V variable is likely to be signed (because GLSL - // requires signed), but WGSL requires unsigned. Handle specially + case SpvBuiltInLocalInvocationId: + case SpvBuiltInLocalInvocationIndex: + case SpvBuiltInGlobalInvocationId: + case SpvBuiltInWorkgroupId: + case SpvBuiltInNumWorkgroups: + // The SPIR-V variable may signed (because GLSL requires signed for + // some of these), but WGSL requires unsigned. Handle specially // so we always perform the conversion at load and store. if (auto* forced_type = UnsignedTypeFor(*type)) { // Requires conversion and special handling in code generation. diff --git a/src/reader/spirv/parser_impl_module_var_test.cc b/src/reader/spirv/parser_impl_module_var_test.cc index 984147cbda..f40b5f3f28 100644 --- a/src/reader/spirv/parser_impl_module_var_test.cc +++ b/src/reader/spirv/parser_impl_module_var_test.cc @@ -90,11 +90,16 @@ std::string CommonTypes() { %m3v2float = OpTypeMatrix %v2float 3 %arr2uint = OpTypeArray %uint %uint_2 - %strct = OpTypeStruct %uint %float %arr2uint )"; } -// Returns layout annotations for types in CommonTypes() +std::string StructTypes() { + return R"( + %strct = OpTypeStruct %uint %float %arr2uint +)"; +} + +// Returns layout annotations for types in StructTypes() std::string CommonLayout() { return R"( OpMemberDecorate %strct 0 Offset 0 @@ -1551,7 +1556,8 @@ TEST_F(SpvModuleScopeVarParserTest, ArrayUndefInitializer) { } TEST_F(SpvModuleScopeVarParserTest, StructInitializer) { - auto p = parser(test::Assemble(Preamble() + FragMain() + CommonTypes() + R"( + auto p = parser(test::Assemble(Preamble() + FragMain() + CommonTypes() + + StructTypes() + R"( %ptr = OpTypePointer Private %strct %two = OpConstant %uint 2 %arrconst = OpConstantComposite %arr2uint %uint_1 %two @@ -1583,7 +1589,8 @@ TEST_F(SpvModuleScopeVarParserTest, StructInitializer) { } TEST_F(SpvModuleScopeVarParserTest, StructNullInitializer) { - auto p = parser(test::Assemble(Preamble() + FragMain() + CommonTypes() + R"( + auto p = parser(test::Assemble(Preamble() + FragMain() + CommonTypes() + + StructTypes() + R"( %ptr = OpTypePointer Private %strct %const = OpConstantNull %strct %200 = OpVariable %ptr Private %const @@ -1613,7 +1620,8 @@ TEST_F(SpvModuleScopeVarParserTest, StructNullInitializer) { } TEST_F(SpvModuleScopeVarParserTest, StructUndefInitializer) { - auto p = parser(test::Assemble(Preamble() + FragMain() + CommonTypes() + R"( + auto p = parser(test::Assemble(Preamble() + FragMain() + CommonTypes() + + StructTypes() + R"( %ptr = OpTypePointer Private %strct %const = OpUndef %strct %200 = OpVariable %ptr Private %const @@ -1704,7 +1712,8 @@ TEST_F(SpvModuleScopeVarParserTest, DescriptorGroupDecoration_Valid) { OpDecorate %1 DescriptorSet 3 OpDecorate %1 Binding 9 ; Required to pass WGSL validation OpDecorate %strct Block -)" + CommonTypes() + R"( +)" + CommonTypes() + StructTypes() + + R"( %ptr_sb_strct = OpTypePointer StorageBuffer %strct %1 = OpVariable %ptr_sb_strct StorageBuffer )" + MainBody())); @@ -1730,7 +1739,8 @@ TEST_F(SpvModuleScopeVarParserTest, const auto assembly = Preamble() + FragMain() + CommonLayout() + R"( OpDecorate %1 DescriptorSet OpDecorate %strct Block -)" + CommonTypes() + R"( +)" + CommonTypes() + StructTypes() + + R"( %ptr_sb_strct = OpTypePointer StorageBuffer %strct %1 = OpVariable %ptr_sb_strct StorageBuffer )" + MainBody(); @@ -1744,7 +1754,8 @@ TEST_F(SpvModuleScopeVarParserTest, OpName %myvar "myvar" OpDecorate %myvar DescriptorSet 3 4 OpDecorate %strct Block -)" + CommonTypes() + R"( +)" + CommonTypes() + StructTypes() + + R"( %ptr_sb_strct = OpTypePointer StorageBuffer %strct %myvar = OpVariable %ptr_sb_strct StorageBuffer )" + MainBody(); @@ -1760,6 +1771,7 @@ TEST_F(SpvModuleScopeVarParserTest, BindingDecoration_Valid) { OpDecorate %1 Binding 3 OpDecorate %strct Block )" + CommonLayout() + CommonTypes() + + StructTypes() + R"( %ptr_sb_strct = OpTypePointer StorageBuffer %strct %1 = OpVariable %ptr_sb_strct StorageBuffer @@ -1787,7 +1799,8 @@ TEST_F(SpvModuleScopeVarParserTest, OpName %myvar "myvar" OpDecorate %myvar Binding OpDecorate %strct Block -)" + CommonTypes() + R"( +)" + CommonTypes() + StructTypes() + + R"( %ptr_sb_strct = OpTypePointer StorageBuffer %strct %myvar = OpVariable %ptr_sb_strct StorageBuffer )" + MainBody(); @@ -1800,7 +1813,8 @@ TEST_F(SpvModuleScopeVarParserTest, BindingDecoration_TwoOperandsWontAssemble) { OpName %myvar "myvar" OpDecorate %myvar Binding 3 4 OpDecorate %strct Block -)" + CommonTypes() + R"( +)" + CommonTypes() + StructTypes() + + R"( %ptr_sb_strct = OpTypePointer StorageBuffer %strct %myvar = OpVariable %ptr_sb_strct StorageBuffer )" + MainBody(); @@ -1818,7 +1832,7 @@ TEST_F(SpvModuleScopeVarParserTest, OpDecorate %strct Block OpMemberDecorate %strct 0 NonReadable )" + CommonLayout() + CommonTypes() + - R"( + StructTypes() + R"( %ptr_sb_strct = OpTypePointer StorageBuffer %strct %1 = OpVariable %ptr_sb_strct StorageBuffer )" + MainBody())); @@ -4317,10 +4331,153 @@ TEST_F(SpvModuleScopeVarParserTest, EntryPointWrapping_IOLocations) { EXPECT_THAT(got, HasSubstr(expected)) << got; } -// TODO(dneto): pipeline IO: convert signedness on builtin inputs in the wrapper -// body -// TODO(dneto): pipeline IO: convert signedness on builtin outputs in the -// wrapper body +TEST_F(SpvModuleScopeVarParserTest, + EntryPointWrapping_BuiltinVar_Input_SameSignedness) { + // local_invocation_index is u32 in WGSL. Use uint in SPIR-V. + // No bitcasts are used for parameter formation or return value. + const auto assembly = CommonCapabilities() + R"( + OpEntryPoint GLCompute %main "main" %1 + OpExecutionMode %main LocalSize 1 1 1 + OpDecorate %1 BuiltIn LocalInvocationIndex +)" + CommonTypes() + + R"( + %ptr_in_uint = OpTypePointer Input %uint + %1 = OpVariable %ptr_in_uint Input + + %main = OpFunction %void None %voidfn + %entry = OpLabel + OpReturn + OpFunctionEnd + )"; + auto p = parser(test::Assemble(assembly)); + + // TODO(crbug.com/tint/508): Remove this when everything is converted + // to HLSL style pipeline IO. + p->SetHLSLStylePipelineIO(); + + ASSERT_TRUE(p->Parse()) << p->error() << assembly; + EXPECT_TRUE(p->error().empty()); + const auto got = p->program().to_str(); + const std::string expected = R"(Module{ + Variable{ + x_1 + private + undefined + __u32 + } + Function main_1 -> __void + () + { + Return{} + } + Function main -> __void + StageDecoration{compute} + ( + VariableConst{ + Decorations{ + BuiltinDecoration{local_invocation_index} + } + x_1_param + none + undefined + __u32 + } + ) + { + Assignment{ + Identifier[not set]{x_1} + Identifier[not set]{x_1_param} + } + Call[not set]{ + Identifier[not set]{main_1} + ( + ) + } + } +} +)"; + EXPECT_EQ(got, expected) << got; +} + +TEST_F(SpvModuleScopeVarParserTest, + EntryPointWrapping_BuiltinVar_Input_OppositeSignedness) { + // local_invocation_index is u32 in WGSL. Use int in SPIR-V. + const auto assembly = CommonCapabilities() + R"( + OpEntryPoint GLCompute %main "main" %1 + OpExecutionMode %main LocalSize 1 1 1 + OpDecorate %1 BuiltIn LocalInvocationIndex +)" + CommonTypes() + + R"( + %ptr_in_int = OpTypePointer Input %int + %1 = OpVariable %ptr_in_int Input + + %main = OpFunction %void None %voidfn + %entry = OpLabel + OpReturn + OpFunctionEnd + )"; + auto p = parser(test::Assemble(assembly)); + + // TODO(crbug.com/tint/508): Remove this when everything is converted + // to HLSL style pipeline IO. + p->SetHLSLStylePipelineIO(); + + ASSERT_TRUE(p->Parse()) << p->error() << assembly; + EXPECT_TRUE(p->error().empty()); + const auto got = p->program().to_str(); + const std::string expected = R"(Module{ + Variable{ + x_1 + private + undefined + __i32 + } + Function main_1 -> __void + () + { + Return{} + } + Function main -> __void + StageDecoration{compute} + ( + VariableConst{ + Decorations{ + BuiltinDecoration{local_invocation_index} + } + x_1_param + none + undefined + __u32 + } + ) + { + Assignment{ + Identifier[not set]{x_1} + Bitcast[not set]<__i32>{ + Identifier[not set]{x_1_param} + } + } + Call[not set]{ + Identifier[not set]{main_1} + ( + ) + } + } +} +)"; + EXPECT_EQ(got, expected) << got; +} + +// SampleMask is an array in Vulkan SPIR-V, but a scalar in WGSL. +TEST_F(SpvModuleScopeVarParserTest, + DISABLED_EntryPointWrapping_BuiltinVar_SampleMask_U) {} +TEST_F(SpvModuleScopeVarParserTest, + DISABLED_EntryPointWrapping_BuiltinVar_SampleMask_U_Initializer) {} +TEST_F(SpvModuleScopeVarParserTest, + DISABLED_EntryPointWrapping_BuiltinVar_SampleMask_S) {} +TEST_F(SpvModuleScopeVarParserTest, + DISABLED_EntryPointWrapping_BuiltinVar_SampleMask_S_Initializer) {} + // TODO(dneto): pipeline IO: flatten structures, and distribute locations // TODO(dneto): Test passing pointer to SampleMask as function parameter,