diff --git a/src/resolver/builtins_validation_test.cc b/src/resolver/builtins_validation_test.cc index c6aa820e3c..4af121dbba 100644 --- a/src/resolver/builtins_validation_test.cc +++ b/src/resolver/builtins_validation_test.cc @@ -27,14 +27,13 @@ template using vec3 = builder::vec3; template using vec4 = builder::vec4; -template using f32 = builder::f32; using i32 = builder::i32; using u32 = builder::u32; class ResolverBuiltinsValidationTest : public resolver::TestHelper, public testing::Test {}; -namespace TypeTemp { +namespace StageTest { struct Params { builder::ast_type_func_ptr type; ast::Builtin builtin; @@ -218,7 +217,7 @@ TEST_F(ResolverBuiltinsValidationTest, FragDepthIsInputStruct_Fail) { "12:34 error: builtin(frag_depth) cannot be used in input of fragment " "pipeline stage\nnote: while analysing entry point fragShader"); } -} // namespace TypeTemp +} // namespace StageTest TEST_F(ResolverBuiltinsValidationTest, PositionNotF32_Struct_Fail) { // struct MyInputs { diff --git a/src/resolver/resolver.cc b/src/resolver/resolver.cc index 881d3800b4..0c48f04c74 100644 --- a/src/resolver/resolver.cc +++ b/src/resolver/resolver.cc @@ -2266,6 +2266,9 @@ bool Resolver::Constructor(ast::ConstructorExpression* expr) { if (auto* arr_type = type->As()) { return ValidateArrayConstructor(type_ctor, arr_type); } + if (auto* struct_type = type->As()) { + return ValidateStructureConstructor(type_ctor, struct_type); + } } else if (auto* scalar_ctor = expr->As()) { Mark(scalar_ctor->literal()); auto* type = TypeOf(scalar_ctor->literal()); @@ -2280,6 +2283,36 @@ bool Resolver::Constructor(ast::ConstructorExpression* expr) { return true; } +bool Resolver::ValidateStructureConstructor( + const ast::TypeConstructorExpression* ctor, + const sem::Struct* struct_type) { + if (ctor->values().size() > 0) { + if (ctor->values().size() != struct_type->Members().size()) { + std::string fm = ctor->values().size() < struct_type->Members().size() + ? "few" + : "many"; + AddError("struct constructor has too " + fm + " inputs: expected " + + std::to_string(struct_type->Members().size()) + ", found " + + std::to_string(ctor->values().size()), + ctor->source()); + return false; + } + for (auto* member : struct_type->Members()) { + auto* value = ctor->values()[member->Index()]; + if (member->Type() != TypeOf(value)->UnwrapRef()) { + AddError( + "type in struct constructor does not match struct member type: " + "expected '" + + member->Type()->FriendlyName(builder_->Symbols()) + + "', found '" + TypeNameOf(value) + "'", + value->source()); + return false; + } + } + } + return true; +} + bool Resolver::ValidateArrayConstructor( const ast::TypeConstructorExpression* ctor, const sem::Array* array_type) { diff --git a/src/resolver/resolver.h b/src/resolver/resolver.h index 3306294977..04b7326088 100644 --- a/src/resolver/resolver.h +++ b/src/resolver/resolver.h @@ -294,6 +294,8 @@ class Resolver { bool ValidateStatements(const ast::StatementList& stmts); bool ValidateStorageTexture(const ast::StorageTexture* t); bool ValidateStructure(const sem::Struct* str); + bool ValidateStructureConstructor(const ast::TypeConstructorExpression* ctor, + const sem::Struct* struct_type); bool ValidateSwitch(const ast::SwitchStatement* s); bool ValidateVariable(const VariableInfo* info); bool ValidateVariableConstructor(const ast::Variable* var, diff --git a/src/resolver/struct_pipeline_stage_use_test.cc b/src/resolver/struct_pipeline_stage_use_test.cc index 1b2c18e863..7c0c0223c2 100644 --- a/src/resolver/struct_pipeline_stage_use_test.cc +++ b/src/resolver/struct_pipeline_stage_use_test.cc @@ -81,7 +81,7 @@ TEST_F(ResolverPipelineStageUseTest, StructUsedAsVertexShaderReturnType) { auto* s = Structure( "S", {Member("a", ty.vec4(), {Builtin(ast::Builtin::kPosition)})}); - Func("main", {}, ty.Of(s), {Return(Construct(ty.Of(s), Expr(0.f)))}, + Func("main", {}, ty.Of(s), {Return(Construct(ty.Of(s)))}, {Stage(ast::PipelineStage::kVertex)}); ASSERT_TRUE(r()->Resolve()) << r()->error(); @@ -141,8 +141,7 @@ TEST_F(ResolverPipelineStageUseTest, StructUsedMultipleStages) { "S", {Member("a", ty.vec4(), {Builtin(ast::Builtin::kPosition)})}); Func("vert_main", {Param("param", ty.Of(s))}, ty.Of(s), - {Return(Construct(ty.Of(s), Expr(0.f)))}, - {Stage(ast::PipelineStage::kVertex)}); + {Return(Construct(ty.Of(s)))}, {Stage(ast::PipelineStage::kVertex)}); Func("frag_main", {Param("param", ty.Of(s))}, ty.void_(), {}, {Stage(ast::PipelineStage::kFragment)}); diff --git a/src/resolver/validation_test.cc b/src/resolver/validation_test.cc index 32b01d2186..eaa2d2b793 100644 --- a/src/resolver/validation_test.cc +++ b/src/resolver/validation_test.cc @@ -2335,6 +2335,176 @@ INSTANTIATE_TEST_SUITE_P(ResolverValidationTest, MatrixDimensions{3, 4}, MatrixDimensions{4, 4})); +namespace StructConstructor { +using builder::CreatePtrs; +using builder::CreatePtrsFor; +using builder::f32; +using builder::i32; +using builder::mat2x2; +using builder::mat3x3; +using builder::mat4x4; +using builder::u32; +using builder::vec2; +using builder::vec3; +using builder::vec4; + +constexpr CreatePtrs all_types[] = { + CreatePtrsFor(), // + CreatePtrsFor(), // + CreatePtrsFor(), // + CreatePtrsFor(), // + CreatePtrsFor>(), // + CreatePtrsFor>(), // + CreatePtrsFor>(), // + CreatePtrsFor>(), // + CreatePtrsFor>(), // + CreatePtrsFor>(), // + CreatePtrsFor>() // +}; + +auto number_of_members = testing::Values(2u, 32u, 64u); + +using StructConstructorInputsTest = + ResolverTestWithParam>; // number of struct members +TEST_P(StructConstructorInputsTest, TooFew) { + auto& param = GetParam(); + auto& str_params = std::get<0>(param); + uint32_t N = std::get<1>(param); + + ast::StructMemberList members; + ast::ExpressionList values; + for (uint32_t i = 0; i < N; i++) { + auto* struct_type = str_params.ast(*this); + members.push_back(Member("member_" + std::to_string(i), struct_type)); + if (i < N - 1) { + auto* ctor_value_expr = str_params.expr(*this, 0); + values.push_back(ctor_value_expr); + } + } + auto* s = Structure("s", members); + auto* tc = create(Source{{12, 34}}, ty.Of(s), + values); + WrapInFunction(tc); + EXPECT_FALSE(r()->Resolve()); + EXPECT_EQ(r()->error(), + "12:34 error: struct constructor has too few inputs: expected " + + std::to_string(N) + ", found " + std::to_string(N - 1)); +} + +TEST_P(StructConstructorInputsTest, TooMany) { + auto& param = GetParam(); + auto& str_params = std::get<0>(param); + uint32_t N = std::get<1>(param); + + ast::StructMemberList members; + ast::ExpressionList values; + for (uint32_t i = 0; i < N + 1; i++) { + if (i < N) { + auto* struct_type = str_params.ast(*this); + members.push_back(Member("member_" + std::to_string(i), struct_type)); + } + auto* ctor_value_expr = str_params.expr(*this, 0); + values.push_back(ctor_value_expr); + } + auto* s = Structure("s", members); + auto* tc = create(Source{{12, 34}}, ty.Of(s), + values); + WrapInFunction(tc); + EXPECT_FALSE(r()->Resolve()); + EXPECT_EQ(r()->error(), + "12:34 error: struct constructor has too many inputs: expected " + + std::to_string(N) + ", found " + std::to_string(N + 1)); +} + +INSTANTIATE_TEST_SUITE_P(ResolverValidationTest, + StructConstructorInputsTest, + testing::Combine(testing::ValuesIn(all_types), + number_of_members)); +using StructConstructorTypeTest = + ResolverTestWithParam>; // number of struct members +TEST_P(StructConstructorTypeTest, AllTypes) { + auto& param = GetParam(); + auto& str_params = std::get<0>(param); + auto& ctor_params = std::get<1>(param); + uint32_t N = std::get<2>(param); + + if (str_params.ast == ctor_params.ast) { + return; + } + + ast::StructMemberList members; + ast::ExpressionList values; + // make the last value of the constructor to have a different type + uint32_t constructor_value_with_different_type = N - 1; + for (uint32_t i = 0; i < N; i++) { + auto* struct_type = str_params.ast(*this); + members.push_back(Member("member_" + std::to_string(i), struct_type)); + auto* ctor_value_expr = (i == constructor_value_with_different_type) + ? ctor_params.expr(*this, 0) + : str_params.expr(*this, 0); + values.push_back(ctor_value_expr); + } + auto* s = Structure("s", members); + auto* tc = create(ty.Of(s), values); + WrapInFunction(tc); + + std::string found = FriendlyName(ctor_params.ast(*this)); + std::string expected = FriendlyName(str_params.ast(*this)); + std::stringstream err; + err << "error: type in struct constructor does not match struct member "; + err << "type: expected '" << expected << "', found '" << found << "'"; + EXPECT_FALSE(r()->Resolve()); + EXPECT_EQ(r()->error(), err.str()); +} + +INSTANTIATE_TEST_SUITE_P(ResolverValidationTest, + StructConstructorTypeTest, + testing::Combine(testing::ValuesIn(all_types), + testing::ValuesIn(all_types), + number_of_members)); + +TEST_F(ResolverValidationTest, Expr_Constructor_Struct_Nested) { + auto* inner_m = Member("m", ty.i32()); + auto* inner_s = Structure("inner_s", {inner_m}); + + auto* m0 = Member("m", ty.i32()); + auto* m1 = Member("m", ty.Of(inner_s)); + auto* m2 = Member("m", ty.i32()); + auto* s = Structure("s", {m0, m1, m2}); + + auto* tc = create(Source{{12, 34}}, ty.Of(s), + ExprList(1, 1, 1)); + WrapInFunction(tc); + EXPECT_FALSE(r()->Resolve()); + EXPECT_EQ(r()->error(), + "error: type in struct constructor does not match struct member " + "type: expected 'inner_s', found 'i32'"); +} + +TEST_F(ResolverValidationTest, Expr_Constructor_Struct) { + auto* m = Member("m", ty.i32()); + auto* s = Structure("MyInputs", {m}); + auto* tc = create(Source{{12, 34}}, ty.Of(s), + ExprList()); + WrapInFunction(tc); + EXPECT_TRUE(r()->Resolve()) << r()->error(); +} + +TEST_F(ResolverValidationTest, Expr_Constructor_Struct_Empty) { + auto* str = Structure("S", { + Member("a", ty.i32()), + Member("b", ty.f32()), + Member("c", ty.vec3()), + }); + + WrapInFunction(Construct(ty.Of(str))); + EXPECT_TRUE(r()->Resolve()) << r()->error(); +} +} // namespace StructConstructor + } // namespace } // namespace resolver } // namespace tint diff --git a/src/transform/spirv_test.cc b/src/transform/spirv_test.cc index 189e8d5de2..f897004fec 100644 --- a/src/transform/spirv_test.cc +++ b/src/transform/spirv_test.cc @@ -500,7 +500,7 @@ struct FragmentInterface { [[stage(vertex)]] fn vert_main(in : VertexIn) -> VertexOut { - return VertexOut(in.i, in.u, in.vi, in.vu); + return VertexOut(in.i, in.u, in.vi, in.vu, vec4()); } [[stage(fragment)]] @@ -561,7 +561,7 @@ fn tint_symbol_11(tint_symbol_5 : VertexOut) { [[stage(vertex)]] fn vert_main() { let tint_symbol_4 : VertexIn = VertexIn(tint_symbol, tint_symbol_1, tint_symbol_2, tint_symbol_3); - tint_symbol_11(VertexOut(tint_symbol_4.i, tint_symbol_4.u, tint_symbol_4.vi, tint_symbol_4.vu)); + tint_symbol_11(VertexOut(tint_symbol_4.i, tint_symbol_4.u, tint_symbol_4.vi, tint_symbol_4.vu, vec4())); return; } diff --git a/src/writer/hlsl/generator_impl_function_test.cc b/src/writer/hlsl/generator_impl_function_test.cc index 1499596466..57de37ce36 100644 --- a/src/writer/hlsl/generator_impl_function_test.cc +++ b/src/writer/hlsl/generator_impl_function_test.cc @@ -264,13 +264,10 @@ TEST_F(HlslGeneratorImplTest_Function, {}); Func("vert_main1", {}, ty.Of(vertex_output_struct), - {Return(Construct(ty.Of(vertex_output_struct), - Expr(Call("foo", Expr(0.5f)))))}, - {Stage(ast::PipelineStage::kVertex)}); + {Return(Call("foo", Expr(0.5f)))}, {Stage(ast::PipelineStage::kVertex)}); Func("vert_main2", {}, ty.Of(vertex_output_struct), - {Return(Construct(ty.Of(vertex_output_struct), - Expr(Call("foo", Expr(0.25f)))))}, + {Return(Call("foo", Expr(0.25f)))}, {Stage(ast::PipelineStage::kVertex)}); GeneratorImpl& gen = SanitizeAndBuild(); @@ -290,7 +287,7 @@ struct tint_symbol { }; tint_symbol vert_main1() { - const VertexOutput tint_symbol_1 = {foo(0.5f)}; + const VertexOutput tint_symbol_1 = foo(0.5f); const tint_symbol tint_symbol_5 = {tint_symbol_1.pos}; return tint_symbol_5; } @@ -300,7 +297,7 @@ struct tint_symbol_2 { }; tint_symbol_2 vert_main2() { - const VertexOutput tint_symbol_3 = {foo(0.25f)}; + const VertexOutput tint_symbol_3 = foo(0.25f); const tint_symbol_2 tint_symbol_6 = {tint_symbol_3.pos}; return tint_symbol_6; } diff --git a/src/writer/spirv/builder_constructor_expression_test.cc b/src/writer/spirv/builder_constructor_expression_test.cc index 08c130eb86..f1ba7c1642 100644 --- a/src/writer/spirv/builder_constructor_expression_test.cc +++ b/src/writer/spirv/builder_constructor_expression_test.cc @@ -1793,9 +1793,9 @@ TEST_F(SpvBuilderConstructorTest, }); Global("a", ty.f32(), ast::StorageClass::kPrivate); - Global("b", ty.f32(), ast::StorageClass::kPrivate); + Global("b", ty.vec3(), ast::StorageClass::kPrivate); - auto* t = Construct(ty.Of(s), 2.f, "a", 2.f); + auto* t = Construct(ty.Of(s), "a", "b"); WrapInFunction(t); spirv::Builder& b = Build();