diff --git a/src/inspector/inspector_test.cc b/src/inspector/inspector_test.cc index 3dc70abad8..3c9b9350a4 100644 --- a/src/inspector/inspector_test.cc +++ b/src/inspector/inspector_test.cc @@ -1059,11 +1059,11 @@ TEST_F(InspectorGetEntryPointTest, MultipleEntryPointsInOutVariables) { TEST_F(InspectorGetEntryPointTest, BuiltInsNotStageVariables) { auto* in_var0 = - Param("in_var0", ty.u32(), {Builtin(ast::Builtin::kInstanceIndex)}); - auto* in_var1 = Param("in_var1", ty.u32(), {Location(0u)}); - Func("foo", {in_var0, in_var1}, ty.u32(), {Return("in_var1")}, + Param("in_var0", ty.u32(), {Builtin(ast::Builtin::kSampleIndex)}); + auto* in_var1 = Param("in_var1", ty.f32(), {Location(0u)}); + Func("foo", {in_var0, in_var1}, ty.f32(), {Return("in_var1")}, {Stage(ast::PipelineStage::kFragment)}, - {Builtin(ast::Builtin::kSampleMask)}); + {Builtin(ast::Builtin::kFragDepth)}); Inspector& inspector = Build(); auto result = inspector.GetEntryPoints(); @@ -1075,7 +1075,7 @@ TEST_F(InspectorGetEntryPointTest, BuiltInsNotStageVariables) { EXPECT_EQ("in_var1", result[0].input_variables[0].name); EXPECT_TRUE(result[0].input_variables[0].has_location_decoration); EXPECT_EQ(0u, result[0].input_variables[0].location_decoration); - EXPECT_EQ(ComponentType::kUInt, result[0].input_variables[0].component_type); + EXPECT_EQ(ComponentType::kFloat, result[0].input_variables[0].component_type); ASSERT_EQ(0u, result[0].output_variables.size()); } diff --git a/src/resolver/builtins_validation_test.cc b/src/resolver/builtins_validation_test.cc index 6c06ce021e..46a2b6637c 100644 --- a/src/resolver/builtins_validation_test.cc +++ b/src/resolver/builtins_validation_test.cc @@ -16,9 +16,209 @@ #include "src/resolver/resolver_test_helper.h" namespace tint { +namespace resolver { namespace { + +template +using DataType = builder::DataType; +template +using vec2 = builder::vec2; +template +using vec3 = builder::vec3; +template +using vec4 = builder::vec4; +template +using f32 = builder::f32; +using i32 = builder::i32; +using u32 = builder::u32; + class ResolverBuiltinsValidationTest : public resolver::TestHelper, public testing::Test {}; +namespace TypeTemp { +struct Params { + builder::ast_type_func_ptr type; + ast::Builtin builtin; + ast::PipelineStage stage; + bool is_valid; +}; + +template +constexpr Params ParamsFor(ast::Builtin builtin, + ast::PipelineStage stage, + bool is_valid) { + return Params{DataType::AST, builtin, stage, is_valid}; +} +static constexpr Params cases[] = { + ParamsFor(ast::Builtin::kVertexIndex, + ast::PipelineStage::kVertex, + true), + ParamsFor(ast::Builtin::kVertexIndex, + ast::PipelineStage::kFragment, + false), + ParamsFor(ast::Builtin::kVertexIndex, + ast::PipelineStage::kCompute, + false), + + ParamsFor(ast::Builtin::kInstanceIndex, + ast::PipelineStage::kVertex, + true), + ParamsFor(ast::Builtin::kInstanceIndex, + ast::PipelineStage::kFragment, + false), + ParamsFor(ast::Builtin::kInstanceIndex, + ast::PipelineStage::kCompute, + false), + + ParamsFor(ast::Builtin::kFrontFacing, + ast::PipelineStage::kVertex, + false), + ParamsFor(ast::Builtin::kFrontFacing, + ast::PipelineStage::kFragment, + true), + ParamsFor(ast::Builtin::kFrontFacing, + ast::PipelineStage::kCompute, + false), + + ParamsFor>(ast::Builtin::kLocalInvocationId, + ast::PipelineStage::kVertex, + false), + ParamsFor>(ast::Builtin::kLocalInvocationId, + ast::PipelineStage::kFragment, + false), + ParamsFor>(ast::Builtin::kLocalInvocationId, + ast::PipelineStage::kCompute, + true), + + ParamsFor(ast::Builtin::kLocalInvocationIndex, + ast::PipelineStage::kVertex, + false), + ParamsFor(ast::Builtin::kLocalInvocationIndex, + ast::PipelineStage::kFragment, + false), + ParamsFor(ast::Builtin::kLocalInvocationIndex, + ast::PipelineStage::kCompute, + true), + + ParamsFor>(ast::Builtin::kGlobalInvocationId, + ast::PipelineStage::kVertex, + false), + ParamsFor>(ast::Builtin::kGlobalInvocationId, + ast::PipelineStage::kFragment, + false), + ParamsFor>(ast::Builtin::kGlobalInvocationId, + ast::PipelineStage::kCompute, + true), + + ParamsFor>(ast::Builtin::kWorkgroupId, + ast::PipelineStage::kVertex, + false), + ParamsFor>(ast::Builtin::kWorkgroupId, + ast::PipelineStage::kFragment, + false), + ParamsFor>(ast::Builtin::kWorkgroupId, + ast::PipelineStage::kCompute, + true), + + ParamsFor(ast::Builtin::kSampleIndex, + ast::PipelineStage::kVertex, + false), + ParamsFor(ast::Builtin::kSampleIndex, + ast::PipelineStage::kFragment, + true), + ParamsFor(ast::Builtin::kSampleIndex, + ast::PipelineStage::kCompute, + false), + + ParamsFor(ast::Builtin::kSampleMask, + ast::PipelineStage::kVertex, + false), + ParamsFor(ast::Builtin::kSampleMask, + ast::PipelineStage::kFragment, + true), + ParamsFor(ast::Builtin::kSampleMask, + ast::PipelineStage::kCompute, + false), +}; + +using ResolverBuiltinsStageTest = ResolverTestWithParam; +TEST_P(ResolverBuiltinsStageTest, All_input) { + const Params& params = GetParam(); + + auto* p = Global("p", ty.vec4(), ast::StorageClass::kPrivate); + auto* input = + Param("input", params.type(*this), + ast::DecorationList{Builtin(Source{{12, 34}}, params.builtin)}); + switch (params.stage) { + case ast::PipelineStage::kVertex: + Func("main", {input}, ty.vec4(), {Return(p)}, + {Stage(ast::PipelineStage::kVertex)}, + {Builtin(Source{{12, 34}}, ast::Builtin::kPosition)}); + break; + case ast::PipelineStage::kFragment: + Func("main", {input}, ty.void_(), {}, + {Stage(ast::PipelineStage::kFragment)}, {}); + break; + case ast::PipelineStage::kCompute: + Func("main", {input}, ty.void_(), {}, + ast::DecorationList{Stage(ast::PipelineStage::kCompute), + WorkgroupSize(1)}); + break; + default: + break; + } + + if (params.is_valid) { + EXPECT_TRUE(r()->Resolve()) << r()->error(); + } else { + std::stringstream err; + err << "12:34 error: builtin(" << params.builtin << ")"; + err << " cannot be used in input of " << params.stage << " pipeline stage"; + EXPECT_FALSE(r()->Resolve()); + EXPECT_EQ(r()->error(), err.str()); + } +} +INSTANTIATE_TEST_SUITE_P(ResolverBuiltinsValidationTest, + ResolverBuiltinsStageTest, + testing::ValuesIn(cases)); + +TEST_F(ResolverBuiltinsValidationTest, FragDepthIsInput_Fail) { + // [[stage(fragment)]] + // fn fs_main( + // [[builtin(kFragDepth)]] fd: f32, + // ) -> [[location(0)]] f32 { return 1.0; } + auto* fd = Param( + "fd", ty.f32(), + ast::DecorationList{Builtin(Source{{12, 34}}, ast::Builtin::kFragDepth)}); + Func("fs_main", ast::VariableList{fd}, ty.f32(), {Return(1.0f)}, + ast::DecorationList{Stage(ast::PipelineStage::kFragment)}, + {Location(0)}); + EXPECT_FALSE(r()->Resolve()); + EXPECT_EQ(r()->error(), + "12:34 error: builtin(frag_depth) cannot be used in input of " + "fragment pipeline stage"); +} + +TEST_F(ResolverBuiltinsValidationTest, FragDepthIsInputStruct_Fail) { + // Struct MyInputs { + // [[builtin(front_facing)]] ff: bool; + // }; + // [[stage(fragment)]] + // fn fragShader(arg: MyInputs) -> [[location(0)]] f32 { return 1.0; } + + auto* s = Structure( + "MyInputs", {Member("frag_depth", ty.f32(), + ast::DecorationList{Builtin( + Source{{12, 34}}, ast::Builtin::kFragDepth)})}); + + Func("fragShader", {Param("arg", ty.Of(s))}, ty.f32(), {Return(1.0f)}, + {Stage(ast::PipelineStage::kFragment)}, {Location(0)}); + EXPECT_FALSE(r()->Resolve()); + EXPECT_EQ( + r()->error(), + "12:34 error: builtin(frag_depth) cannot be used in input of fragment " + "pipeline stage\nnote: while analysing entry point fragShader"); +} +} // namespace TypeTemp TEST_F(ResolverBuiltinsValidationTest, PositionNotF32_Struct_Fail) { // struct MyInputs { @@ -170,15 +370,12 @@ TEST_F(ResolverBuiltinsValidationTest, PositionIsNotF32_Fail) { TEST_F(ResolverBuiltinsValidationTest, FragDepthIsNotF32_Fail) { // [[stage(fragment)]] - // fn fs_main( - // [[builtin(kFragDepth)]] fd: f32, - // ) -> [[location(0)]] f32 { return 1.0; } - auto* fd = Param( - "fd", ty.i32(), + // fn fs_main() -> [[builtin(kFragDepth)]] f32 { var fd: i32; return fd; } + auto* fd = Var("fd", ty.i32()); + Func( + "fs_main", {}, ty.i32(), {Decl(fd), Return(fd)}, + ast::DecorationList{Stage(ast::PipelineStage::kFragment)}, ast::DecorationList{Builtin(Source{{12, 34}}, ast::Builtin::kFragDepth)}); - Func("fs_main", ast::VariableList{fd}, ty.f32(), {Return(1.0f)}, - ast::DecorationList{Stage(ast::PipelineStage::kFragment)}, - {Location(0)}); EXPECT_FALSE(r()->Resolve()); EXPECT_EQ(r()->error(), "12:34 error: store type of builtin(frag_depth) must be 'f32'"); @@ -227,44 +424,43 @@ TEST_F(ResolverBuiltinsValidationTest, FragmentBuiltin_Pass) { // fn fs_main( // [[builtin(kPosition)]] p: vec4, // [[builtin(front_facing)]] ff: bool, - // [[builtin(frag_depth)]] fd: f32, // [[builtin(sample_index)]] si: u32, // [[builtin(sample_mask)]] sm : u32 - // ) -> [[location(0)]] f32 { return 1.0; } + // ) -> [[builtin(frag_depth)]] f32 { var fd: f32; return fd; } auto* p = Param("p", ty.vec4(), ast::DecorationList{Builtin(ast::Builtin::kPosition)}); auto* ff = Param("ff", ty.bool_(), ast::DecorationList{Builtin(ast::Builtin::kFrontFacing)}); - auto* fd = Param("fd", ty.f32(), - ast::DecorationList{Builtin(ast::Builtin::kFragDepth)}); auto* si = Param("si", ty.u32(), ast::DecorationList{Builtin(ast::Builtin::kSampleIndex)}); auto* sm = Param("sm", ty.u32(), ast::DecorationList{Builtin(ast::Builtin::kSampleMask)}); - Func( - "fs_main", ast::VariableList{p, ff, fd, si, sm}, ty.f32(), {Return(1.0f)}, - ast::DecorationList{Stage(ast::PipelineStage::kFragment)}, {Location(0)}); + auto* var_fd = Var("fd", ty.f32()); + Func("fs_main", ast::VariableList{p, ff, si, sm}, ty.f32(), + {Decl(var_fd), Return(var_fd)}, + ast::DecorationList{Stage(ast::PipelineStage::kFragment)}, + ast::DecorationList{Builtin(ast::Builtin::kFragDepth)}); EXPECT_TRUE(r()->Resolve()) << r()->error(); } TEST_F(ResolverBuiltinsValidationTest, VertexBuiltin_Pass) { // [[stage(vertex)]] // fn main( - // [[builtin(kVertexIndex)]] vi : u32, - // [[builtin(kInstanceIndex)]] ii : u32, - // [[builtin(kPosition)]] p :vec4 - // ) {} + // [[builtin(vertex_index)]] vi : u32, + // [[builtin(instance_index)]] ii : u32, + // ) -> [[builtin(position)]] vec4 { var p :vec4; return p; } auto* vi = Param("vi", ty.u32(), ast::DecorationList{ Builtin(Source{{12, 34}}, ast::Builtin::kVertexIndex)}); - auto* p = Param("p", ty.vec4(), - ast::DecorationList{Builtin(ast::Builtin::kPosition)}); + auto* ii = Param("ii", ty.u32(), ast::DecorationList{Builtin(Source{{12, 34}}, ast::Builtin::kInstanceIndex)}); - Func("main", ast::VariableList{vi, ii, p}, ty.vec4(), + auto* p = Var("p", ty.vec4()); + Func("main", ast::VariableList{vi, ii}, ty.vec4(), { - Return(Expr(p)), + Decl(p), + Return(p), }, ast::DecorationList{Stage(ast::PipelineStage::kVertex)}, ast::DecorationList{Builtin(ast::Builtin::kPosition)}); @@ -369,7 +565,6 @@ TEST_F(ResolverBuiltinsValidationTest, TEST_F(ResolverBuiltinsValidationTest, FragmentBuiltinStruct_Pass) { // Struct MyInputs { // [[builtin(kPosition)]] p: vec4; - // [[builtin(front_facing)]] ff: bool; // [[builtin(frag_depth)]] fd: f32; // [[builtin(sample_index)]] si: u32; // [[builtin(sample_mask)]] sm : u32;; @@ -383,8 +578,6 @@ TEST_F(ResolverBuiltinsValidationTest, FragmentBuiltinStruct_Pass) { ast::DecorationList{Builtin(ast::Builtin::kPosition)}), Member("front_facing", ty.bool_(), ast::DecorationList{Builtin(ast::Builtin::kFrontFacing)}), - Member("frag_depth", ty.f32(), - ast::DecorationList{Builtin(ast::Builtin::kFragDepth)}), Member("sample_index", ty.u32(), ast::DecorationList{Builtin(ast::Builtin::kSampleIndex)}), Member("sample_mask", ty.u32(), @@ -1006,4 +1199,5 @@ INSTANTIATE_TEST_SUITE_P(ResolverBuiltinsValidationTest, "pack2x16float")); } // namespace +} // namespace resolver } // namespace tint diff --git a/src/resolver/decoration_validation_test.cc b/src/resolver/decoration_validation_test.cc index ba261a8ee0..61f8d36834 100644 --- a/src/resolver/decoration_validation_test.cc +++ b/src/resolver/decoration_validation_test.cc @@ -145,7 +145,8 @@ TEST_P(FunctionParameterDecorationTest, IsValid) { } else { EXPECT_FALSE(r()->Resolve()) << r()->error(); EXPECT_EQ(r()->error(), - "error: decoration is not valid for function parameters"); + "error: decoration is not valid for non-entry point function " + "parameters"); } } INSTANTIATE_TEST_SUITE_P( @@ -244,7 +245,8 @@ TEST_P(FunctionReturnTypeDecorationTest, IsValid) { } else { EXPECT_FALSE(r()->Resolve()) << r()->error(); EXPECT_EQ(r()->error(), - "error: decoration is not valid for function return types"); + "error: decoration is not valid for non-entry point function " + "return types"); } } INSTANTIATE_TEST_SUITE_P( diff --git a/src/resolver/entry_point_validation_test.cc b/src/resolver/entry_point_validation_test.cc index 6b63223455..bd0201f1a3 100644 --- a/src/resolver/entry_point_validation_test.cc +++ b/src/resolver/entry_point_validation_test.cc @@ -289,16 +289,6 @@ TEST_F(ResolverEntryPointValidationTest, ParameterAttribute_Location) { EXPECT_TRUE(r()->Resolve()) << r()->error(); } -TEST_F(ResolverEntryPointValidationTest, ParameterAttribute_Builtin) { - // [[stage(fragment)]] - // fn main([[builtin(frag_depth)]] param : f32) {} - auto* param = Param("param", ty.f32(), {Builtin(ast::Builtin::kFragDepth)}); - Func(Source{{12, 34}}, "main", {param}, ty.void_(), {}, - {Stage(ast::PipelineStage::kFragment)}); - - EXPECT_TRUE(r()->Resolve()) << r()->error(); -} - TEST_F(ResolverEntryPointValidationTest, ParameterAttribute_Missing) { // [[stage(fragment)]] // fn main(param : f32) {} @@ -313,10 +303,10 @@ TEST_F(ResolverEntryPointValidationTest, ParameterAttribute_Missing) { TEST_F(ResolverEntryPointValidationTest, ParameterAttribute_Multiple) { // [[stage(fragment)]] - // fn main([[location(0)]] [[builtin(vertex_index)]] param : u32) {} + // fn main([[location(0)]] [[builtin(sample_index)]] param : u32) {} auto* param = Param("param", ty.u32(), {Location(Source{{13, 43}}, 0), - Builtin(Source{{14, 52}}, ast::Builtin::kVertexIndex)}); + Builtin(Source{{14, 52}}, ast::Builtin::kSampleIndex)}); Func(Source{{12, 34}}, "main", {param}, ty.void_(), {}, {Stage(ast::PipelineStage::kFragment)}); diff --git a/src/resolver/resolver.cc b/src/resolver/resolver.cc index 7a4c7a7174..7002b5a52a 100644 --- a/src/resolver/resolver.cc +++ b/src/resolver/resolver.cc @@ -928,20 +928,15 @@ bool Resolver::ValidateFunctionParameter(const ast::Function* func, for (auto* deco : info->declaration->decorations()) { if (!func->IsEntryPoint() && !deco->Is()) { - AddError("decoration is not valid for function parameters", - deco->source()); + AddError( + "decoration is not valid for non-entry point function parameters", + deco->source()); return false; - } - - if (auto* builtin = deco->As()) { - if (!ValidateBuiltinDecoration(builtin, info->type)) { - return false; - } } else if (auto* interpolate = deco->As()) { if (!ValidateInterpolateDecoration(interpolate, info->type)) { return false; } - } else if (!deco->IsAnyOfIsAnyOf() && (IsValidationEnabled( info->declaration->decorations(), @@ -989,10 +984,25 @@ bool Resolver::ValidateFunctionParameter(const ast::Function* func, } bool Resolver::ValidateBuiltinDecoration(const ast::BuiltinDecoration* deco, - const sem::Type* storage_type) { + const sem::Type* storage_type, + const bool is_input) { auto* type = storage_type->UnwrapRef(); + const auto stage = current_function_ + ? current_function_->declaration->pipeline_stage() + : ast::PipelineStage::kNone; + std::stringstream stage_name; + stage_name << stage; + bool is_stage_mismatch = false; switch (deco->value()) { case ast::Builtin::kPosition: + if (stage != ast::PipelineStage::kNone && + !(stage == ast::PipelineStage::kFragment && is_input) && + !(stage == ast::PipelineStage::kVertex && !is_input)) { + AddError(deco_to_str(deco) + " cannot be used in " + + (is_input ? "input of " : "output of ") + + stage_name.str() + " pipeline stage", + deco->source()); + } if (!(type->is_float_vector() && type->As()->size() == 4)) { AddError("store type of " + deco_to_str(deco) + " must be 'vec4'", deco->source()); @@ -1002,6 +1012,10 @@ bool Resolver::ValidateBuiltinDecoration(const ast::BuiltinDecoration* deco, case ast::Builtin::kGlobalInvocationId: case ast::Builtin::kLocalInvocationId: case ast::Builtin::kWorkgroupId: + if (stage != ast::PipelineStage::kNone && + !(stage == ast::PipelineStage::kCompute && is_input)) { + is_stage_mismatch = true; + } if (!(type->is_unsigned_integer_vector() && type->As()->size() == 3)) { AddError("store type of " + deco_to_str(deco) + " must be 'vec3'", @@ -1010,6 +1024,10 @@ bool Resolver::ValidateBuiltinDecoration(const ast::BuiltinDecoration* deco, } break; case ast::Builtin::kFragDepth: + if (stage != ast::PipelineStage::kNone && + !(stage == ast::PipelineStage::kFragment && !is_input)) { + is_stage_mismatch = true; + } if (!type->Is()) { AddError("store type of " + deco_to_str(deco) + " must be 'f32'", deco->source()); @@ -1017,6 +1035,10 @@ bool Resolver::ValidateBuiltinDecoration(const ast::BuiltinDecoration* deco, } break; case ast::Builtin::kFrontFacing: + if (stage != ast::PipelineStage::kNone && + !(stage == ast::PipelineStage::kFragment && is_input)) { + is_stage_mismatch = true; + } if (!type->Is()) { AddError("store type of " + deco_to_str(deco) + " must be 'bool'", deco->source()); @@ -1024,10 +1046,44 @@ bool Resolver::ValidateBuiltinDecoration(const ast::BuiltinDecoration* deco, } break; case ast::Builtin::kLocalInvocationIndex: + if (stage != ast::PipelineStage::kNone && + !(stage == ast::PipelineStage::kCompute && is_input)) { + is_stage_mismatch = true; + } + if (!type->Is()) { + AddError("store type of " + deco_to_str(deco) + " must be 'u32'", + deco->source()); + return false; + } + break; case ast::Builtin::kVertexIndex: case ast::Builtin::kInstanceIndex: + if (stage != ast::PipelineStage::kNone && + !(stage == ast::PipelineStage::kVertex && is_input)) { + is_stage_mismatch = true; + } + if (!type->Is()) { + AddError("store type of " + deco_to_str(deco) + " must be 'u32'", + deco->source()); + return false; + } + break; case ast::Builtin::kSampleMask: + if (stage != ast::PipelineStage::kNone && + !(stage == ast::PipelineStage::kFragment)) { + is_stage_mismatch = true; + } + if (!type->Is()) { + AddError("store type of " + deco_to_str(deco) + " must be 'u32'", + deco->source()); + return false; + } + break; case ast::Builtin::kSampleIndex: + if (stage != ast::PipelineStage::kNone && + !(stage == ast::PipelineStage::kFragment && is_input)) { + is_stage_mismatch = true; + } if (!type->Is()) { AddError("store type of " + deco_to_str(deco) + " must be 'u32'", deco->source()); @@ -1037,6 +1093,15 @@ bool Resolver::ValidateBuiltinDecoration(const ast::BuiltinDecoration* deco, default: break; } + + if (is_stage_mismatch) { + AddError(deco_to_str(deco) + " cannot be used in " + + (is_input ? "input of " : "output of ") + stage_name.str() + + " pipeline stage", + deco->source()); + return false; + } + return true; } @@ -1070,12 +1135,9 @@ bool Resolver::ValidateFunction(const ast::Function* func, return false; } - auto stage_deco_count = 0; auto workgroup_deco_count = 0; for (auto* deco : func->decorations()) { - if (deco->Is()) { - stage_deco_count++; - } else if (deco->Is()) { + if (deco->Is()) { workgroup_deco_count++; if (func->pipeline_stage() != ast::PipelineStage::kCompute) { AddError( @@ -1083,7 +1145,8 @@ bool Resolver::ValidateFunction(const ast::Function* func, deco->source()); return false; } - } else if (!deco->Is()) { + } else if (!deco->IsAnyOf()) { AddError("decoration is not valid for functions", deco->source()); return false; } @@ -1119,20 +1182,24 @@ bool Resolver::ValidateFunction(const ast::Function* func, for (auto* deco : func->return_type_decorations()) { if (!func->IsEntryPoint()) { - AddError("decoration is not valid for function return types", - deco->source()); + AddError( + "decoration is not valid for non-entry point function return types", + deco->source()); return false; } - if (auto* builtin = deco->As()) { - if (!ValidateBuiltinDecoration(builtin, info->return_type)) { - return false; - } - } else if (auto* interpolate = deco->As()) { + if (auto* interpolate = deco->As()) { if (!ValidateInterpolateDecoration(interpolate, info->return_type)) { return false; } - } else if (!deco->Is()) { + } else if (!deco->IsAnyOf() && + (IsValidationEnabled( + info->declaration->decorations(), + ast::DisabledValidation::kEntryPointParameter) && + IsValidationEnabled(info->declaration->decorations(), + ast::DisabledValidation:: + kIgnoreAtomicFunctionParameter))) { AddError("decoration is not valid for entry point return types", deco->source()); return false; @@ -1192,6 +1259,12 @@ bool Resolver::ValidateEntryPoint(const ast::Function* func, } builtins.emplace(builtin->value()); + if (!ValidateBuiltinDecoration(builtin, ty, + /* is_input */ param_or_ret == + ParamOrRetType::kParameter)) { + return false; + } + } else if (auto* location = deco->As()) { if (pipeline_io_attribute) { AddError("multiple entry point IO attributes", deco->source()); @@ -1409,7 +1482,6 @@ bool Resolver::Function(ast::Function* func) { return false; } - // TODO(amaiorano): Validate parameter decorations for (auto* deco : param->decorations()) { Mark(deco); } diff --git a/src/resolver/resolver.h b/src/resolver/resolver.h index e8e553f497..3306294977 100644 --- a/src/resolver/resolver.h +++ b/src/resolver/resolver.h @@ -273,7 +273,8 @@ class Resolver { bool ValidateAtomicUses(); bool ValidateAssignment(const ast::AssignmentStatement* a); bool ValidateBuiltinDecoration(const ast::BuiltinDecoration* deco, - const sem::Type* storage_type); + const sem::Type* storage_type, + const bool is_input = true); bool ValidateCallStatement(ast::CallStatement* stmt); bool ValidateEntryPoint(const ast::Function* func, const FunctionInfo* info); bool ValidateFunction(const ast::Function* func, const FunctionInfo* info); diff --git a/src/resolver/validation_test.cc b/src/resolver/validation_test.cc index fee3f575a5..32b01d2186 100644 --- a/src/resolver/validation_test.cc +++ b/src/resolver/validation_test.cc @@ -85,9 +85,8 @@ TEST_F(ResolverValidationTest, WorkgroupMemoryUsedInFragmentStage) { // fn f2(){ dst = wg; } // fn f1() { f2(); } // [[stage(fragment)]] - // fn f0() -> [[builtin(position)]] vec4 { + // fn f0() { // f1(); - // return dst; //} Global(Source{{1, 2}}, "wg", ty.vec4(), ast::StorageClass::kWorkgroup); @@ -97,10 +96,9 @@ TEST_F(ResolverValidationTest, WorkgroupMemoryUsedInFragmentStage) { Func(Source{{5, 6}}, "f2", ast::VariableList{}, ty.void_(), {stmt}); Func(Source{{7, 8}}, "f1", ast::VariableList{}, ty.void_(), {Ignore(Call("f2"))}); - Func(Source{{9, 10}}, "f0", ast::VariableList{}, ty.vec4(), - {Ignore(Call("f1")), Return(Expr("dst"))}, - ast::DecorationList{Stage(ast::PipelineStage::kFragment)}, - ast::DecorationList{Builtin(ast::Builtin::kPosition)}); + Func(Source{{9, 10}}, "f0", ast::VariableList{}, ty.void_(), + {Ignore(Call("f1"))}, + ast::DecorationList{Stage(ast::PipelineStage::kFragment)}); EXPECT_FALSE(r()->Resolve()); EXPECT_EQ( diff --git a/src/writer/spirv/builder_entry_point_test.cc b/src/writer/spirv/builder_entry_point_test.cc index ebadbb4b7e..01cc19986d 100644 --- a/src/writer/spirv/builder_entry_point_test.cc +++ b/src/writer/spirv/builder_entry_point_test.cc @@ -296,7 +296,7 @@ OpFunctionEnd TEST_F(BuilderTest, SampleIndex_SampleRateShadingCapability) { Func("main", {Param("sample_index", ty.u32(), {Builtin(ast::Builtin::kSampleIndex)})}, - ty.void_(), {}, {Stage(ast::PipelineStage::kCompute), WorkgroupSize(1)}); + ty.void_(), {}, {Stage(ast::PipelineStage::kFragment)}); spirv::Builder& b = SanitizeAndBuild();