diff --git a/src/reader/spirv/function.cc b/src/reader/spirv/function.cc index fb63dc7d78..5022c6f354 100644 --- a/src/reader/spirv/function.cc +++ b/src/reader/spirv/function.cc @@ -685,6 +685,15 @@ struct LoopStatementBuilder ast::BlockStatement* continuing = nullptr; }; +/// @param decos a list of parsed decorations +/// @returns true if the decorations include a SampleMask builtin +bool HasBuiltinSampleMask(const ast::DecorationList& decos) { + if (auto* builtin = ast::GetDecoration(decos)) { + return builtin->value() == ast::Builtin::kSampleMask; + } + return false; +} + } // namespace BlockInfo::BlockInfo(const spvtools::opt::BasicBlock& bb) @@ -973,11 +982,17 @@ bool FunctionEmitter::EmitEntryPointAsWrapper() { // variable. ast::Expression* param_value = create(source, param_sym); - if (forced_store_type != store_type) { - // Insert a bitcast if needed. - const auto cast_name = namer_.MakeDerivedName(param_name + "_cast"); - const auto cast_sym = builder_.Symbols().Register(cast_name); - + if (HasBuiltinSampleMask(param_decos)) { + // In Vulkan SPIR-V, the sample mask is an array. In WGSL it's a scalar. + // Use the first element only. + param_value = create( + source, param_value, parser_impl_.MakeNullValue(ty_.I32())); + if (store_type->As()->type->IsSignedScalarOrVector()) { + // sample_mask is unsigned in WGSL. Bitcast it. + param_value = create( + source, ty_.I32()->Build(builder_), param_value); + } + } else if (forced_store_type != store_type) { // The parameter will have the WGSL type, but we need to add // a bitcast to the variable store type. param_value = create( @@ -1046,9 +1061,15 @@ bool FunctionEmitter::EmitEntryPointAsWrapper() { std::move(out_decos)); return_members.push_back(return_member); + ast::Expression* return_member_value = + create(source, var_sym); + if (forced_store_type != store_type) { + // We need to cast from the variable store type to the member type. + return_member_value = create( + source, forced_store_type->Build(builder_), return_member_value); + } // Save the expression. - return_exprs.push_back( - create(source, var_sym)); + return_exprs.push_back(return_member_value); } // Create and register the result type. diff --git a/src/reader/spirv/parser_impl_module_var_test.cc b/src/reader/spirv/parser_impl_module_var_test.cc index f40b5f3f28..a755f13d2c 100644 --- a/src/reader/spirv/parser_impl_module_var_test.cc +++ b/src/reader/spirv/parser_impl_module_var_test.cc @@ -4470,13 +4470,155 @@ TEST_F(SpvModuleScopeVarParserTest, // SampleMask is an array in Vulkan SPIR-V, but a scalar in WGSL. TEST_F(SpvModuleScopeVarParserTest, - DISABLED_EntryPointWrapping_BuiltinVar_SampleMask_U) {} + EntryPointWrapping_BuiltinVar_SampleMask_In_Unsigned) { + // SampleMask is u32 in WGSL. + // Use unsigned array element in Vulkan. + const auto assembly = CommonCapabilities() + R"( + OpEntryPoint Fragment %main "main" %1 + OpExecutionMode %main OriginUpperLeft + OpDecorate %1 BuiltIn SampleMask +)" + CommonTypes() + + R"( + %arr = OpTypeArray %uint %uint_1 + %ptr_ty = OpTypePointer Input %arr + %1 = OpVariable %ptr_ty Input + + %main = OpFunction %void None %voidfn + %entry = OpLabel + OpReturn + OpFunctionEnd + )"; + auto p = parser(test::Assemble(assembly)); + + // TODO(crbug.com/tint/508): Remove this when everything is converted + // to HLSL style pipeline IO. + p->SetHLSLStylePipelineIO(); + + ASSERT_TRUE(p->Parse()) << p->error() << assembly; + EXPECT_TRUE(p->error().empty()); + const auto got = p->program().to_str(); + const std::string expected = R"(Module{ + Variable{ + x_1 + private + undefined + __array__u32_1 + } + Function main_1 -> __void + () + { + Return{} + } + Function main -> __void + StageDecoration{fragment} + ( + VariableConst{ + Decorations{ + BuiltinDecoration{sample_mask} + } + x_1_param + none + undefined + __u32 + } + ) + { + Assignment{ + Identifier[not set]{x_1} + ArrayAccessor[not set]{ + Identifier[not set]{x_1_param} + ScalarConstructor[not set]{0} + } + } + Call[not set]{ + Identifier[not set]{main_1} + ( + ) + } + } +} +)"; + EXPECT_EQ(got, expected) << got; +} + TEST_F(SpvModuleScopeVarParserTest, - DISABLED_EntryPointWrapping_BuiltinVar_SampleMask_U_Initializer) {} + EntryPointWrapping_BuiltinVar_SampleMask_In_Signed) { + // SampleMask is u32 in WGSL. + // Use signed array element in Vulkan. + const auto assembly = CommonCapabilities() + R"( + OpEntryPoint Fragment %main "main" %1 + OpExecutionMode %main OriginUpperLeft + OpDecorate %1 BuiltIn SampleMask +)" + CommonTypes() + + R"( + %arr = OpTypeArray %int %uint_1 + %ptr_ty = OpTypePointer Input %arr + %1 = OpVariable %ptr_ty Input + + %main = OpFunction %void None %voidfn + %entry = OpLabel + OpReturn + OpFunctionEnd + )"; + auto p = parser(test::Assemble(assembly)); + + // TODO(crbug.com/tint/508): Remove this when everything is converted + // to HLSL style pipeline IO. + p->SetHLSLStylePipelineIO(); + + ASSERT_TRUE(p->Parse()) << p->error() << assembly; + EXPECT_TRUE(p->error().empty()); + const auto got = p->program().to_str(); + const std::string expected = R"(Module{ + Variable{ + x_1 + private + undefined + __array__i32_1 + } + Function main_1 -> __void + () + { + Return{} + } + Function main -> __void + StageDecoration{fragment} + ( + VariableConst{ + Decorations{ + BuiltinDecoration{sample_mask} + } + x_1_param + none + undefined + __u32 + } + ) + { + Assignment{ + Identifier[not set]{x_1} + Bitcast[not set]<__i32>{ + ArrayAccessor[not set]{ + Identifier[not set]{x_1_param} + ScalarConstructor[not set]{0} + } + } + } + Call[not set]{ + Identifier[not set]{main_1} + ( + ) + } + } +} +)"; + EXPECT_EQ(got, expected) << got; +} + TEST_F(SpvModuleScopeVarParserTest, - DISABLED_EntryPointWrapping_BuiltinVar_SampleMask_S) {} + DISABLED_EntryPointWrapping_BuiltinVar_SampleMask_Out_U) {} TEST_F(SpvModuleScopeVarParserTest, - DISABLED_EntryPointWrapping_BuiltinVar_SampleMask_S_Initializer) {} + DISABLED_EntryPointWrapping_BuiltinVar_SampleMask_Out_S) {} // TODO(dneto): pipeline IO: flatten structures, and distribute locations