From 4d697515729dcfdf689fd9792d979e73499b7efd Mon Sep 17 00:00:00 2001 From: Sarah Date: Mon, 21 Jun 2021 17:08:05 +0000 Subject: [PATCH] validation: matrix element type must be 'f32' Bug: tint:784 Change-Id: Iafb1d3e16beb489d588b7af6aac18f7cee26154b Reviewed-on: https://dawn-review.googlesource.com/c/tint/+/54900 Auto-Submit: Sarah Mashayekhi Kokoro: Kokoro Reviewed-by: Ben Clayton --- src/resolver/call_test.cc | 6 ---- src/resolver/entry_point_validation_test.cc | 9 ----- src/resolver/inferred_type_test.cc | 4 --- src/resolver/resolver.cc | 25 +++++++++++-- src/resolver/resolver.h | 1 + src/resolver/resolver_test.cc | 12 ++----- src/resolver/struct_layout_test.cc | 20 +++++------ .../type_constructor_validation_test.cc | 8 ----- src/resolver/validation_test.cc | 25 ++++++------- .../generator_impl_member_accessor_test.cc | 36 +++++++++---------- 10 files changed, 65 insertions(+), 81 deletions(-) diff --git a/src/resolver/call_test.cc b/src/resolver/call_test.cc index e7e4d245a0..3c96d10629 100644 --- a/src/resolver/call_test.cc +++ b/src/resolver/call_test.cc @@ -77,14 +77,8 @@ static constexpr Params all_param_types[] = { ParamsFor>(), // ParamsFor>(), // ParamsFor>(), // - ParamsFor>(), // - ParamsFor>(), // ParamsFor>(), // - ParamsFor>(), // - ParamsFor>(), // ParamsFor>(), // - ParamsFor>(), // - ParamsFor>(), // ParamsFor>() // }; diff --git a/src/resolver/entry_point_validation_test.cc b/src/resolver/entry_point_validation_test.cc index cd99f6fc9e..881a44db7e 100644 --- a/src/resolver/entry_point_validation_test.cc +++ b/src/resolver/entry_point_validation_test.cc @@ -564,17 +564,8 @@ static constexpr Params cases[] = { ParamsFor>(true), // ParamsFor>(true), // ParamsFor>(false), // - ParamsFor>(false), // - ParamsFor>(false), // - ParamsFor>(false), // ParamsFor>(false), // - ParamsFor>(false), // - ParamsFor>(false), // - ParamsFor>(false), // ParamsFor>(false), // - ParamsFor>(false), // - ParamsFor>(false), // - ParamsFor>(false), // ParamsFor>(true), // ParamsFor>(true), // ParamsFor>(true), // diff --git a/src/resolver/inferred_type_test.cc b/src/resolver/inferred_type_test.cc index bca018e2c9..789964bc93 100644 --- a/src/resolver/inferred_type_test.cc +++ b/src/resolver/inferred_type_test.cc @@ -65,8 +65,6 @@ Params all_cases[] = { ParamsFor>(), // ParamsFor>(), // ParamsFor>(), // - ParamsFor>(), // - ParamsFor>(), // ParamsFor>(), // ParamsFor>(), // ParamsFor>(), // @@ -76,8 +74,6 @@ Params all_cases[] = { ParamsFor>>(), // ParamsFor>>(), // ParamsFor>>(), // - ParamsFor>>(), // - ParamsFor>>(), // ParamsFor>>(), // }; diff --git a/src/resolver/resolver.cc b/src/resolver/resolver.cc index a9f0e8db81..9b7234f862 100644 --- a/src/resolver/resolver.cc +++ b/src/resolver/resolver.cc @@ -293,9 +293,15 @@ sem::Type* Resolver::Type(const ast::Type* ty) { } if (auto* t = ty->As()) { if (auto* el = Type(t->type())) { - auto* column_type = builder_->create( - const_cast(el), t->rows()); - return builder_->create(column_type, t->columns()); + if (auto* column_type = builder_->create( + const_cast(el), t->rows())) { + if (auto* matrix_type = + builder_->create(column_type, t->columns())) { + if (ValidateMatrix(matrix_type, t->source())) { + return matrix_type; + } + } + } } return nullptr; } @@ -2139,6 +2145,15 @@ bool Resolver::ValidateVectorConstructor( return true; } +bool Resolver::ValidateMatrix(const sem::Matrix* matrix_type, + const Source& source) { + if (!matrix_type->is_float_matrix()) { + diagnostics_.add_error("matrix element type must be 'f32'", source); + return false; + } + return true; +} // namespace resolver + bool Resolver::ValidateMatrixConstructor( const ast::TypeConstructorExpression* ctor, const sem::Matrix* matrix_type) { @@ -2148,6 +2163,10 @@ bool Resolver::ValidateMatrixConstructor( return true; } + if (!ValidateMatrix(matrix_type, ctor->source())) { + return false; + } + auto* elem_type = matrix_type->type(); if (matrix_type->columns() != values.size()) { const Source& values_start = values[0]->source(); diff --git a/src/resolver/resolver.h b/src/resolver/resolver.h index 07463d3583..d92b4b136e 100644 --- a/src/resolver/resolver.h +++ b/src/resolver/resolver.h @@ -272,6 +272,7 @@ class Resolver { bool ValidateEntryPoint(const ast::Function* func, const FunctionInfo* info); bool ValidateFunction(const ast::Function* func, const FunctionInfo* info); bool ValidateGlobalVariable(const VariableInfo* var); + bool ValidateMatrix(const sem::Matrix* matirx_type, const Source& source); bool ValidateMatrixConstructor(const ast::TypeConstructorExpression* ctor, const sem::Matrix* matrix_type); bool ValidateParameter(const VariableInfo* info); diff --git a/src/resolver/resolver_test.cc b/src/resolver/resolver_test.cc index facdd7fd47..c5e7176597 100644 --- a/src/resolver/resolver_test.cc +++ b/src/resolver/resolver_test.cc @@ -496,7 +496,7 @@ TEST_F(ResolverTest, ArrayAccessor_Matrix_Dynamic_Ref) { } TEST_F(ResolverTest, ArrayAccessor_Matrix_BothDimensions_Dynamic_Ref) { - Global("my_var", ty.mat4x4(), ast::StorageClass::kOutput); + Global("my_var", ty.mat4x4(), ast::StorageClass::kOutput); auto* idx = Var("idx", ty.u32(), Expr(3u)); auto* idy = Var("idy", ty.u32(), Expr(2u)); auto* acc = IndexAccessor(IndexAccessor("my_var", idx), idy); @@ -517,7 +517,7 @@ TEST_F(ResolverTest, ArrayAccessor_Matrix_Dynamic) { } TEST_F(ResolverTest, ArrayAccessor_Matrix_XDimension_Dynamic) { - GlobalConst("my_var", ty.mat4x4(), Construct(ty.mat4x4())); + GlobalConst("my_var", ty.mat4x4(), Construct(ty.mat4x4())); auto* idx = Var("idx", ty.u32(), Expr(3u)); auto* acc = IndexAccessor("my_var", Expr(Source{{12, 34}}, idx)); WrapInFunction(Decl(idx), acc); @@ -528,7 +528,7 @@ TEST_F(ResolverTest, ArrayAccessor_Matrix_XDimension_Dynamic) { } TEST_F(ResolverTest, ArrayAccessor_Matrix_BothDimension_Dynamic) { - GlobalConst("my_var", ty.mat4x4(), Construct(ty.mat4x4())); + GlobalConst("my_var", ty.mat4x4(), Construct(ty.mat4x4())); auto* idx = Var("idy", ty.u32(), Expr(2u)); auto* acc = IndexAccessor(IndexAccessor("my_var", Expr(Source{{12, 34}}, idx)), 1); @@ -1486,14 +1486,8 @@ static constexpr builder::ast_type_func_ptr all_create_type_funcs[] = { DataType>::AST, // DataType>::AST, // DataType>::AST, // - DataType>::AST, // - DataType>::AST, // DataType>::AST, // - DataType>::AST, // - DataType>::AST, // DataType>::AST, // - DataType>::AST, // - DataType>::AST, // DataType>::AST // }; diff --git a/src/resolver/struct_layout_test.cc b/src/resolver/struct_layout_test.cc index f516412d60..709724377f 100644 --- a/src/resolver/struct_layout_test.cc +++ b/src/resolver/struct_layout_test.cc @@ -242,15 +242,15 @@ TEST_F(ResolverStructLayoutTest, Vector) { TEST_F(ResolverStructLayoutTest, Matrix) { auto* s = Structure("S", { - Member("a", ty.mat2x2()), - Member("b", ty.mat2x3()), - Member("c", ty.mat2x4()), - Member("d", ty.mat3x2()), - Member("e", ty.mat3x3()), - Member("f", ty.mat3x4()), - Member("g", ty.mat4x2()), - Member("h", ty.mat4x3()), - Member("i", ty.mat4x4()), + Member("a", ty.mat2x2()), + Member("b", ty.mat2x3()), + Member("c", ty.mat2x4()), + Member("d", ty.mat3x2()), + Member("e", ty.mat3x3()), + Member("f", ty.mat3x4()), + Member("g", ty.mat4x2()), + Member("h", ty.mat4x3()), + Member("i", ty.mat4x4()), }); ASSERT_TRUE(r()->Resolve()) << r()->error(); @@ -292,7 +292,7 @@ TEST_F(ResolverStructLayoutTest, Matrix) { TEST_F(ResolverStructLayoutTest, NestedStruct) { auto* inner = Structure("Inner", { - Member("a", ty.mat3x3()), + Member("a", ty.mat3x3()), }); auto* s = Structure("S", { Member("a", ty.i32()), diff --git a/src/resolver/type_constructor_validation_test.cc b/src/resolver/type_constructor_validation_test.cc index bd9cf20138..44066dd8fb 100644 --- a/src/resolver/type_constructor_validation_test.cc +++ b/src/resolver/type_constructor_validation_test.cc @@ -110,8 +110,6 @@ static constexpr Params from_constructor_expression_cases[] = { ParamsFor>(), ParamsFor>(), ParamsFor>(), - ParamsFor>(), - ParamsFor>(), ParamsFor>(), ParamsFor>(), ParamsFor>(), @@ -120,8 +118,6 @@ static constexpr Params from_constructor_expression_cases[] = { ParamsFor>>(), ParamsFor>>(), ParamsFor>>(), - ParamsFor>>(), - ParamsFor>>(), ParamsFor>>(), }; INSTANTIATE_TEST_SUITE_P(ResolverTypeConstructorValidationTest, @@ -208,8 +204,6 @@ static constexpr Params from_call_expression_cases[] = { ParamsFor>(), ParamsFor>(), ParamsFor>(), - ParamsFor>(), - ParamsFor>(), ParamsFor>(), ParamsFor>(), ParamsFor>(), @@ -218,8 +212,6 @@ static constexpr Params from_call_expression_cases[] = { ParamsFor>>(), ParamsFor>>(), ParamsFor>>(), - ParamsFor>>(), - ParamsFor>>(), ParamsFor>>(), }; INSTANTIATE_TEST_SUITE_P(ResolverTypeConstructorValidationTest, diff --git a/src/resolver/validation_test.cc b/src/resolver/validation_test.cc index 689ef588f7..3af33e0e49 100644 --- a/src/resolver/validation_test.cc +++ b/src/resolver/validation_test.cc @@ -876,20 +876,6 @@ TEST_F(ResolverValidationTest, "expected 'array', found 'array'"); } -TEST_F(ResolverValidationTest, - Expr_Constructor_ArrayOfMatrix_SubElemTypeMismatch) { - // array, 2>(mat2x2(), mat2x2()); - auto* e0 = mat2x2(); - SetSource(Source::Location({12, 34})); - auto* e1 = mat2x2(); - auto* t = Construct(ty.array(ty.mat2x2(), 2), e0, e1); - WrapInFunction(t); - - EXPECT_FALSE(r()->Resolve()); - EXPECT_EQ(r()->error(), - "12:34 error: type in array constructor does not match array type: " - "expected 'mat2x2', found 'mat2x2'"); -} TEST_F(ResolverValidationTest, Expr_Constructor_Array_TooFewElements) { // array(1, 2, 3); SetSource(Source::Location({12, 34})); @@ -1987,6 +1973,17 @@ std::string VecStr(uint32_t dimensions, std::string subtype = "f32") { using MatrixConstructorTest = ResolverTestWithParam; +TEST_F(MatrixConstructorTest, Expr_Constructor_Matrix_NotF32) { + // m2x2() + SetSource(Source::Location({12, 34})); + auto* tc = mat2x2( + create(ty.mat2x2(), ExprList())); + WrapInFunction(tc); + + EXPECT_FALSE(r()->Resolve()); + EXPECT_EQ(r()->error(), "12:34 error: matrix element type must be 'f32'"); +} + TEST_P(MatrixConstructorTest, Expr_Constructor_Error_TooFewArguments) { // matNxM(vecM(), ...); with N - 1 arguments diff --git a/src/writer/hlsl/generator_impl_member_accessor_test.cc b/src/writer/hlsl/generator_impl_member_accessor_test.cc index 2d68500804..69f8142d7e 100644 --- a/src/writer/hlsl/generator_impl_member_accessor_test.cc +++ b/src/writer/hlsl/generator_impl_member_accessor_test.cc @@ -202,32 +202,32 @@ INSTANTIATE_TEST_SUITE_P( TypeCase{ty_vec4, "asfloat(data.Load4(16u))"}, TypeCase{ty_vec4, "asint(data.Load4(16u))"}, TypeCase{ - ty_mat2x2, - R"(return uint2x2(buffer.Load2((offset + 0u)), buffer.Load2((offset + 8u)));)"}, + ty_mat2x2, + R"(return float2x2(asfloat(buffer.Load2((offset + 0u))), asfloat(buffer.Load2((offset + 8u))));)"}, TypeCase{ ty_mat2x3, R"(return float2x3(asfloat(buffer.Load3((offset + 0u))), asfloat(buffer.Load3((offset + 16u))));)"}, TypeCase{ - ty_mat2x4, - R"(return int2x4(asint(buffer.Load4((offset + 0u))), asint(buffer.Load4((offset + 16u))));)"}, + ty_mat2x4, + R"(return float2x4(asfloat(buffer.Load4((offset + 0u))), asfloat(buffer.Load4((offset + 16u))));)"}, TypeCase{ - ty_mat3x2, - R"(return uint3x2(buffer.Load2((offset + 0u)), buffer.Load2((offset + 8u)), buffer.Load2((offset + 16u)));)"}, + ty_mat3x2, + R"(return float3x2(asfloat(buffer.Load2((offset + 0u))), asfloat(buffer.Load2((offset + 8u))), asfloat(buffer.Load2((offset + 16u))));)"}, TypeCase{ ty_mat3x3, R"(return float3x3(asfloat(buffer.Load3((offset + 0u))), asfloat(buffer.Load3((offset + 16u))), asfloat(buffer.Load3((offset + 32u))));)"}, TypeCase{ - ty_mat3x4, - R"(return int3x4(asint(buffer.Load4((offset + 0u))), asint(buffer.Load4((offset + 16u))), asint(buffer.Load4((offset + 32u))));)"}, + ty_mat3x4, + R"(return float3x4(asfloat(buffer.Load4((offset + 0u))), asfloat(buffer.Load4((offset + 16u))), asfloat(buffer.Load4((offset + 32u))));)"}, TypeCase{ - ty_mat4x2, - R"(return uint4x2(buffer.Load2((offset + 0u)), buffer.Load2((offset + 8u)), buffer.Load2((offset + 16u)), buffer.Load2((offset + 24u)));)"}, + ty_mat4x2, + R"(return float4x2(asfloat(buffer.Load2((offset + 0u))), asfloat(buffer.Load2((offset + 8u))), asfloat(buffer.Load2((offset + 16u))), asfloat(buffer.Load2((offset + 24u))));)"}, TypeCase{ ty_mat4x3, R"(return float4x3(asfloat(buffer.Load3((offset + 0u))), asfloat(buffer.Load3((offset + 16u))), asfloat(buffer.Load3((offset + 32u))), asfloat(buffer.Load3((offset + 48u))));)"}, TypeCase{ - ty_mat4x4, - R"(return int4x4(asint(buffer.Load4((offset + 0u))), asint(buffer.Load4((offset + 16u))), asint(buffer.Load4((offset + 32u))), asint(buffer.Load4((offset + 48u))));)"})); + ty_mat4x4, + R"(return float4x4(asfloat(buffer.Load4((offset + 0u))), asfloat(buffer.Load4((offset + 16u))), asfloat(buffer.Load4((offset + 32u))), asfloat(buffer.Load4((offset + 48u))));)"})); using HlslGeneratorImplTest_MemberAccessor_StorageBufferStore = HlslGeneratorImplTest_MemberAccessorWithParam; @@ -273,7 +273,7 @@ INSTANTIATE_TEST_SUITE_P( TypeCase{ty_vec4, "data.Store4(16u, asuint(value))"}, TypeCase{ty_vec4, "data.Store4(16u, asuint(value))"}, TypeCase{ty_vec4, "data.Store4(16u, asuint(value))"}, - TypeCase{ty_mat2x2, R"({ + TypeCase{ty_mat2x2, R"({ buffer.Store2((offset + 0u), asuint(value[0u])); buffer.Store2((offset + 8u), asuint(value[1u])); })"}, @@ -281,11 +281,11 @@ INSTANTIATE_TEST_SUITE_P( buffer.Store3((offset + 0u), asuint(value[0u])); buffer.Store3((offset + 16u), asuint(value[1u])); })"}, - TypeCase{ty_mat2x4, R"({ + TypeCase{ty_mat2x4, R"({ buffer.Store4((offset + 0u), asuint(value[0u])); buffer.Store4((offset + 16u), asuint(value[1u])); })"}, - TypeCase{ty_mat3x2, R"({ + TypeCase{ty_mat3x2, R"({ buffer.Store2((offset + 0u), asuint(value[0u])); buffer.Store2((offset + 8u), asuint(value[1u])); buffer.Store2((offset + 16u), asuint(value[2u])); @@ -295,12 +295,12 @@ INSTANTIATE_TEST_SUITE_P( buffer.Store3((offset + 16u), asuint(value[1u])); buffer.Store3((offset + 32u), asuint(value[2u])); })"}, - TypeCase{ty_mat3x4, R"({ + TypeCase{ty_mat3x4, R"({ buffer.Store4((offset + 0u), asuint(value[0u])); buffer.Store4((offset + 16u), asuint(value[1u])); buffer.Store4((offset + 32u), asuint(value[2u])); })"}, - TypeCase{ty_mat4x2, R"({ + TypeCase{ty_mat4x2, R"({ buffer.Store2((offset + 0u), asuint(value[0u])); buffer.Store2((offset + 8u), asuint(value[1u])); buffer.Store2((offset + 16u), asuint(value[2u])); @@ -312,7 +312,7 @@ INSTANTIATE_TEST_SUITE_P( buffer.Store3((offset + 32u), asuint(value[2u])); buffer.Store3((offset + 48u), asuint(value[3u])); })"}, - TypeCase{ty_mat4x4, R"({ + TypeCase{ty_mat4x4, R"({ buffer.Store4((offset + 0u), asuint(value[0u])); buffer.Store4((offset + 16u), asuint(value[1u])); buffer.Store4((offset + 32u), asuint(value[2u]));