resolver: Validate vector types
Fixed: tint:953 Change-Id: I3742680e49894a93db41219e512796ba9bdf036a Reviewed-on: https://dawn-review.googlesource.com/c/tint/+/56778 Auto-Submit: Ben Clayton <bclayton@google.com> Commit-Queue: Sarah Mashayekhi <sarahmashay@google.com> Kokoro: Kokoro <noreply+kokoro@google.com> Reviewed-by: Ben Clayton <bclayton@chromium.org> Reviewed-by: Sarah Mashayekhi <sarahmashay@google.com>
This commit is contained in:
parent
b4ff73e250
commit
ffe7978dbf
|
@ -324,8 +324,12 @@ sem::Type* Resolver::Type(const ast::Type* ty) {
|
|||
}
|
||||
if (auto* t = ty->As<ast::Vector>()) {
|
||||
if (auto* el = Type(t->type())) {
|
||||
return builder_->create<sem::Vector>(const_cast<sem::Type*>(el),
|
||||
t->size());
|
||||
if (auto* vector = builder_->create<sem::Vector>(
|
||||
const_cast<sem::Type*>(el), t->size())) {
|
||||
if (ValidateVector(vector, t->source())) {
|
||||
return vector;
|
||||
}
|
||||
}
|
||||
}
|
||||
return nullptr;
|
||||
}
|
||||
|
@ -333,10 +337,10 @@ sem::Type* Resolver::Type(const ast::Type* ty) {
|
|||
if (auto* el = Type(t->type())) {
|
||||
if (auto* column_type = builder_->create<sem::Vector>(
|
||||
const_cast<sem::Type*>(el), t->rows())) {
|
||||
if (auto* matrix_type =
|
||||
if (auto* matrix =
|
||||
builder_->create<sem::Matrix>(column_type, t->columns())) {
|
||||
if (ValidateMatrix(matrix_type, t->source())) {
|
||||
return matrix_type;
|
||||
if (ValidateMatrix(matrix, t->source())) {
|
||||
return matrix;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -2300,14 +2304,22 @@ bool Resolver::ValidateVectorConstructor(
|
|||
return true;
|
||||
}
|
||||
|
||||
bool Resolver::ValidateMatrix(const sem::Matrix* matrix_type,
|
||||
const Source& source) {
|
||||
if (!matrix_type->is_float_matrix()) {
|
||||
bool Resolver::ValidateVector(const sem::Vector* ty, const Source& source) {
|
||||
if (!ty->type()->is_scalar()) {
|
||||
AddError("vector element type must be 'bool', 'f32', 'i32' or 'u32'",
|
||||
source);
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
bool Resolver::ValidateMatrix(const sem::Matrix* ty, const Source& source) {
|
||||
if (!ty->is_float_matrix()) {
|
||||
AddError("matrix element type must be 'f32'", source);
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
} // namespace resolver
|
||||
}
|
||||
|
||||
bool Resolver::ValidateMatrixConstructor(
|
||||
const ast::TypeConstructorExpression* ctor,
|
||||
|
|
|
@ -280,7 +280,7 @@ class Resolver {
|
|||
bool ValidateGlobalVariable(const VariableInfo* var);
|
||||
bool ValidateInterpolateDecoration(const ast::InterpolateDecoration* deco,
|
||||
const sem::Type* storage_type);
|
||||
bool ValidateMatrix(const sem::Matrix* matirx_type, const Source& source);
|
||||
bool ValidateMatrix(const sem::Matrix* ty, const Source& source);
|
||||
bool ValidateMatrixConstructor(const ast::TypeConstructorExpression* ctor,
|
||||
const sem::Matrix* matrix_type);
|
||||
bool ValidateFunctionParameter(const ast::Function* func,
|
||||
|
@ -300,6 +300,7 @@ class Resolver {
|
|||
const std::string& type_name,
|
||||
const sem::Type* rhs_type,
|
||||
const std::string& rhs_type_name);
|
||||
bool ValidateVector(const sem::Vector* ty, const Source& source);
|
||||
bool ValidateVectorConstructor(const ast::TypeConstructorExpression* ctor,
|
||||
const sem::Vector* vec_type);
|
||||
bool ValidateScalarConstructor(const ast::TypeConstructorExpression* ctor,
|
||||
|
|
|
@ -1356,29 +1356,19 @@ TEST_F(ResolverTest, Expr_Accessor_MultiLevel) {
|
|||
// vec4<f32> foo
|
||||
// }
|
||||
// struct A {
|
||||
// vec3<struct b> mem
|
||||
// array<b, 3> mem
|
||||
// }
|
||||
// var c : A
|
||||
// c.mem[0].foo.yx
|
||||
// -> vec2<f32>
|
||||
//
|
||||
// MemberAccessor{
|
||||
// MemberAccessor{
|
||||
// ArrayAccessor{
|
||||
// MemberAccessor{
|
||||
// Identifier{c}
|
||||
// Identifier{mem}
|
||||
// }
|
||||
// ScalarConstructor{0}
|
||||
// }
|
||||
// Identifier{foo}
|
||||
// }
|
||||
// Identifier{yx}
|
||||
// fn f() {
|
||||
// c.mem[0].foo
|
||||
// }
|
||||
//
|
||||
|
||||
auto* stB = Structure("B", {Member("foo", ty.vec4<f32>())});
|
||||
auto* stA = Structure("A", {Member("mem", ty.vec(ty.Of(stB), 3))});
|
||||
auto* stA = Structure("A", {Member("mem", ty.array(ty.Of(stB), 3))});
|
||||
Global("c", ty.Of(stA), ast::StorageClass::kPrivate);
|
||||
|
||||
auto* mem = MemberAccessor(
|
||||
|
|
|
@ -156,6 +156,9 @@ using mat3x3 = mat<3, 3, T>;
|
|||
template <typename T>
|
||||
using mat4x4 = mat<4, 4, T>;
|
||||
|
||||
template <int N, typename T>
|
||||
struct array {};
|
||||
|
||||
template <typename TO, int ID = 0>
|
||||
struct alias {};
|
||||
|
||||
|
@ -384,6 +387,43 @@ struct DataType<alias<T, ID>> {
|
|||
}
|
||||
};
|
||||
|
||||
/// Helper for building array types and expressions
|
||||
template <int N, typename T>
|
||||
struct DataType<array<N, T>> {
|
||||
/// true as arrays are a composite type
|
||||
static constexpr bool is_composite = true;
|
||||
|
||||
/// @param b the ProgramBuilder
|
||||
/// @return a new AST array type
|
||||
static inline ast::Type* AST(ProgramBuilder& b) {
|
||||
return b.ty.array(DataType<T>::AST(b), N);
|
||||
}
|
||||
/// @param b the ProgramBuilder
|
||||
/// @return the semantic array type
|
||||
static inline sem::Type* Sem(ProgramBuilder& b) {
|
||||
return b.create<sem::Array>(DataType<T>::Sem(b), N);
|
||||
}
|
||||
/// @param b the ProgramBuilder
|
||||
/// @param elem_value the value each element in the array will be initialized
|
||||
/// with
|
||||
/// @return a new AST array value expression
|
||||
static inline ast::Expression* Expr(ProgramBuilder& b, int elem_value) {
|
||||
return b.Construct(AST(b), ExprArgs(b, elem_value));
|
||||
}
|
||||
|
||||
/// @param b the ProgramBuilder
|
||||
/// @param elem_value the value each element will be initialized with
|
||||
/// @return the list of expressions that are used to construct the array
|
||||
static inline ast::ExpressionList ExprArgs(ProgramBuilder& b,
|
||||
int elem_value) {
|
||||
ast::ExpressionList args;
|
||||
for (int i = 0; i < N; i++) {
|
||||
args.emplace_back(DataType<T>::Expr(b, elem_value));
|
||||
}
|
||||
return args;
|
||||
}
|
||||
};
|
||||
|
||||
/// Struct of all creation pointer types
|
||||
struct CreatePtrs {
|
||||
/// ast node type create function
|
||||
|
|
|
@ -42,6 +42,8 @@ template <typename T>
|
|||
using mat3x3 = builder::mat3x3<T>;
|
||||
template <typename T>
|
||||
using mat4x4 = builder::mat4x4<T>;
|
||||
template <int N, typename T>
|
||||
using array = builder::array<N, T>;
|
||||
template <typename T>
|
||||
using alias = builder::alias<T>;
|
||||
template <typename T>
|
||||
|
@ -751,6 +753,126 @@ TEST_F(StorageTextureAccessTest, WriteOnlyAccess_Pass) {
|
|||
|
||||
} // namespace StorageTextureTests
|
||||
|
||||
namespace MatrixTests {
|
||||
struct Params {
|
||||
uint32_t columns;
|
||||
uint32_t rows;
|
||||
builder::ast_type_func_ptr elem_ty;
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
constexpr Params ParamsFor(uint32_t columns, uint32_t rows) {
|
||||
return Params{columns, rows, DataType<T>::AST};
|
||||
}
|
||||
|
||||
using ValidMatrixTypes = ResolverTestWithParam<Params>;
|
||||
TEST_P(ValidMatrixTypes, Okay) {
|
||||
// var a : matNxM<EL_TY>;
|
||||
auto& params = GetParam();
|
||||
Global("a", ty.mat(params.elem_ty(*this), params.columns, params.rows),
|
||||
ast::StorageClass::kPrivate);
|
||||
EXPECT_TRUE(r()->Resolve()) << r()->error();
|
||||
}
|
||||
INSTANTIATE_TEST_SUITE_P(ResolverTypeValidationTest,
|
||||
ValidMatrixTypes,
|
||||
testing::Values(ParamsFor<f32>(2, 2),
|
||||
ParamsFor<f32>(2, 3),
|
||||
ParamsFor<f32>(2, 4),
|
||||
ParamsFor<f32>(3, 2),
|
||||
ParamsFor<f32>(3, 3),
|
||||
ParamsFor<f32>(3, 4),
|
||||
ParamsFor<f32>(4, 2),
|
||||
ParamsFor<f32>(4, 3),
|
||||
ParamsFor<f32>(4, 4),
|
||||
ParamsFor<alias<f32>>(4, 2),
|
||||
ParamsFor<alias<f32>>(4, 3),
|
||||
ParamsFor<alias<f32>>(4, 4)));
|
||||
|
||||
using InvalidMatrixElementTypes = ResolverTestWithParam<Params>;
|
||||
TEST_P(InvalidMatrixElementTypes, InvalidElementType) {
|
||||
// var a : matNxM<EL_TY>;
|
||||
auto& params = GetParam();
|
||||
Global("a",
|
||||
ty.mat(Source{{12, 34}}, params.elem_ty(*this), params.columns,
|
||||
params.rows),
|
||||
ast::StorageClass::kPrivate);
|
||||
EXPECT_FALSE(r()->Resolve());
|
||||
EXPECT_EQ(r()->error(), "12:34 error: matrix element type must be 'f32'");
|
||||
}
|
||||
INSTANTIATE_TEST_SUITE_P(ResolverTypeValidationTest,
|
||||
InvalidMatrixElementTypes,
|
||||
testing::Values(ParamsFor<bool>(4, 2),
|
||||
ParamsFor<i32>(4, 3),
|
||||
ParamsFor<u32>(4, 4),
|
||||
ParamsFor<vec2<f32>>(2, 2),
|
||||
ParamsFor<vec3<i32>>(2, 3),
|
||||
ParamsFor<vec4<u32>>(2, 4),
|
||||
ParamsFor<mat2x2<f32>>(3, 2),
|
||||
ParamsFor<mat3x3<f32>>(3, 3),
|
||||
ParamsFor<mat4x4<f32>>(3, 4),
|
||||
ParamsFor<array<2, f32>>(4, 2)));
|
||||
} // namespace MatrixTests
|
||||
|
||||
namespace VectorTests {
|
||||
struct Params {
|
||||
uint32_t width;
|
||||
builder::ast_type_func_ptr elem_ty;
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
constexpr Params ParamsFor(uint32_t width) {
|
||||
return Params{width, DataType<T>::AST};
|
||||
}
|
||||
|
||||
using ValidVectorTypes = ResolverTestWithParam<Params>;
|
||||
TEST_P(ValidVectorTypes, Okay) {
|
||||
// var a : vecN<EL_TY>;
|
||||
auto& params = GetParam();
|
||||
Global("a", ty.vec(params.elem_ty(*this), params.width),
|
||||
ast::StorageClass::kPrivate);
|
||||
EXPECT_TRUE(r()->Resolve()) << r()->error();
|
||||
}
|
||||
INSTANTIATE_TEST_SUITE_P(ResolverTypeValidationTest,
|
||||
ValidVectorTypes,
|
||||
testing::Values(ParamsFor<bool>(2),
|
||||
ParamsFor<f32>(2),
|
||||
ParamsFor<i32>(2),
|
||||
ParamsFor<u32>(2),
|
||||
ParamsFor<bool>(3),
|
||||
ParamsFor<f32>(3),
|
||||
ParamsFor<i32>(3),
|
||||
ParamsFor<u32>(3),
|
||||
ParamsFor<bool>(4),
|
||||
ParamsFor<f32>(4),
|
||||
ParamsFor<i32>(4),
|
||||
ParamsFor<u32>(4),
|
||||
ParamsFor<alias<bool>>(4),
|
||||
ParamsFor<alias<f32>>(4),
|
||||
ParamsFor<alias<i32>>(4),
|
||||
ParamsFor<alias<u32>>(4)));
|
||||
|
||||
using InvalidVectorElementTypes = ResolverTestWithParam<Params>;
|
||||
TEST_P(InvalidVectorElementTypes, InvalidElementType) {
|
||||
// var a : vecN<EL_TY>;
|
||||
auto& params = GetParam();
|
||||
Global("a", ty.vec(Source{{12, 34}}, params.elem_ty(*this), params.width),
|
||||
ast::StorageClass::kPrivate);
|
||||
EXPECT_FALSE(r()->Resolve());
|
||||
EXPECT_EQ(
|
||||
r()->error(),
|
||||
"12:34 error: vector element type must be 'bool', 'f32', 'i32' or 'u32'");
|
||||
}
|
||||
INSTANTIATE_TEST_SUITE_P(ResolverTypeValidationTest,
|
||||
InvalidVectorElementTypes,
|
||||
testing::Values(ParamsFor<vec2<f32>>(2),
|
||||
ParamsFor<vec3<i32>>(2),
|
||||
ParamsFor<vec4<u32>>(2),
|
||||
ParamsFor<mat2x2<f32>>(2),
|
||||
ParamsFor<mat3x3<f32>>(2),
|
||||
ParamsFor<mat4x4<f32>>(2),
|
||||
ParamsFor<array<2, f32>>(2)));
|
||||
} // namespace VectorTests
|
||||
|
||||
} // namespace
|
||||
} // namespace resolver
|
||||
} // namespace tint
|
||||
|
|
|
@ -2007,17 +2007,6 @@ std::string VecStr(uint32_t dimensions, std::string subtype = "f32") {
|
|||
|
||||
using MatrixConstructorTest = ResolverTestWithParam<MatrixDimensions>;
|
||||
|
||||
TEST_F(MatrixConstructorTest, Expr_Constructor_Matrix_NotF32) {
|
||||
// m2x2<i32>()
|
||||
SetSource(Source::Location({12, 34}));
|
||||
auto* tc = mat2x2<i32>(
|
||||
create<ast::TypeConstructorExpression>(ty.mat2x2<i32>(), ExprList()));
|
||||
WrapInFunction(tc);
|
||||
|
||||
EXPECT_FALSE(r()->Resolve());
|
||||
EXPECT_EQ(r()->error(), "12:34 error: matrix element type must be 'f32'");
|
||||
}
|
||||
|
||||
TEST_P(MatrixConstructorTest, Expr_Constructor_Error_TooFewArguments) {
|
||||
// matNxM<f32>(vecM<f32>(), ...); with N - 1 arguments
|
||||
|
||||
|
|
Loading…
Reference in New Issue