Resolver: Enforce matrix constructor type rules

Added enforcement for matrix constructor type rules according to the
table in https://gpuweb.github.io/gpuweb/wgsl.html#type-constructor-expr.

Fixed: tint:633
Change-Id: I97fc7f558f04780ed03252d94c071af3e0e07e26
Reviewed-on: https://dawn-review.googlesource.com/c/tint/+/45020
Reviewed-by: Ben Clayton <bclayton@google.com>
Commit-Queue: Arman Uguray <armansito@chromium.org>
This commit is contained in:
Arman Uguray 2021-03-18 15:43:14 +00:00 committed by Commit Bot service account
parent 512bdf1612
commit 097c75ae3e
3 changed files with 422 additions and 11 deletions

View File

@ -63,6 +63,13 @@ class ScopedAssignment {
T old_value_; T old_value_;
}; };
// Helper function that returns the range union of two source locations. The
// `start` and `end` locations are assumed to refer to the same source file.
Source CombineSourceRange(const Source& start, const Source& end) {
return Source(Source::Range(start.range.begin, end.range.end),
start.file_path, start.file_content);
}
} // namespace } // namespace
Resolver::Resolver(ProgramBuilder* builder) Resolver::Resolver(ProgramBuilder* builder)
@ -572,9 +579,11 @@ bool Resolver::Constructor(ast::ConstructorExpression* expr) {
// obey the constructor type rules laid out in // obey the constructor type rules laid out in
// https://gpuweb.github.io/gpuweb/wgsl.html#type-constructor-expr. // https://gpuweb.github.io/gpuweb/wgsl.html#type-constructor-expr.
if (auto* vec_type = type_ctor->type()->As<type::Vector>()) { if (auto* vec_type = type_ctor->type()->As<type::Vector>()) {
return VectorConstructor(*vec_type, type_ctor->values()); return VectorConstructor(vec_type, type_ctor->values());
}
if (auto* mat_type = type_ctor->type()->As<type::Matrix>()) {
return MatrixConstructor(mat_type, type_ctor->values());
} }
// TODO(crbug.com/tint/633): Validate matrix constructor
// TODO(crbug.com/tint/634): Validate array constructor // TODO(crbug.com/tint/634): Validate array constructor
} else if (auto* scalar_ctor = expr->As<ast::ScalarConstructorExpression>()) { } else if (auto* scalar_ctor = expr->As<ast::ScalarConstructorExpression>()) {
SetType(expr, scalar_ctor->literal()->type()); SetType(expr, scalar_ctor->literal()->type());
@ -584,9 +593,9 @@ bool Resolver::Constructor(ast::ConstructorExpression* expr) {
return true; return true;
} }
bool Resolver::VectorConstructor(const type::Vector& vec_type, bool Resolver::VectorConstructor(const type::Vector* vec_type,
const ast::ExpressionList& values) { const ast::ExpressionList& values) {
type::Type* elem_type = vec_type.type()->UnwrapAll(); type::Type* elem_type = vec_type->type()->UnwrapAll();
size_t value_cardinality_sum = 0; size_t value_cardinality_sum = 0;
for (auto* value : values) { for (auto* value : values) {
type::Type* value_type = TypeOf(value)->UnwrapAll(); type::Type* value_type = TypeOf(value)->UnwrapAll();
@ -635,26 +644,63 @@ bool Resolver::VectorConstructor(const type::Vector& vec_type,
// A correct vector constructor must either be a zero-value expression // A correct vector constructor must either be a zero-value expression
// or the number of components of all constructor arguments must add up // or the number of components of all constructor arguments must add up
// to the vector cardinality. // to the vector cardinality.
if (value_cardinality_sum > 0 && value_cardinality_sum != vec_type.size()) { if (value_cardinality_sum > 0 && value_cardinality_sum != vec_type->size()) {
if (values.empty()) { if (values.empty()) {
TINT_ICE(diagnostics_) TINT_ICE(diagnostics_)
<< "constructor arguments expected to be non-empty!"; << "constructor arguments expected to be non-empty!";
} }
const Source& values_start = values[0]->source(); const Source& values_start = values[0]->source();
const Source& values_end = values[values.size() - 1]->source(); const Source& values_end = values[values.size() - 1]->source();
const Source src(
Source::Range(values_start.range.begin, values_end.range.end),
values_start.file_path, values_start.file_content);
diagnostics_.add_error( diagnostics_.add_error(
"attempted to construct '" + "attempted to construct '" +
vec_type.FriendlyName(builder_->Symbols()) + "' with " + vec_type->FriendlyName(builder_->Symbols()) + "' with " +
std::to_string(value_cardinality_sum) + " component(s)", std::to_string(value_cardinality_sum) + " component(s)",
src); CombineSourceRange(values_start, values_end));
return false; return false;
} }
return true; return true;
} }
bool Resolver::MatrixConstructor(const type::Matrix* matrix_type,
const ast::ExpressionList& values) {
// Zero Value expression
if (values.empty()) {
return true;
}
type::Type* elem_type = matrix_type->type()->UnwrapAll();
if (matrix_type->columns() != values.size()) {
const Source& values_start = values[0]->source();
const Source& values_end = values[values.size() - 1]->source();
diagnostics_.add_error(
"expected " + std::to_string(matrix_type->columns()) + " '" +
VectorPretty(matrix_type->rows(), elem_type) + "' arguments in '" +
matrix_type->FriendlyName(builder_->Symbols()) +
"' constructor, found " + std::to_string(values.size()),
CombineSourceRange(values_start, values_end));
return false;
}
for (auto* value : values) {
type::Type* value_type = TypeOf(value)->UnwrapAll();
auto* value_vec = value_type->As<type::Vector>();
if (!value_vec || value_vec->size() != matrix_type->rows() ||
elem_type != value_vec->type()->UnwrapAll()) {
diagnostics_.add_error(
"expected argument type '" +
VectorPretty(matrix_type->rows(), elem_type) + "' in '" +
matrix_type->FriendlyName(builder_->Symbols()) +
"' constructor, found '" +
value_type->FriendlyName(builder_->Symbols()) + "'",
value->source());
return false;
}
}
return true;
}
bool Resolver::Identifier(ast::IdentifierExpression* expr) { bool Resolver::Identifier(ast::IdentifierExpression* expr) {
auto symbol = expr->symbol(); auto symbol = expr->symbol();
VariableInfo* var; VariableInfo* var;
@ -1501,6 +1547,11 @@ bool Resolver::BlockScope(BlockInfo::Type type, F&& callback) {
return callback(); return callback();
} }
std::string Resolver::VectorPretty(uint32_t size, type::Type* element_type) {
type::Vector vec_type(element_type, size);
return vec_type.FriendlyName(builder_->Symbols());
}
Resolver::VariableInfo::VariableInfo(ast::Variable* decl) Resolver::VariableInfo::VariableInfo(ast::Variable* decl)
: declaration(decl), storage_class(decl->declared_storage_class()) {} : declaration(decl), storage_class(decl->declared_storage_class()) {}

View File

@ -194,7 +194,9 @@ class Resolver {
bool Call(ast::CallExpression*); bool Call(ast::CallExpression*);
bool CaseStatement(ast::CaseStatement*); bool CaseStatement(ast::CaseStatement*);
bool Constructor(ast::ConstructorExpression*); bool Constructor(ast::ConstructorExpression*);
bool VectorConstructor(const type::Vector& vec_type, bool VectorConstructor(const type::Vector* vec_type,
const ast::ExpressionList& values);
bool MatrixConstructor(const type::Matrix* matrix_type,
const ast::ExpressionList& values); const ast::ExpressionList& values);
bool Expression(ast::Expression*); bool Expression(ast::Expression*);
bool Expressions(const ast::ExpressionList&); bool Expressions(const ast::ExpressionList&);
@ -247,6 +249,13 @@ class Resolver {
template <typename F> template <typename F>
bool BlockScope(BlockInfo::Type type, F&& callback); bool BlockScope(BlockInfo::Type type, F&& callback);
/// Returns a human-readable string representation of the vector type name
/// with the given parameters.
/// @param size the vector dimension
/// @param element_type scalar vector sub-element type
/// @return pretty string representation
std::string VectorPretty(uint32_t size, type::Type* element_type);
ProgramBuilder* const builder_; ProgramBuilder* const builder_;
std::unique_ptr<IntrinsicTable> const intrinsic_table_; std::unique_ptr<IntrinsicTable> const intrinsic_table_;
diag::List diagnostics_; diag::List diagnostics_;

View File

@ -1664,6 +1664,357 @@ TEST_F(ResolverValidationTest,
EXPECT_TRUE(r()->Resolve()) << r()->error(); EXPECT_TRUE(r()->Resolve()) << r()->error();
} }
struct MatrixDimensions {
uint32_t rows;
uint32_t columns;
};
std::string MatrixStr(const MatrixDimensions& dimensions,
std::string subtype = "f32") {
return "mat" + std::to_string(dimensions.columns) + "x" +
std::to_string(dimensions.rows) + "<" + subtype + ">";
}
std::string VecStr(uint32_t dimensions, std::string subtype = "f32") {
return "vec" + std::to_string(dimensions) + "<" + subtype + ">";
}
using MatrixConstructorTest = ResolverTestWithParam<MatrixDimensions>;
TEST_P(MatrixConstructorTest, Expr_Constructor_Error_TooFewArguments) {
// matNxM<f32>(vecM<f32>(), ...); with N - 1 arguments
const auto param = GetParam();
auto* matrix_type = create<type::Matrix>(ty.f32(), param.rows, param.columns);
auto* vec_type = create<type::Vector>(ty.f32(), param.rows);
ast::ExpressionList args;
for (uint32_t i = 1; i <= param.columns - 1; i++) {
args.push_back(create<ast::TypeConstructorExpression>(
Source{{12, i}}, vec_type, ExprList()));
}
auto* tc = create<ast::TypeConstructorExpression>(Source{}, matrix_type,
std::move(args));
WrapInFunction(tc);
EXPECT_FALSE(r()->Resolve());
EXPECT_EQ(r()->error(),
"12:1 error: expected " + std::to_string(param.columns) + " '" +
VecStr(param.rows) + "' arguments in '" + MatrixStr(param) +
"' constructor, found " + std::to_string(param.columns - 1));
}
TEST_P(MatrixConstructorTest, Expr_Constructor_Error_TooManyArguments) {
// matNxM<f32>(vecM<f32>(), ...); with N + 1 arguments
const auto param = GetParam();
auto* matrix_type = create<type::Matrix>(ty.f32(), param.rows, param.columns);
auto* vec_type = create<type::Vector>(ty.f32(), param.rows);
ast::ExpressionList args;
for (uint32_t i = 1; i <= param.columns + 1; i++) {
args.push_back(create<ast::TypeConstructorExpression>(
Source{{12, i}}, vec_type, ExprList()));
}
auto* tc = create<ast::TypeConstructorExpression>(Source{}, matrix_type,
std::move(args));
WrapInFunction(tc);
EXPECT_FALSE(r()->Resolve());
EXPECT_EQ(r()->error(),
"12:1 error: expected " + std::to_string(param.columns) + " '" +
VecStr(param.rows) + "' arguments in '" + MatrixStr(param) +
"' constructor, found " + std::to_string(param.columns + 1));
}
TEST_P(MatrixConstructorTest, Expr_Constructor_Error_InvalidArgumentType) {
// matNxM<f32>(1.0, 1.0, ...); N arguments
const auto param = GetParam();
auto* matrix_type = create<type::Matrix>(ty.f32(), param.rows, param.columns);
ast::ExpressionList args;
for (uint32_t i = 1; i <= param.columns; i++) {
args.push_back(create<ast::ScalarConstructorExpression>(Source{{12, i}},
Literal(1.0f)));
}
auto* tc = create<ast::TypeConstructorExpression>(Source{}, matrix_type,
std::move(args));
WrapInFunction(tc);
EXPECT_FALSE(r()->Resolve());
EXPECT_EQ(r()->error(), "12:1 error: expected argument type '" +
VecStr(param.rows) + "' in '" + MatrixStr(param) +
"' constructor, found 'f32'");
}
TEST_P(MatrixConstructorTest,
Expr_Constructor_Error_TooFewRowsInVectorArgument) {
// matNxM<f32>(vecM<f32>(),...,vecM-1<f32>());
const auto param = GetParam();
// Skip the test if parameters would have resuled in an invalid vec1 type.
if (param.rows == 2) {
return;
}
auto* matrix_type = create<type::Matrix>(ty.f32(), param.rows, param.columns);
auto* valid_vec_type = create<type::Vector>(ty.f32(), param.rows);
auto* invalid_vec_type = create<type::Vector>(ty.f32(), param.rows - 1);
ast::ExpressionList args;
for (uint32_t i = 1; i <= param.columns - 1; i++) {
args.push_back(create<ast::TypeConstructorExpression>(
Source{{12, i}}, valid_vec_type, ExprList()));
}
const size_t kInvalidLoc = 2 * (param.columns - 1);
args.push_back(create<ast::TypeConstructorExpression>(
Source{{12, kInvalidLoc}}, invalid_vec_type, ExprList()));
auto* tc = create<ast::TypeConstructorExpression>(Source{}, matrix_type,
std::move(args));
WrapInFunction(tc);
EXPECT_FALSE(r()->Resolve());
EXPECT_EQ(r()->error(), "12:" + std::to_string(kInvalidLoc) +
" error: expected argument type '" +
VecStr(param.rows) + "' in '" + MatrixStr(param) +
"' constructor, found '" +
VecStr(param.rows - 1) + "'");
}
TEST_P(MatrixConstructorTest,
Expr_Constructor_Error_TooManyRowsInVectorArgument) {
// matNxM<f32>(vecM<f32>(),...,vecM+1<f32>());
const auto param = GetParam();
// Skip the test if parameters would have resuled in an invalid vec5 type.
if (param.rows == 4) {
return;
}
auto* matrix_type = create<type::Matrix>(ty.f32(), param.rows, param.columns);
auto* valid_vec_type = create<type::Vector>(ty.f32(), param.rows);
auto* invalid_vec_type = create<type::Vector>(ty.f32(), param.rows + 1);
ast::ExpressionList args;
for (uint32_t i = 1; i <= param.columns - 1; i++) {
args.push_back(create<ast::TypeConstructorExpression>(
Source{{12, i}}, valid_vec_type, ExprList()));
}
const size_t kInvalidLoc = 2 * (param.columns - 1);
args.push_back(create<ast::TypeConstructorExpression>(
Source{{12, kInvalidLoc}}, invalid_vec_type, ExprList()));
auto* tc = create<ast::TypeConstructorExpression>(Source{}, matrix_type,
std::move(args));
WrapInFunction(tc);
EXPECT_FALSE(r()->Resolve());
EXPECT_EQ(r()->error(), "12:" + std::to_string(kInvalidLoc) +
" error: expected argument type '" +
VecStr(param.rows) + "' in '" + MatrixStr(param) +
"' constructor, found '" +
VecStr(param.rows + 1) + "'");
}
TEST_P(MatrixConstructorTest,
Expr_Constructor_Error_ArgumentVectorElementTypeMismatch) {
// matNxM<f32>(vecM<u32>(), ...); with N arguments
const auto param = GetParam();
auto* matrix_type = create<type::Matrix>(ty.f32(), param.rows, param.columns);
auto* vec_type = create<type::Vector>(ty.u32(), param.rows);
ast::ExpressionList args;
for (uint32_t i = 1; i <= param.columns; i++) {
args.push_back(create<ast::TypeConstructorExpression>(
Source{{12, i}}, vec_type, ExprList()));
}
auto* tc = create<ast::TypeConstructorExpression>(Source{}, matrix_type,
std::move(args));
WrapInFunction(tc);
EXPECT_FALSE(r()->Resolve());
EXPECT_EQ(r()->error(), "12:1 error: expected argument type '" +
VecStr(param.rows) + "' in '" + MatrixStr(param) +
"' constructor, found '" +
VecStr(param.rows, "u32") + "'");
}
TEST_P(MatrixConstructorTest, Expr_Constructor_ZeroValue_Success) {
// matNxM<f32>();
const auto param = GetParam();
auto* matrix_type = create<type::Matrix>(ty.f32(), param.rows, param.columns);
auto* tc = create<ast::TypeConstructorExpression>(Source{{12, 40}},
matrix_type, ExprList());
WrapInFunction(tc);
EXPECT_TRUE(r()->Resolve()) << r()->error();
}
TEST_P(MatrixConstructorTest, Expr_Constructor_WithArguments_Success) {
// matNxM<f32>(vecM<f32>(), ...); with N arguments
const auto param = GetParam();
auto* matrix_type = create<type::Matrix>(ty.f32(), param.rows, param.columns);
auto* vec_type = create<type::Vector>(ty.f32(), param.rows);
ast::ExpressionList args;
for (uint32_t i = 1; i <= param.columns; i++) {
args.push_back(create<ast::TypeConstructorExpression>(
Source{{12, i}}, vec_type, ExprList()));
}
auto* tc = create<ast::TypeConstructorExpression>(Source{}, matrix_type,
std::move(args));
WrapInFunction(tc);
EXPECT_TRUE(r()->Resolve()) << r()->error();
}
TEST_P(MatrixConstructorTest, Expr_Constructor_ElementTypeAlias_Error) {
// matNxM<Float32>(vecM<u32>(), ...); with N arguments
const auto param = GetParam();
auto* f32_alias = ty.alias("Float32", ty.f32());
auto* matrix_type =
create<type::Matrix>(f32_alias, param.rows, param.columns);
auto* vec_type = create<type::Vector>(ty.u32(), param.rows);
ast::ExpressionList args;
for (uint32_t i = 1; i <= param.columns; i++) {
args.push_back(create<ast::TypeConstructorExpression>(
Source{{12, i}}, vec_type, ExprList()));
}
auto* tc = create<ast::TypeConstructorExpression>(Source{}, matrix_type,
std::move(args));
WrapInFunction(tc);
EXPECT_FALSE(r()->Resolve());
EXPECT_EQ(r()->error(),
"12:1 error: expected argument type '" + VecStr(param.rows) +
"' in '" + MatrixStr(param, "Float32") +
"' constructor, found '" + VecStr(param.rows, "u32") + "'");
}
TEST_P(MatrixConstructorTest, Expr_Constructor_ElementTypeAlias_Success) {
// matNxM<Float32>(vecM<f32>(), ...); with N arguments
const auto param = GetParam();
auto* f32_alias = ty.alias("Float32", ty.f32());
auto* matrix_type =
create<type::Matrix>(f32_alias, param.rows, param.columns);
auto* vec_type = create<type::Vector>(ty.f32(), param.rows);
ast::ExpressionList args;
for (uint32_t i = 1; i <= param.columns; i++) {
args.push_back(create<ast::TypeConstructorExpression>(
Source{{12, i}}, vec_type, ExprList()));
}
auto* tc = create<ast::TypeConstructorExpression>(Source{}, matrix_type,
std::move(args));
WrapInFunction(tc);
EXPECT_TRUE(r()->Resolve()) << r()->error();
}
TEST_F(ResolverValidationTest, Expr_MatrixConstructor_ArgumentTypeAlias_Error) {
auto* vec2_alias = ty.alias("VectorUnsigned2", ty.vec2<u32>());
auto* tc = mat2x2<f32>(create<ast::TypeConstructorExpression>(
Source{{12, 34}}, vec2_alias, ExprList()),
vec2<f32>());
WrapInFunction(tc);
EXPECT_FALSE(r()->Resolve());
EXPECT_EQ(r()->error(),
"12:34 error: expected argument type 'vec2<f32>' in 'mat2x2<f32>' "
"constructor, found 'vec2<u32>'");
}
TEST_P(MatrixConstructorTest, Expr_Constructor_ArgumentTypeAlias_Success) {
const auto param = GetParam();
auto* matrix_type = create<type::Matrix>(ty.f32(), param.rows, param.columns);
auto* vec_type = create<type::Vector>(ty.f32(), param.rows);
auto* vec_alias = ty.alias("VectorFloat2", vec_type);
ast::ExpressionList args;
for (uint32_t i = 1; i <= param.columns; i++) {
args.push_back(create<ast::TypeConstructorExpression>(
Source{{12, i}}, vec_alias, ExprList()));
}
auto* tc = create<ast::TypeConstructorExpression>(Source{}, matrix_type,
std::move(args));
WrapInFunction(tc);
EXPECT_TRUE(r()->Resolve()) << r()->error();
}
TEST_P(MatrixConstructorTest, Expr_Constructor_ArgumentElementTypeAlias_Error) {
const auto param = GetParam();
auto* matrix_type = create<type::Matrix>(ty.f32(), param.rows, param.columns);
auto* f32_alias = ty.alias("UnsignedInt", ty.u32());
auto* vec_type = create<type::Vector>(f32_alias, param.rows);
ast::ExpressionList args;
for (uint32_t i = 1; i <= param.columns; i++) {
args.push_back(create<ast::TypeConstructorExpression>(
Source{{12, i}}, vec_type, ExprList()));
}
auto* tc = create<ast::TypeConstructorExpression>(Source{}, matrix_type,
std::move(args));
WrapInFunction(tc);
EXPECT_FALSE(r()->Resolve());
EXPECT_EQ(r()->error(), "12:1 error: expected argument type '" +
VecStr(param.rows) + "' in '" + MatrixStr(param) +
"' constructor, found '" +
VecStr(param.rows, "UnsignedInt") + "'");
}
TEST_P(MatrixConstructorTest,
Expr_Constructor_ArgumentElementTypeAlias_Success) {
const auto param = GetParam();
auto* matrix_type = create<type::Matrix>(ty.f32(), param.rows, param.columns);
auto* f32_alias = ty.alias("Float32", ty.f32());
auto* vec_type = create<type::Vector>(f32_alias, param.rows);
ast::ExpressionList args;
for (uint32_t i = 1; i <= param.columns; i++) {
args.push_back(create<ast::TypeConstructorExpression>(
Source{{12, i}}, vec_type, ExprList()));
}
auto* tc = create<ast::TypeConstructorExpression>(Source{}, matrix_type,
std::move(args));
WrapInFunction(tc);
EXPECT_TRUE(r()->Resolve()) << r()->error();
}
INSTANTIATE_TEST_SUITE_P(ResolverValidationTest,
MatrixConstructorTest,
testing::Values(MatrixDimensions{2, 2},
MatrixDimensions{3, 2},
MatrixDimensions{4, 2},
MatrixDimensions{2, 3},
MatrixDimensions{3, 3},
MatrixDimensions{4, 3},
MatrixDimensions{2, 4},
MatrixDimensions{3, 4},
MatrixDimensions{4, 4}));
} // namespace } // namespace
} // namespace resolver } // namespace resolver
} // namespace tint } // namespace tint