mirror of
https://github.com/encounter/dawn-cmake.git
synced 2025-12-20 10:25:28 +00:00
tint: Add matrix identify and single-scalar ctors
Fixed: tint:1545 Change-Id: I86451223765f620861bf98861142e6d34c7e945b Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/90502 Reviewed-by: David Neto <dneto@google.com> Commit-Queue: Ben Clayton <bclayton@google.com>
This commit is contained in:
committed by
Dawn LUCI CQ
parent
31b379409d
commit
3b5edf1435
@@ -14,12 +14,14 @@
|
||||
|
||||
#include "src/tint/transform/vectorize_scalar_matrix_constructors.h"
|
||||
|
||||
#include <unordered_map>
|
||||
#include <utility>
|
||||
|
||||
#include "src/tint/program_builder.h"
|
||||
#include "src/tint/sem/call.h"
|
||||
#include "src/tint/sem/expression.h"
|
||||
#include "src/tint/sem/type_constructor.h"
|
||||
#include "src/tint/utils/map.h"
|
||||
|
||||
TINT_INSTANTIATE_TYPEINFO(tint::transform::VectorizeScalarMatrixConstructors);
|
||||
|
||||
@@ -44,6 +46,8 @@ bool VectorizeScalarMatrixConstructors::ShouldRun(const Program* program, const
|
||||
}
|
||||
|
||||
void VectorizeScalarMatrixConstructors::Run(CloneContext& ctx, const DataMap&, DataMap&) const {
|
||||
std::unordered_map<const sem::Matrix*, Symbol> scalar_ctors;
|
||||
|
||||
ctx.ReplaceAll([&](const ast::CallExpression* expr) -> const ast::CallExpression* {
|
||||
auto* call = ctx.src->Sem().Get(expr);
|
||||
auto* ty_ctor = call->Target()->As<sem::TypeConstructor>();
|
||||
@@ -64,21 +68,56 @@ void VectorizeScalarMatrixConstructors::Run(CloneContext& ctx, const DataMap&, D
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
// Build a list of vector expressions for each column.
|
||||
ast::ExpressionList columns;
|
||||
for (uint32_t c = 0; c < mat_type->columns(); c++) {
|
||||
// Build a list of scalar expressions for each value in the column.
|
||||
ast::ExpressionList row_values;
|
||||
for (uint32_t r = 0; r < mat_type->rows(); r++) {
|
||||
row_values.push_back(ctx.Clone(args[c * mat_type->rows() + r]->Declaration()));
|
||||
}
|
||||
// Constructs a matrix using vector columns, with the elements constructed using the
|
||||
// 'element(uint32_t c, uint32_t r)' callback.
|
||||
auto build_mat = [&](auto&& element) {
|
||||
ast::ExpressionList columns(mat_type->columns());
|
||||
for (uint32_t c = 0; c < mat_type->columns(); c++) {
|
||||
ast::ExpressionList row_values(mat_type->rows());
|
||||
for (uint32_t r = 0; r < mat_type->rows(); r++) {
|
||||
row_values[r] = element(c, r);
|
||||
}
|
||||
|
||||
// Construct the column vector.
|
||||
auto* col =
|
||||
ctx.dst->vec(CreateASTTypeFor(ctx, mat_type->type()), mat_type->rows(), row_values);
|
||||
columns.push_back(col);
|
||||
// Construct the column vector.
|
||||
columns[c] = ctx.dst->vec(CreateASTTypeFor(ctx, mat_type->type()), mat_type->rows(),
|
||||
row_values);
|
||||
}
|
||||
return ctx.dst->Construct(CreateASTTypeFor(ctx, mat_type), columns);
|
||||
};
|
||||
|
||||
if (args.size() == 1) {
|
||||
// Generate a helper function for constructing the matrix.
|
||||
// This is done to ensure that the single argument value is only evaluated once, and
|
||||
// with the correct expression evaluation order.
|
||||
auto fn = utils::GetOrCreate(scalar_ctors, mat_type, [&] {
|
||||
auto name =
|
||||
ctx.dst->Symbols().New("build_mat" + std::to_string(mat_type->columns()) + "x" +
|
||||
std::to_string(mat_type->rows()));
|
||||
ctx.dst->Func(name,
|
||||
{
|
||||
// Single scalar parameter
|
||||
ctx.dst->Param("value", CreateASTTypeFor(ctx, mat_type->type())),
|
||||
},
|
||||
CreateASTTypeFor(ctx, mat_type),
|
||||
{
|
||||
ctx.dst->Return(build_mat([&](uint32_t, uint32_t) { //
|
||||
return ctx.dst->Expr("value");
|
||||
})),
|
||||
});
|
||||
return name;
|
||||
});
|
||||
return ctx.dst->Call(fn, ctx.Clone(args[0]->Declaration()));
|
||||
}
|
||||
return ctx.dst->Construct(CreateASTTypeFor(ctx, mat_type), columns);
|
||||
|
||||
if (args.size() == mat_type->columns() * mat_type->rows()) {
|
||||
return build_mat([&](uint32_t c, uint32_t r) {
|
||||
return ctx.Clone(args[c * mat_type->rows() + r]->Declaration());
|
||||
});
|
||||
}
|
||||
|
||||
TINT_ICE(Transform, ctx.dst->Diagnostics())
|
||||
<< "matrix constructor has unexpected number of arguments";
|
||||
return nullptr;
|
||||
});
|
||||
|
||||
ctx.Clone();
|
||||
|
||||
@@ -31,7 +31,57 @@ TEST_F(VectorizeScalarMatrixConstructorsTest, ShouldRunEmptyModule) {
|
||||
EXPECT_FALSE(ShouldRun<VectorizeScalarMatrixConstructors>(src));
|
||||
}
|
||||
|
||||
TEST_P(VectorizeScalarMatrixConstructorsTest, Basic) {
|
||||
TEST_P(VectorizeScalarMatrixConstructorsTest, SingleScalars) {
|
||||
uint32_t cols = GetParam().first;
|
||||
uint32_t rows = GetParam().second;
|
||||
std::string matrix_no_type = "mat" + std::to_string(cols) + "x" + std::to_string(rows);
|
||||
std::string matrix = matrix_no_type + "<f32>";
|
||||
std::string vector = "vec" + std::to_string(rows) + "<f32>";
|
||||
std::string values;
|
||||
for (uint32_t c = 0; c < cols; c++) {
|
||||
if (c > 0) {
|
||||
values += ", ";
|
||||
}
|
||||
values += vector + "(";
|
||||
for (uint32_t r = 0; r < rows; r++) {
|
||||
if (r > 0) {
|
||||
values += ", ";
|
||||
}
|
||||
values += "value";
|
||||
}
|
||||
values += ")";
|
||||
}
|
||||
|
||||
std::string src = R"(
|
||||
@stage(fragment)
|
||||
fn main() {
|
||||
let m = ${matrix}(42.0);
|
||||
}
|
||||
)";
|
||||
|
||||
std::string expect = R"(
|
||||
fn build_${matrix_no_type}(value : f32) -> ${matrix} {
|
||||
return ${matrix}(${values});
|
||||
}
|
||||
|
||||
@stage(fragment)
|
||||
fn main() {
|
||||
let m = build_${matrix_no_type}(42.0);
|
||||
}
|
||||
)";
|
||||
src = utils::ReplaceAll(src, "${matrix}", matrix);
|
||||
expect = utils::ReplaceAll(expect, "${matrix}", matrix);
|
||||
expect = utils::ReplaceAll(expect, "${matrix_no_type}", matrix_no_type);
|
||||
expect = utils::ReplaceAll(expect, "${values}", values);
|
||||
|
||||
EXPECT_TRUE(ShouldRun<VectorizeScalarMatrixConstructors>(src));
|
||||
|
||||
auto got = Run<VectorizeScalarMatrixConstructors>(src);
|
||||
|
||||
EXPECT_EQ(expect, str(got));
|
||||
}
|
||||
|
||||
TEST_P(VectorizeScalarMatrixConstructorsTest, MultipleScalars) {
|
||||
uint32_t cols = GetParam().first;
|
||||
uint32_t rows = GetParam().second;
|
||||
std::string mat_type = "mat" + std::to_string(cols) + "x" + std::to_string(rows) + "<f32>";
|
||||
|
||||
Reference in New Issue
Block a user