tint: Add matrix identify and single-scalar ctors

Fixed: tint:1545
Change-Id: I86451223765f620861bf98861142e6d34c7e945b
Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/90502
Reviewed-by: David Neto <dneto@google.com>
Commit-Queue: Ben Clayton <bclayton@google.com>
This commit is contained in:
Ben Clayton
2022-05-16 21:14:11 +00:00
committed by Dawn LUCI CQ
parent 31b379409d
commit 3b5edf1435
95 changed files with 5468 additions and 4232 deletions

View File

@@ -14,12 +14,14 @@
#include "src/tint/transform/vectorize_scalar_matrix_constructors.h"
#include <unordered_map>
#include <utility>
#include "src/tint/program_builder.h"
#include "src/tint/sem/call.h"
#include "src/tint/sem/expression.h"
#include "src/tint/sem/type_constructor.h"
#include "src/tint/utils/map.h"
TINT_INSTANTIATE_TYPEINFO(tint::transform::VectorizeScalarMatrixConstructors);
@@ -44,6 +46,8 @@ bool VectorizeScalarMatrixConstructors::ShouldRun(const Program* program, const
}
void VectorizeScalarMatrixConstructors::Run(CloneContext& ctx, const DataMap&, DataMap&) const {
std::unordered_map<const sem::Matrix*, Symbol> scalar_ctors;
ctx.ReplaceAll([&](const ast::CallExpression* expr) -> const ast::CallExpression* {
auto* call = ctx.src->Sem().Get(expr);
auto* ty_ctor = call->Target()->As<sem::TypeConstructor>();
@@ -64,21 +68,56 @@ void VectorizeScalarMatrixConstructors::Run(CloneContext& ctx, const DataMap&, D
return nullptr;
}
// Build a list of vector expressions for each column.
ast::ExpressionList columns;
for (uint32_t c = 0; c < mat_type->columns(); c++) {
// Build a list of scalar expressions for each value in the column.
ast::ExpressionList row_values;
for (uint32_t r = 0; r < mat_type->rows(); r++) {
row_values.push_back(ctx.Clone(args[c * mat_type->rows() + r]->Declaration()));
}
// Constructs a matrix using vector columns, with the elements constructed using the
// 'element(uint32_t c, uint32_t r)' callback.
auto build_mat = [&](auto&& element) {
ast::ExpressionList columns(mat_type->columns());
for (uint32_t c = 0; c < mat_type->columns(); c++) {
ast::ExpressionList row_values(mat_type->rows());
for (uint32_t r = 0; r < mat_type->rows(); r++) {
row_values[r] = element(c, r);
}
// Construct the column vector.
auto* col =
ctx.dst->vec(CreateASTTypeFor(ctx, mat_type->type()), mat_type->rows(), row_values);
columns.push_back(col);
// Construct the column vector.
columns[c] = ctx.dst->vec(CreateASTTypeFor(ctx, mat_type->type()), mat_type->rows(),
row_values);
}
return ctx.dst->Construct(CreateASTTypeFor(ctx, mat_type), columns);
};
if (args.size() == 1) {
// Generate a helper function for constructing the matrix.
// This is done to ensure that the single argument value is only evaluated once, and
// with the correct expression evaluation order.
auto fn = utils::GetOrCreate(scalar_ctors, mat_type, [&] {
auto name =
ctx.dst->Symbols().New("build_mat" + std::to_string(mat_type->columns()) + "x" +
std::to_string(mat_type->rows()));
ctx.dst->Func(name,
{
// Single scalar parameter
ctx.dst->Param("value", CreateASTTypeFor(ctx, mat_type->type())),
},
CreateASTTypeFor(ctx, mat_type),
{
ctx.dst->Return(build_mat([&](uint32_t, uint32_t) { //
return ctx.dst->Expr("value");
})),
});
return name;
});
return ctx.dst->Call(fn, ctx.Clone(args[0]->Declaration()));
}
return ctx.dst->Construct(CreateASTTypeFor(ctx, mat_type), columns);
if (args.size() == mat_type->columns() * mat_type->rows()) {
return build_mat([&](uint32_t c, uint32_t r) {
return ctx.Clone(args[c * mat_type->rows() + r]->Declaration());
});
}
TINT_ICE(Transform, ctx.dst->Diagnostics())
<< "matrix constructor has unexpected number of arguments";
return nullptr;
});
ctx.Clone();

View File

@@ -31,7 +31,57 @@ TEST_F(VectorizeScalarMatrixConstructorsTest, ShouldRunEmptyModule) {
EXPECT_FALSE(ShouldRun<VectorizeScalarMatrixConstructors>(src));
}
TEST_P(VectorizeScalarMatrixConstructorsTest, Basic) {
TEST_P(VectorizeScalarMatrixConstructorsTest, SingleScalars) {
uint32_t cols = GetParam().first;
uint32_t rows = GetParam().second;
std::string matrix_no_type = "mat" + std::to_string(cols) + "x" + std::to_string(rows);
std::string matrix = matrix_no_type + "<f32>";
std::string vector = "vec" + std::to_string(rows) + "<f32>";
std::string values;
for (uint32_t c = 0; c < cols; c++) {
if (c > 0) {
values += ", ";
}
values += vector + "(";
for (uint32_t r = 0; r < rows; r++) {
if (r > 0) {
values += ", ";
}
values += "value";
}
values += ")";
}
std::string src = R"(
@stage(fragment)
fn main() {
let m = ${matrix}(42.0);
}
)";
std::string expect = R"(
fn build_${matrix_no_type}(value : f32) -> ${matrix} {
return ${matrix}(${values});
}
@stage(fragment)
fn main() {
let m = build_${matrix_no_type}(42.0);
}
)";
src = utils::ReplaceAll(src, "${matrix}", matrix);
expect = utils::ReplaceAll(expect, "${matrix}", matrix);
expect = utils::ReplaceAll(expect, "${matrix_no_type}", matrix_no_type);
expect = utils::ReplaceAll(expect, "${values}", values);
EXPECT_TRUE(ShouldRun<VectorizeScalarMatrixConstructors>(src));
auto got = Run<VectorizeScalarMatrixConstructors>(src);
EXPECT_EQ(expect, str(got));
}
TEST_P(VectorizeScalarMatrixConstructorsTest, MultipleScalars) {
uint32_t cols = GetParam().first;
uint32_t rows = GetParam().second;
std::string mat_type = "mat" + std::to_string(cols) + "x" + std::to_string(rows) + "<f32>";