resolver: Validate vector types

Fixed: tint:953
Change-Id: I3742680e49894a93db41219e512796ba9bdf036a
Reviewed-on: https://dawn-review.googlesource.com/c/tint/+/56778
Auto-Submit: Ben Clayton <bclayton@google.com>
Commit-Queue: Sarah Mashayekhi <sarahmashay@google.com>
Kokoro: Kokoro <noreply+kokoro@google.com>
Reviewed-by: Ben Clayton <bclayton@chromium.org>
Reviewed-by: Sarah Mashayekhi <sarahmashay@google.com>
This commit is contained in:
Ben Clayton 2021-07-05 20:21:35 +00:00 committed by Tint LUCI CQ
parent b4ff73e250
commit ffe7978dbf
6 changed files with 189 additions and 35 deletions

View File

@ -324,8 +324,12 @@ sem::Type* Resolver::Type(const ast::Type* ty) {
}
if (auto* t = ty->As<ast::Vector>()) {
if (auto* el = Type(t->type())) {
return builder_->create<sem::Vector>(const_cast<sem::Type*>(el),
t->size());
if (auto* vector = builder_->create<sem::Vector>(
const_cast<sem::Type*>(el), t->size())) {
if (ValidateVector(vector, t->source())) {
return vector;
}
}
}
return nullptr;
}
@ -333,10 +337,10 @@ sem::Type* Resolver::Type(const ast::Type* ty) {
if (auto* el = Type(t->type())) {
if (auto* column_type = builder_->create<sem::Vector>(
const_cast<sem::Type*>(el), t->rows())) {
if (auto* matrix_type =
if (auto* matrix =
builder_->create<sem::Matrix>(column_type, t->columns())) {
if (ValidateMatrix(matrix_type, t->source())) {
return matrix_type;
if (ValidateMatrix(matrix, t->source())) {
return matrix;
}
}
}
@ -2300,14 +2304,22 @@ bool Resolver::ValidateVectorConstructor(
return true;
}
bool Resolver::ValidateMatrix(const sem::Matrix* matrix_type,
const Source& source) {
if (!matrix_type->is_float_matrix()) {
bool Resolver::ValidateVector(const sem::Vector* ty, const Source& source) {
if (!ty->type()->is_scalar()) {
AddError("vector element type must be 'bool', 'f32', 'i32' or 'u32'",
source);
return false;
}
return true;
}
bool Resolver::ValidateMatrix(const sem::Matrix* ty, const Source& source) {
if (!ty->is_float_matrix()) {
AddError("matrix element type must be 'f32'", source);
return false;
}
return true;
} // namespace resolver
}
bool Resolver::ValidateMatrixConstructor(
const ast::TypeConstructorExpression* ctor,

View File

@ -280,7 +280,7 @@ class Resolver {
bool ValidateGlobalVariable(const VariableInfo* var);
bool ValidateInterpolateDecoration(const ast::InterpolateDecoration* deco,
const sem::Type* storage_type);
bool ValidateMatrix(const sem::Matrix* matirx_type, const Source& source);
bool ValidateMatrix(const sem::Matrix* ty, const Source& source);
bool ValidateMatrixConstructor(const ast::TypeConstructorExpression* ctor,
const sem::Matrix* matrix_type);
bool ValidateFunctionParameter(const ast::Function* func,
@ -300,6 +300,7 @@ class Resolver {
const std::string& type_name,
const sem::Type* rhs_type,
const std::string& rhs_type_name);
bool ValidateVector(const sem::Vector* ty, const Source& source);
bool ValidateVectorConstructor(const ast::TypeConstructorExpression* ctor,
const sem::Vector* vec_type);
bool ValidateScalarConstructor(const ast::TypeConstructorExpression* ctor,

View File

@ -1356,29 +1356,19 @@ TEST_F(ResolverTest, Expr_Accessor_MultiLevel) {
// vec4<f32> foo
// }
// struct A {
// vec3<struct b> mem
// array<b, 3> mem
// }
// var c : A
// c.mem[0].foo.yx
// -> vec2<f32>
//
// MemberAccessor{
// MemberAccessor{
// ArrayAccessor{
// MemberAccessor{
// Identifier{c}
// Identifier{mem}
// }
// ScalarConstructor{0}
// }
// Identifier{foo}
// }
// Identifier{yx}
// fn f() {
// c.mem[0].foo
// }
//
auto* stB = Structure("B", {Member("foo", ty.vec4<f32>())});
auto* stA = Structure("A", {Member("mem", ty.vec(ty.Of(stB), 3))});
auto* stA = Structure("A", {Member("mem", ty.array(ty.Of(stB), 3))});
Global("c", ty.Of(stA), ast::StorageClass::kPrivate);
auto* mem = MemberAccessor(

View File

@ -156,6 +156,9 @@ using mat3x3 = mat<3, 3, T>;
template <typename T>
using mat4x4 = mat<4, 4, T>;
template <int N, typename T>
struct array {};
template <typename TO, int ID = 0>
struct alias {};
@ -384,6 +387,43 @@ struct DataType<alias<T, ID>> {
}
};
/// Helper for building array types and expressions
template <int N, typename T>
struct DataType<array<N, T>> {
/// true as arrays are a composite type
static constexpr bool is_composite = true;
/// @param b the ProgramBuilder
/// @return a new AST array type
static inline ast::Type* AST(ProgramBuilder& b) {
return b.ty.array(DataType<T>::AST(b), N);
}
/// @param b the ProgramBuilder
/// @return the semantic array type
static inline sem::Type* Sem(ProgramBuilder& b) {
return b.create<sem::Array>(DataType<T>::Sem(b), N);
}
/// @param b the ProgramBuilder
/// @param elem_value the value each element in the array will be initialized
/// with
/// @return a new AST array value expression
static inline ast::Expression* Expr(ProgramBuilder& b, int elem_value) {
return b.Construct(AST(b), ExprArgs(b, elem_value));
}
/// @param b the ProgramBuilder
/// @param elem_value the value each element will be initialized with
/// @return the list of expressions that are used to construct the array
static inline ast::ExpressionList ExprArgs(ProgramBuilder& b,
int elem_value) {
ast::ExpressionList args;
for (int i = 0; i < N; i++) {
args.emplace_back(DataType<T>::Expr(b, elem_value));
}
return args;
}
};
/// Struct of all creation pointer types
struct CreatePtrs {
/// ast node type create function

View File

@ -42,6 +42,8 @@ template <typename T>
using mat3x3 = builder::mat3x3<T>;
template <typename T>
using mat4x4 = builder::mat4x4<T>;
template <int N, typename T>
using array = builder::array<N, T>;
template <typename T>
using alias = builder::alias<T>;
template <typename T>
@ -751,6 +753,126 @@ TEST_F(StorageTextureAccessTest, WriteOnlyAccess_Pass) {
} // namespace StorageTextureTests
namespace MatrixTests {
struct Params {
uint32_t columns;
uint32_t rows;
builder::ast_type_func_ptr elem_ty;
};
template <typename T>
constexpr Params ParamsFor(uint32_t columns, uint32_t rows) {
return Params{columns, rows, DataType<T>::AST};
}
using ValidMatrixTypes = ResolverTestWithParam<Params>;
TEST_P(ValidMatrixTypes, Okay) {
// var a : matNxM<EL_TY>;
auto& params = GetParam();
Global("a", ty.mat(params.elem_ty(*this), params.columns, params.rows),
ast::StorageClass::kPrivate);
EXPECT_TRUE(r()->Resolve()) << r()->error();
}
INSTANTIATE_TEST_SUITE_P(ResolverTypeValidationTest,
ValidMatrixTypes,
testing::Values(ParamsFor<f32>(2, 2),
ParamsFor<f32>(2, 3),
ParamsFor<f32>(2, 4),
ParamsFor<f32>(3, 2),
ParamsFor<f32>(3, 3),
ParamsFor<f32>(3, 4),
ParamsFor<f32>(4, 2),
ParamsFor<f32>(4, 3),
ParamsFor<f32>(4, 4),
ParamsFor<alias<f32>>(4, 2),
ParamsFor<alias<f32>>(4, 3),
ParamsFor<alias<f32>>(4, 4)));
using InvalidMatrixElementTypes = ResolverTestWithParam<Params>;
TEST_P(InvalidMatrixElementTypes, InvalidElementType) {
// var a : matNxM<EL_TY>;
auto& params = GetParam();
Global("a",
ty.mat(Source{{12, 34}}, params.elem_ty(*this), params.columns,
params.rows),
ast::StorageClass::kPrivate);
EXPECT_FALSE(r()->Resolve());
EXPECT_EQ(r()->error(), "12:34 error: matrix element type must be 'f32'");
}
INSTANTIATE_TEST_SUITE_P(ResolverTypeValidationTest,
InvalidMatrixElementTypes,
testing::Values(ParamsFor<bool>(4, 2),
ParamsFor<i32>(4, 3),
ParamsFor<u32>(4, 4),
ParamsFor<vec2<f32>>(2, 2),
ParamsFor<vec3<i32>>(2, 3),
ParamsFor<vec4<u32>>(2, 4),
ParamsFor<mat2x2<f32>>(3, 2),
ParamsFor<mat3x3<f32>>(3, 3),
ParamsFor<mat4x4<f32>>(3, 4),
ParamsFor<array<2, f32>>(4, 2)));
} // namespace MatrixTests
namespace VectorTests {
struct Params {
uint32_t width;
builder::ast_type_func_ptr elem_ty;
};
template <typename T>
constexpr Params ParamsFor(uint32_t width) {
return Params{width, DataType<T>::AST};
}
using ValidVectorTypes = ResolverTestWithParam<Params>;
TEST_P(ValidVectorTypes, Okay) {
// var a : vecN<EL_TY>;
auto& params = GetParam();
Global("a", ty.vec(params.elem_ty(*this), params.width),
ast::StorageClass::kPrivate);
EXPECT_TRUE(r()->Resolve()) << r()->error();
}
INSTANTIATE_TEST_SUITE_P(ResolverTypeValidationTest,
ValidVectorTypes,
testing::Values(ParamsFor<bool>(2),
ParamsFor<f32>(2),
ParamsFor<i32>(2),
ParamsFor<u32>(2),
ParamsFor<bool>(3),
ParamsFor<f32>(3),
ParamsFor<i32>(3),
ParamsFor<u32>(3),
ParamsFor<bool>(4),
ParamsFor<f32>(4),
ParamsFor<i32>(4),
ParamsFor<u32>(4),
ParamsFor<alias<bool>>(4),
ParamsFor<alias<f32>>(4),
ParamsFor<alias<i32>>(4),
ParamsFor<alias<u32>>(4)));
using InvalidVectorElementTypes = ResolverTestWithParam<Params>;
TEST_P(InvalidVectorElementTypes, InvalidElementType) {
// var a : vecN<EL_TY>;
auto& params = GetParam();
Global("a", ty.vec(Source{{12, 34}}, params.elem_ty(*this), params.width),
ast::StorageClass::kPrivate);
EXPECT_FALSE(r()->Resolve());
EXPECT_EQ(
r()->error(),
"12:34 error: vector element type must be 'bool', 'f32', 'i32' or 'u32'");
}
INSTANTIATE_TEST_SUITE_P(ResolverTypeValidationTest,
InvalidVectorElementTypes,
testing::Values(ParamsFor<vec2<f32>>(2),
ParamsFor<vec3<i32>>(2),
ParamsFor<vec4<u32>>(2),
ParamsFor<mat2x2<f32>>(2),
ParamsFor<mat3x3<f32>>(2),
ParamsFor<mat4x4<f32>>(2),
ParamsFor<array<2, f32>>(2)));
} // namespace VectorTests
} // namespace
} // namespace resolver
} // namespace tint

View File

@ -2007,17 +2007,6 @@ std::string VecStr(uint32_t dimensions, std::string subtype = "f32") {
using MatrixConstructorTest = ResolverTestWithParam<MatrixDimensions>;
TEST_F(MatrixConstructorTest, Expr_Constructor_Matrix_NotF32) {
// m2x2<i32>()
SetSource(Source::Location({12, 34}));
auto* tc = mat2x2<i32>(
create<ast::TypeConstructorExpression>(ty.mat2x2<i32>(), ExprList()));
WrapInFunction(tc);
EXPECT_FALSE(r()->Resolve());
EXPECT_EQ(r()->error(), "12:34 error: matrix element type must be 'f32'");
}
TEST_P(MatrixConstructorTest, Expr_Constructor_Error_TooFewArguments) {
// matNxM<f32>(vecM<f32>(), ...); with N - 1 arguments