mirror of
https://github.com/encounter/dawn-cmake.git
synced 2025-12-15 16:16:08 +00:00
Add support for matrix construction from scalars
Use a transform to convert these to the vector form for the MSL and SPIR-V backends. MSL only has the scalar form from version 2.0 onwards. Fixed: tint:1123 Change-Id: I384abd9872d9eae52a10a37cbd6aa96004692e9c Reviewed-on: https://dawn-review.googlesource.com/c/tint/+/67360 Kokoro: Kokoro <noreply+kokoro@google.com> Commit-Queue: James Price <jrprice@google.com> Reviewed-by: Ben Clayton <bclayton@google.com>
This commit is contained in:
committed by
Tint LUCI CQ
parent
c73b57f2c6
commit
91689fb6f0
@@ -2885,28 +2885,47 @@ bool Resolver::ValidateMatrixConstructor(
|
||||
}
|
||||
|
||||
auto* elem_type = matrix_type->type();
|
||||
if (matrix_type->columns() != values.size()) {
|
||||
auto num_elements = matrix_type->columns() * matrix_type->rows();
|
||||
|
||||
// Print a generic error for an invalid matrix constructor, showing the
|
||||
// available overloads.
|
||||
auto print_error = [&]() {
|
||||
const Source& values_start = values[0]->source;
|
||||
const Source& values_end = values[values.size() - 1]->source;
|
||||
AddError("expected " + std::to_string(matrix_type->columns()) + " '" +
|
||||
VectorPretty(matrix_type->rows(), elem_type) +
|
||||
"' arguments in '" + type_name + "' constructor, found " +
|
||||
std::to_string(values.size()),
|
||||
Source::Combine(values_start, values_end));
|
||||
auto elem_type_name = elem_type->FriendlyName(builder_->Symbols());
|
||||
std::stringstream ss;
|
||||
ss << "invalid constructor for " + type_name << std::endl << std::endl;
|
||||
ss << "3 candidates available:" << std::endl;
|
||||
ss << " " << type_name << "()" << std::endl;
|
||||
ss << " " << type_name << "(" << elem_type_name << ",...,"
|
||||
<< elem_type_name << ")"
|
||||
<< " // " << std::to_string(num_elements) << " arguments" << std::endl;
|
||||
ss << " " << type_name << "(";
|
||||
for (uint32_t c = 0; c < matrix_type->columns(); c++) {
|
||||
if (c > 0) {
|
||||
ss << ", ";
|
||||
}
|
||||
ss << VectorPretty(matrix_type->rows(), elem_type);
|
||||
}
|
||||
ss << ")" << std::endl;
|
||||
AddError(ss.str(), Source::Combine(values_start, values_end));
|
||||
};
|
||||
|
||||
const sem::Type* expected_arg_type = nullptr;
|
||||
if (num_elements == values.size()) {
|
||||
// Column-major construction from scalar elements.
|
||||
expected_arg_type = matrix_type->type();
|
||||
} else if (matrix_type->columns() == values.size()) {
|
||||
// Column-by-column construction from vectors.
|
||||
expected_arg_type = matrix_type->ColumnType();
|
||||
} else {
|
||||
print_error();
|
||||
return false;
|
||||
}
|
||||
|
||||
for (auto* value : values) {
|
||||
auto* value_type = TypeOf(value)->UnwrapRef();
|
||||
auto* value_vec = value_type->As<sem::Vector>();
|
||||
|
||||
if (!value_vec || value_vec->Width() != matrix_type->rows() ||
|
||||
elem_type != value_vec->type()) {
|
||||
AddError("expected argument type '" +
|
||||
VectorPretty(matrix_type->rows(), elem_type) + "' in '" +
|
||||
type_name + "' constructor, found '" + TypeNameOf(value) +
|
||||
"'",
|
||||
value->source);
|
||||
if (TypeOf(value)->UnwrapRef() != expected_arg_type) {
|
||||
print_error();
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
@@ -12,6 +12,7 @@
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include "gmock/gmock.h"
|
||||
#include "src/resolver/resolver_test_helper.h"
|
||||
#include "src/sem/reference_type.h"
|
||||
|
||||
@@ -19,6 +20,8 @@ namespace tint {
|
||||
namespace resolver {
|
||||
namespace {
|
||||
|
||||
using ::testing::HasSubstr;
|
||||
|
||||
// Helpers and typedefs
|
||||
using builder::alias;
|
||||
using builder::alias1;
|
||||
@@ -1641,13 +1644,9 @@ static std::string MatrixStr(const MatrixDimensions& dimensions,
|
||||
std::to_string(dimensions.rows) + "<" + subtype + ">";
|
||||
}
|
||||
|
||||
static std::string VecStr(uint32_t dimensions, std::string subtype = "f32") {
|
||||
return "vec" + std::to_string(dimensions) + "<" + subtype + ">";
|
||||
}
|
||||
|
||||
using MatrixConstructorTest = ResolverTestWithParam<MatrixDimensions>;
|
||||
|
||||
TEST_P(MatrixConstructorTest, Expr_Constructor_Error_TooFewArguments) {
|
||||
TEST_P(MatrixConstructorTest, Expr_ColumnConstructor_Error_TooFewArguments) {
|
||||
// matNxM<f32>(vecM<f32>(), ...); with N - 1 arguments
|
||||
|
||||
const auto param = GetParam();
|
||||
@@ -1665,13 +1664,34 @@ TEST_P(MatrixConstructorTest, Expr_Constructor_Error_TooFewArguments) {
|
||||
WrapInFunction(tc);
|
||||
|
||||
EXPECT_FALSE(r()->Resolve());
|
||||
EXPECT_EQ(r()->error(),
|
||||
"12:1 error: expected " + std::to_string(param.columns) + " '" +
|
||||
VecStr(param.rows) + "' arguments in '" + MatrixStr(param) +
|
||||
"' constructor, found " + std::to_string(param.columns - 1));
|
||||
EXPECT_THAT(r()->error(),
|
||||
HasSubstr("12:1 error: invalid constructor for " +
|
||||
MatrixStr(param) + "\n\n3 candidates available:"));
|
||||
}
|
||||
|
||||
TEST_P(MatrixConstructorTest, Expr_Constructor_Error_TooManyArguments) {
|
||||
TEST_P(MatrixConstructorTest, Expr_ElementConstructor_Error_TooFewArguments) {
|
||||
// matNxM<f32>(f32,...,f32); with N*M - 1 arguments
|
||||
|
||||
const auto param = GetParam();
|
||||
|
||||
ast::ExpressionList args;
|
||||
for (uint32_t i = 1; i <= param.columns * param.rows - 1; i++) {
|
||||
args.push_back(create<ast::TypeConstructorExpression>(
|
||||
Source{{12, i}}, ty.f32(), ExprList()));
|
||||
}
|
||||
|
||||
auto* matrix_type = ty.mat<f32>(param.columns, param.rows);
|
||||
auto* tc = create<ast::TypeConstructorExpression>(Source{}, matrix_type,
|
||||
std::move(args));
|
||||
WrapInFunction(tc);
|
||||
|
||||
EXPECT_FALSE(r()->Resolve());
|
||||
EXPECT_THAT(r()->error(),
|
||||
HasSubstr("12:1 error: invalid constructor for " +
|
||||
MatrixStr(param) + "\n\n3 candidates available:"));
|
||||
}
|
||||
|
||||
TEST_P(MatrixConstructorTest, Expr_ColumnConstructor_Error_TooManyArguments) {
|
||||
// matNxM<f32>(vecM<f32>(), ...); with N + 1 arguments
|
||||
|
||||
const auto param = GetParam();
|
||||
@@ -1689,21 +1709,20 @@ TEST_P(MatrixConstructorTest, Expr_Constructor_Error_TooManyArguments) {
|
||||
WrapInFunction(tc);
|
||||
|
||||
EXPECT_FALSE(r()->Resolve());
|
||||
EXPECT_EQ(r()->error(),
|
||||
"12:1 error: expected " + std::to_string(param.columns) + " '" +
|
||||
VecStr(param.rows) + "' arguments in '" + MatrixStr(param) +
|
||||
"' constructor, found " + std::to_string(param.columns + 1));
|
||||
EXPECT_THAT(r()->error(),
|
||||
HasSubstr("12:1 error: invalid constructor for " +
|
||||
MatrixStr(param) + "\n\n3 candidates available:"));
|
||||
}
|
||||
|
||||
TEST_P(MatrixConstructorTest, Expr_Constructor_Error_InvalidArgumentType) {
|
||||
// matNxM<f32>(1.0, 1.0, ...); N arguments
|
||||
TEST_P(MatrixConstructorTest, Expr_ElementConstructor_Error_TooManyArguments) {
|
||||
// matNxM<f32>(f32,...,f32); with N*M + 1 arguments
|
||||
|
||||
const auto param = GetParam();
|
||||
|
||||
ast::ExpressionList args;
|
||||
for (uint32_t i = 1; i <= param.columns; i++) {
|
||||
args.push_back(create<ast::ScalarConstructorExpression>(Source{{12, i}},
|
||||
Literal(1.0f)));
|
||||
for (uint32_t i = 1; i <= param.columns * param.rows + 1; i++) {
|
||||
args.push_back(create<ast::TypeConstructorExpression>(
|
||||
Source{{12, i}}, ty.f32(), ExprList()));
|
||||
}
|
||||
|
||||
auto* matrix_type = ty.mat<f32>(param.columns, param.rows);
|
||||
@@ -1712,13 +1731,60 @@ TEST_P(MatrixConstructorTest, Expr_Constructor_Error_InvalidArgumentType) {
|
||||
WrapInFunction(tc);
|
||||
|
||||
EXPECT_FALSE(r()->Resolve());
|
||||
EXPECT_EQ(r()->error(), "12:1 error: expected argument type '" +
|
||||
VecStr(param.rows) + "' in '" + MatrixStr(param) +
|
||||
"' constructor, found 'f32'");
|
||||
EXPECT_THAT(r()->error(),
|
||||
HasSubstr("12:1 error: invalid constructor for " +
|
||||
MatrixStr(param) + "\n\n3 candidates available:"));
|
||||
}
|
||||
|
||||
TEST_P(MatrixConstructorTest,
|
||||
Expr_Constructor_Error_TooFewRowsInVectorArgument) {
|
||||
Expr_ColumnConstructor_Error_InvalidArgumentType) {
|
||||
// matNxM<f32>(vec<u32>, vec<u32>, ...); N arguments
|
||||
|
||||
const auto param = GetParam();
|
||||
|
||||
ast::ExpressionList args;
|
||||
for (uint32_t i = 1; i <= param.columns; i++) {
|
||||
auto* vec_type = ty.vec<u32>(param.rows);
|
||||
args.push_back(create<ast::TypeConstructorExpression>(
|
||||
Source{{12, i}}, vec_type, ExprList()));
|
||||
}
|
||||
|
||||
auto* matrix_type = ty.mat<f32>(param.columns, param.rows);
|
||||
auto* tc = create<ast::TypeConstructorExpression>(Source{}, matrix_type,
|
||||
std::move(args));
|
||||
WrapInFunction(tc);
|
||||
|
||||
EXPECT_FALSE(r()->Resolve());
|
||||
EXPECT_THAT(r()->error(),
|
||||
HasSubstr("12:1 error: invalid constructor for " +
|
||||
MatrixStr(param) + "\n\n3 candidates available:"));
|
||||
}
|
||||
|
||||
TEST_P(MatrixConstructorTest,
|
||||
Expr_ElementConstructor_Error_InvalidArgumentType) {
|
||||
// matNxM<f32>(u32, u32, ...); N*M arguments
|
||||
|
||||
const auto param = GetParam();
|
||||
|
||||
ast::ExpressionList args;
|
||||
for (uint32_t i = 1; i <= param.columns; i++) {
|
||||
args.push_back(
|
||||
create<ast::ScalarConstructorExpression>(Source{{12, i}}, Literal(1u)));
|
||||
}
|
||||
|
||||
auto* matrix_type = ty.mat<f32>(param.columns, param.rows);
|
||||
auto* tc = create<ast::TypeConstructorExpression>(Source{}, matrix_type,
|
||||
std::move(args));
|
||||
WrapInFunction(tc);
|
||||
|
||||
EXPECT_FALSE(r()->Resolve());
|
||||
EXPECT_THAT(r()->error(),
|
||||
HasSubstr("12:1 error: invalid constructor for " +
|
||||
MatrixStr(param) + "\n\n3 candidates available:"));
|
||||
}
|
||||
|
||||
TEST_P(MatrixConstructorTest,
|
||||
Expr_ColumnConstructor_Error_TooFewRowsInVectorArgument) {
|
||||
// matNxM<f32>(vecM<f32>(),...,vecM-1<f32>());
|
||||
|
||||
const auto param = GetParam();
|
||||
@@ -1745,15 +1811,13 @@ TEST_P(MatrixConstructorTest,
|
||||
WrapInFunction(tc);
|
||||
|
||||
EXPECT_FALSE(r()->Resolve());
|
||||
EXPECT_EQ(r()->error(), "12:" + std::to_string(kInvalidLoc) +
|
||||
" error: expected argument type '" +
|
||||
VecStr(param.rows) + "' in '" + MatrixStr(param) +
|
||||
"' constructor, found '" +
|
||||
VecStr(param.rows - 1) + "'");
|
||||
EXPECT_THAT(r()->error(),
|
||||
HasSubstr("12:1 error: invalid constructor for " +
|
||||
MatrixStr(param) + "\n\n3 candidates available:"));
|
||||
}
|
||||
|
||||
TEST_P(MatrixConstructorTest,
|
||||
Expr_Constructor_Error_TooManyRowsInVectorArgument) {
|
||||
Expr_ColumnConstructor_Error_TooManyRowsInVectorArgument) {
|
||||
// matNxM<f32>(vecM<f32>(),...,vecM+1<f32>());
|
||||
|
||||
const auto param = GetParam();
|
||||
@@ -1780,36 +1844,9 @@ TEST_P(MatrixConstructorTest,
|
||||
WrapInFunction(tc);
|
||||
|
||||
EXPECT_FALSE(r()->Resolve());
|
||||
EXPECT_EQ(r()->error(), "12:" + std::to_string(kInvalidLoc) +
|
||||
" error: expected argument type '" +
|
||||
VecStr(param.rows) + "' in '" + MatrixStr(param) +
|
||||
"' constructor, found '" +
|
||||
VecStr(param.rows + 1) + "'");
|
||||
}
|
||||
|
||||
TEST_P(MatrixConstructorTest,
|
||||
Expr_Constructor_Error_ArgumentVectorElementTypeMismatch) {
|
||||
// matNxM<f32>(vecM<u32>(), ...); with N arguments
|
||||
|
||||
const auto param = GetParam();
|
||||
|
||||
ast::ExpressionList args;
|
||||
for (uint32_t i = 1; i <= param.columns; i++) {
|
||||
auto* vec_type = ty.vec<u32>(param.rows);
|
||||
args.push_back(create<ast::TypeConstructorExpression>(
|
||||
Source{{12, i}}, vec_type, ExprList()));
|
||||
}
|
||||
|
||||
auto* matrix_type = ty.mat<f32>(param.columns, param.rows);
|
||||
auto* tc = create<ast::TypeConstructorExpression>(Source{}, matrix_type,
|
||||
std::move(args));
|
||||
WrapInFunction(tc);
|
||||
|
||||
EXPECT_FALSE(r()->Resolve());
|
||||
EXPECT_EQ(r()->error(), "12:1 error: expected argument type '" +
|
||||
VecStr(param.rows) + "' in '" + MatrixStr(param) +
|
||||
"' constructor, found '" +
|
||||
VecStr(param.rows, "u32") + "'");
|
||||
EXPECT_THAT(r()->error(),
|
||||
HasSubstr("12:1 error: invalid constructor for " +
|
||||
MatrixStr(param) + "\n\n3 candidates available:"));
|
||||
}
|
||||
|
||||
TEST_P(MatrixConstructorTest, Expr_Constructor_ZeroValue_Success) {
|
||||
@@ -1824,7 +1861,7 @@ TEST_P(MatrixConstructorTest, Expr_Constructor_ZeroValue_Success) {
|
||||
EXPECT_TRUE(r()->Resolve()) << r()->error();
|
||||
}
|
||||
|
||||
TEST_P(MatrixConstructorTest, Expr_Constructor_WithArguments_Success) {
|
||||
TEST_P(MatrixConstructorTest, Expr_Constructor_WithColumns_Success) {
|
||||
// matNxM<f32>(vecM<f32>(), ...); with N arguments
|
||||
|
||||
const auto param = GetParam();
|
||||
@@ -1844,6 +1881,25 @@ TEST_P(MatrixConstructorTest, Expr_Constructor_WithArguments_Success) {
|
||||
EXPECT_TRUE(r()->Resolve()) << r()->error();
|
||||
}
|
||||
|
||||
TEST_P(MatrixConstructorTest, Expr_Constructor_WithElements_Success) {
|
||||
// matNxM<f32>(f32,...,f32); with N*M arguments
|
||||
|
||||
const auto param = GetParam();
|
||||
|
||||
ast::ExpressionList args;
|
||||
for (uint32_t i = 1; i <= param.columns * param.rows; i++) {
|
||||
args.push_back(create<ast::TypeConstructorExpression>(
|
||||
Source{{12, i}}, ty.f32(), ExprList()));
|
||||
}
|
||||
|
||||
auto* matrix_type = ty.mat<f32>(param.columns, param.rows);
|
||||
auto* tc = create<ast::TypeConstructorExpression>(Source{}, matrix_type,
|
||||
std::move(args));
|
||||
WrapInFunction(tc);
|
||||
|
||||
EXPECT_TRUE(r()->Resolve()) << r()->error();
|
||||
}
|
||||
|
||||
TEST_P(MatrixConstructorTest, Expr_Constructor_ElementTypeAlias_Error) {
|
||||
// matNxM<Float32>(vecM<u32>(), ...); with N arguments
|
||||
|
||||
@@ -1863,10 +1919,9 @@ TEST_P(MatrixConstructorTest, Expr_Constructor_ElementTypeAlias_Error) {
|
||||
WrapInFunction(tc);
|
||||
|
||||
EXPECT_FALSE(r()->Resolve());
|
||||
EXPECT_EQ(r()->error(),
|
||||
"12:1 error: expected argument type '" + VecStr(param.rows) +
|
||||
"' in '" + MatrixStr(param, "Float32") +
|
||||
"' constructor, found '" + VecStr(param.rows, "u32") + "'");
|
||||
EXPECT_THAT(r()->error(), HasSubstr("12:1 error: invalid constructor for " +
|
||||
MatrixStr(param, "Float32") +
|
||||
"\n\n3 candidates available:"));
|
||||
}
|
||||
|
||||
TEST_P(MatrixConstructorTest, Expr_Constructor_ElementTypeAlias_Success) {
|
||||
@@ -1900,8 +1955,13 @@ TEST_F(ResolverTypeConstructorValidationTest,
|
||||
|
||||
EXPECT_FALSE(r()->Resolve());
|
||||
EXPECT_EQ(r()->error(),
|
||||
"12:34 error: expected argument type 'vec2<f32>' in 'mat2x2<f32>' "
|
||||
"constructor, found 'VectorUnsigned2'");
|
||||
R"(12:34 error: invalid constructor for mat2x2<f32>
|
||||
|
||||
3 candidates available:
|
||||
mat2x2<f32>()
|
||||
mat2x2<f32>(f32,...,f32) // 4 arguments
|
||||
mat2x2<f32>(vec2<f32>, vec2<f32>)
|
||||
)");
|
||||
}
|
||||
|
||||
TEST_P(MatrixConstructorTest, Expr_Constructor_ArgumentTypeAlias_Success) {
|
||||
@@ -1940,10 +2000,9 @@ TEST_P(MatrixConstructorTest, Expr_Constructor_ArgumentElementTypeAlias_Error) {
|
||||
WrapInFunction(tc);
|
||||
|
||||
EXPECT_FALSE(r()->Resolve());
|
||||
EXPECT_EQ(r()->error(), "12:1 error: expected argument type '" +
|
||||
VecStr(param.rows) + "' in '" + MatrixStr(param) +
|
||||
"' constructor, found '" +
|
||||
VecStr(param.rows, "UnsignedInt") + "'");
|
||||
EXPECT_THAT(r()->error(),
|
||||
HasSubstr("12:1 error: invalid constructor for " +
|
||||
MatrixStr(param) + "\n\n3 candidates available:"));
|
||||
}
|
||||
|
||||
TEST_P(MatrixConstructorTest,
|
||||
|
||||
Reference in New Issue
Block a user