mirror of
https://github.com/encounter/dawn-cmake.git
synced 2025-12-18 17:35:30 +00:00
tint: Add matrix identify and single-scalar ctors
Fixed: tint:1545 Change-Id: I86451223765f620861bf98861142e6d34c7e945b Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/90502 Reviewed-by: David Neto <dneto@google.com> Commit-Queue: Ben Clayton <bclayton@google.com>
This commit is contained in:
committed by
Dawn LUCI CQ
parent
31b379409d
commit
3b5edf1435
@@ -605,6 +605,15 @@ ctor bool(bool) -> bool
|
||||
ctor vec2<T: scalar>(vec2<T>) -> vec2<T>
|
||||
ctor vec3<T: scalar>(vec3<T>) -> vec3<T>
|
||||
ctor vec4<T: scalar>(vec4<T>) -> vec4<T>
|
||||
ctor mat2x2<f32>(mat2x2<f32>) -> mat2x2<f32>
|
||||
ctor mat2x3<f32>(mat2x3<f32>) -> mat2x3<f32>
|
||||
ctor mat2x4<f32>(mat2x4<f32>) -> mat2x4<f32>
|
||||
ctor mat3x2<f32>(mat3x2<f32>) -> mat3x2<f32>
|
||||
ctor mat3x3<f32>(mat3x3<f32>) -> mat3x3<f32>
|
||||
ctor mat3x4<f32>(mat3x4<f32>) -> mat3x4<f32>
|
||||
ctor mat4x2<f32>(mat4x2<f32>) -> mat4x2<f32>
|
||||
ctor mat4x3<f32>(mat4x3<f32>) -> mat4x3<f32>
|
||||
ctor mat4x4<f32>(mat4x4<f32>) -> mat4x4<f32>
|
||||
|
||||
// Vector constructors
|
||||
ctor vec2<T: scalar>(T) -> vec2<T>
|
||||
@@ -623,50 +632,59 @@ ctor vec4<T: scalar>(xyz: vec3<T>, w: T) -> vec4<T>
|
||||
ctor vec4<T: scalar>(x: T, zyw: vec3<T>) -> vec4<T>
|
||||
|
||||
// Matrix constructors
|
||||
ctor mat2x2(f32, f32,
|
||||
f32, f32) -> mat2x2<f32>
|
||||
ctor mat2x2(vec2<f32>, vec2<f32>) -> mat2x2<f32>
|
||||
ctor mat2x2<T: f32>(T) -> mat2x2<T>
|
||||
ctor mat2x2<T: f32>(T, T,
|
||||
T, T) -> mat2x2<T>
|
||||
ctor mat2x2<T: f32>(vec2<T>, vec2<T>) -> mat2x2<T>
|
||||
|
||||
ctor mat2x3(f32, f32, f32,
|
||||
f32, f32, f32) -> mat2x3<f32>
|
||||
ctor mat2x3(vec3<f32>, vec3<f32>) -> mat2x3<f32>
|
||||
ctor mat2x3<T: f32>(T) -> mat2x3<T>
|
||||
ctor mat2x3<T: f32>(T, T, T,
|
||||
T, T, T) -> mat2x3<T>
|
||||
ctor mat2x3<T: f32>(vec3<T>, vec3<T>) -> mat2x3<T>
|
||||
|
||||
ctor mat2x4(f32, f32, f32, f32,
|
||||
f32, f32, f32, f32) -> mat2x4<f32>
|
||||
ctor mat2x4(vec4<f32>, vec4<f32>) -> mat2x4<f32>
|
||||
ctor mat2x4<T: f32>(T) -> mat2x4<T>
|
||||
ctor mat2x4<T: f32>(T, T, T, T,
|
||||
T, T, T, T) -> mat2x4<T>
|
||||
ctor mat2x4<T: f32>(vec4<T>, vec4<T>) -> mat2x4<T>
|
||||
|
||||
ctor mat3x2(f32, f32,
|
||||
f32, f32,
|
||||
f32, f32) -> mat3x2<f32>
|
||||
ctor mat3x2(vec2<f32>, vec2<f32>, vec2<f32>) -> mat3x2<f32>
|
||||
ctor mat3x2<T: f32>(T) -> mat3x2<T>
|
||||
ctor mat3x2<T: f32>(T, T,
|
||||
T, T,
|
||||
T, T) -> mat3x2<T>
|
||||
ctor mat3x2<T: f32>(vec2<T>, vec2<T>, vec2<T>) -> mat3x2<T>
|
||||
|
||||
ctor mat3x3(f32, f32, f32,
|
||||
f32, f32, f32,
|
||||
f32, f32, f32) -> mat3x3<f32>
|
||||
ctor mat3x3(vec3<f32>, vec3<f32>, vec3<f32>) -> mat3x3<f32>
|
||||
ctor mat3x3<T: f32>(T) -> mat3x3<T>
|
||||
ctor mat3x3<T: f32>(T, T, T,
|
||||
T, T, T,
|
||||
T, T, T) -> mat3x3<T>
|
||||
ctor mat3x3<T: f32>(vec3<T>, vec3<T>, vec3<T>) -> mat3x3<T>
|
||||
|
||||
ctor mat3x4(f32, f32, f32, f32,
|
||||
f32, f32, f32, f32,
|
||||
f32, f32, f32, f32) -> mat3x4<f32>
|
||||
ctor mat3x4(vec4<f32>, vec4<f32>, vec4<f32>) -> mat3x4<f32>
|
||||
ctor mat3x4<T: f32>(T) -> mat3x4<T>
|
||||
ctor mat3x4<T: f32>(T, T, T, T,
|
||||
T, T, T, T,
|
||||
T, T, T, T) -> mat3x4<T>
|
||||
ctor mat3x4<T: f32>(vec4<T>, vec4<T>, vec4<T>) -> mat3x4<T>
|
||||
|
||||
ctor mat4x2(f32, f32,
|
||||
f32, f32,
|
||||
f32, f32,
|
||||
f32, f32) -> mat4x2<f32>
|
||||
ctor mat4x2(vec2<f32>, vec2<f32>, vec2<f32>, vec2<f32>) -> mat4x2<f32>
|
||||
ctor mat4x2<T: f32>(T) -> mat4x2<T>
|
||||
ctor mat4x2<T: f32>(T, T,
|
||||
T, T,
|
||||
T, T,
|
||||
T, T) -> mat4x2<T>
|
||||
ctor mat4x2<T: f32>(vec2<T>, vec2<T>, vec2<T>, vec2<T>) -> mat4x2<T>
|
||||
|
||||
ctor mat4x3(f32, f32, f32,
|
||||
f32, f32, f32,
|
||||
f32, f32, f32,
|
||||
f32, f32, f32) -> mat4x3<f32>
|
||||
ctor mat4x3(vec3<f32>, vec3<f32>, vec3<f32>, vec3<f32>) -> mat4x3<f32>
|
||||
ctor mat4x3<T: f32>(T) -> mat4x3<T>
|
||||
ctor mat4x3<T: f32>(T, T, T,
|
||||
T, T, T,
|
||||
T, T, T,
|
||||
T, T, T) -> mat4x3<T>
|
||||
ctor mat4x3<T: f32>(vec3<T>, vec3<T>, vec3<T>, vec3<T>) -> mat4x3<T>
|
||||
|
||||
ctor mat4x4(f32, f32, f32, f32,
|
||||
f32, f32, f32, f32,
|
||||
f32, f32, f32, f32,
|
||||
f32, f32, f32, f32) -> mat4x4<f32>
|
||||
ctor mat4x4(vec4<f32>, vec4<f32>, vec4<f32>, vec4<f32>) -> mat4x4<f32>
|
||||
ctor mat4x4<T: f32>(T) -> mat4x4<T>
|
||||
ctor mat4x4<T: f32>(T, T, T, T,
|
||||
T, T, T, T,
|
||||
T, T, T, T,
|
||||
T, T, T, T) -> mat4x4<T>
|
||||
ctor mat4x4<T: f32>(vec4<T>, vec4<T>, vec4<T>, vec4<T>) -> mat4x4<T>
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
// Type conversions //
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -235,15 +235,18 @@ constexpr Params ParamsFor(Kind kind) {
|
||||
}
|
||||
|
||||
static constexpr Params valid_cases[] = {
|
||||
// Direct init (non-conversions)
|
||||
ParamsFor<bool, bool>(Kind::Construct), //
|
||||
ParamsFor<i32, i32>(Kind::Construct), //
|
||||
ParamsFor<u32, u32>(Kind::Construct), //
|
||||
ParamsFor<f32, f32>(Kind::Construct), //
|
||||
ParamsFor<vec3<bool>, vec3<bool>>(Kind::Construct), //
|
||||
ParamsFor<vec3<i32>, vec3<i32>>(Kind::Construct), //
|
||||
ParamsFor<vec3<u32>, vec3<u32>>(Kind::Construct), //
|
||||
ParamsFor<vec3<f32>, vec3<f32>>(Kind::Construct), //
|
||||
// Identity
|
||||
ParamsFor<bool, bool>(Kind::Construct), //
|
||||
ParamsFor<i32, i32>(Kind::Construct), //
|
||||
ParamsFor<u32, u32>(Kind::Construct), //
|
||||
ParamsFor<f32, f32>(Kind::Construct), //
|
||||
ParamsFor<vec3<bool>, vec3<bool>>(Kind::Construct), //
|
||||
ParamsFor<vec3<i32>, vec3<i32>>(Kind::Construct), //
|
||||
ParamsFor<vec3<u32>, vec3<u32>>(Kind::Construct), //
|
||||
ParamsFor<vec3<f32>, vec3<f32>>(Kind::Construct), //
|
||||
ParamsFor<mat3x3<f32>, mat3x3<f32>>(Kind::Construct), //
|
||||
ParamsFor<mat2x3<f32>, mat2x3<f32>>(Kind::Construct), //
|
||||
ParamsFor<mat3x2<f32>, mat3x2<f32>>(Kind::Construct), //
|
||||
|
||||
// Splat
|
||||
ParamsFor<vec3<bool>, bool>(Kind::Construct), //
|
||||
@@ -251,6 +254,10 @@ static constexpr Params valid_cases[] = {
|
||||
ParamsFor<vec3<u32>, u32>(Kind::Construct), //
|
||||
ParamsFor<vec3<f32>, f32>(Kind::Construct), //
|
||||
|
||||
ParamsFor<mat3x3<f32>, f32>(Kind::Construct), //
|
||||
ParamsFor<mat2x3<f32>, f32>(Kind::Construct), //
|
||||
ParamsFor<mat3x2<f32>, f32>(Kind::Construct), //
|
||||
|
||||
// Conversion
|
||||
ParamsFor<bool, u32>(Kind::Conversion), //
|
||||
ParamsFor<bool, i32>(Kind::Conversion), //
|
||||
@@ -2016,9 +2023,8 @@ TEST_P(MatrixConstructorTest, Expr_ColumnConstructor_Error_TooFewArguments) {
|
||||
WrapInFunction(tc);
|
||||
|
||||
EXPECT_FALSE(r()->Resolve());
|
||||
EXPECT_THAT(r()->error(),
|
||||
HasSubstr("12:34 error: no matching constructor for " + MatrixStr(param) + "(" +
|
||||
args_tys.str() + ")\n\n3 candidate constructors:"));
|
||||
EXPECT_THAT(r()->error(), HasSubstr("12:34 error: no matching constructor for " +
|
||||
MatrixStr(param) + "(" + args_tys.str() + ")"));
|
||||
}
|
||||
|
||||
TEST_P(MatrixConstructorTest, Expr_ElementConstructor_Error_TooFewArguments) {
|
||||
@@ -2041,9 +2047,8 @@ TEST_P(MatrixConstructorTest, Expr_ElementConstructor_Error_TooFewArguments) {
|
||||
WrapInFunction(tc);
|
||||
|
||||
EXPECT_FALSE(r()->Resolve());
|
||||
EXPECT_THAT(r()->error(),
|
||||
HasSubstr("12:34 error: no matching constructor for " + MatrixStr(param) + "(" +
|
||||
args_tys.str() + ")\n\n3 candidate constructors:"));
|
||||
EXPECT_THAT(r()->error(), HasSubstr("12:34 error: no matching constructor for " +
|
||||
MatrixStr(param) + "(" + args_tys.str() + ")"));
|
||||
}
|
||||
|
||||
TEST_P(MatrixConstructorTest, Expr_ColumnConstructor_Error_TooManyArguments) {
|
||||
@@ -2067,9 +2072,8 @@ TEST_P(MatrixConstructorTest, Expr_ColumnConstructor_Error_TooManyArguments) {
|
||||
WrapInFunction(tc);
|
||||
|
||||
EXPECT_FALSE(r()->Resolve());
|
||||
EXPECT_THAT(r()->error(),
|
||||
HasSubstr("12:34 error: no matching constructor for " + MatrixStr(param) + "(" +
|
||||
args_tys.str() + ")\n\n3 candidate constructors:"));
|
||||
EXPECT_THAT(r()->error(), HasSubstr("12:34 error: no matching constructor for " +
|
||||
MatrixStr(param) + "(" + args_tys.str() + ")"));
|
||||
}
|
||||
|
||||
TEST_P(MatrixConstructorTest, Expr_ElementConstructor_Error_TooManyArguments) {
|
||||
@@ -2092,9 +2096,8 @@ TEST_P(MatrixConstructorTest, Expr_ElementConstructor_Error_TooManyArguments) {
|
||||
WrapInFunction(tc);
|
||||
|
||||
EXPECT_FALSE(r()->Resolve());
|
||||
EXPECT_THAT(r()->error(),
|
||||
HasSubstr("12:34 error: no matching constructor for " + MatrixStr(param) + "(" +
|
||||
args_tys.str() + ")\n\n3 candidate constructors:"));
|
||||
EXPECT_THAT(r()->error(), HasSubstr("12:34 error: no matching constructor for " +
|
||||
MatrixStr(param) + "(" + args_tys.str() + ")"));
|
||||
}
|
||||
|
||||
TEST_P(MatrixConstructorTest, Expr_ColumnConstructor_Error_InvalidArgumentType) {
|
||||
@@ -2118,9 +2121,8 @@ TEST_P(MatrixConstructorTest, Expr_ColumnConstructor_Error_InvalidArgumentType)
|
||||
WrapInFunction(tc);
|
||||
|
||||
EXPECT_FALSE(r()->Resolve());
|
||||
EXPECT_THAT(r()->error(),
|
||||
HasSubstr("12:34 error: no matching constructor for " + MatrixStr(param) + "(" +
|
||||
args_tys.str() + ")\n\n3 candidate constructors:"));
|
||||
EXPECT_THAT(r()->error(), HasSubstr("12:34 error: no matching constructor for " +
|
||||
MatrixStr(param) + "(" + args_tys.str() + ")"));
|
||||
}
|
||||
|
||||
TEST_P(MatrixConstructorTest, Expr_ElementConstructor_Error_InvalidArgumentType) {
|
||||
@@ -2143,9 +2145,8 @@ TEST_P(MatrixConstructorTest, Expr_ElementConstructor_Error_InvalidArgumentType)
|
||||
WrapInFunction(tc);
|
||||
|
||||
EXPECT_FALSE(r()->Resolve());
|
||||
EXPECT_THAT(r()->error(),
|
||||
HasSubstr("12:34 error: no matching constructor for " + MatrixStr(param) + "(" +
|
||||
args_tys.str() + ")\n\n3 candidate constructors:"));
|
||||
EXPECT_THAT(r()->error(), HasSubstr("12:34 error: no matching constructor for " +
|
||||
MatrixStr(param) + "(" + args_tys.str() + ")"));
|
||||
}
|
||||
|
||||
TEST_P(MatrixConstructorTest, Expr_ColumnConstructor_Error_TooFewRowsInVectorArgument) {
|
||||
@@ -2178,9 +2179,8 @@ TEST_P(MatrixConstructorTest, Expr_ColumnConstructor_Error_TooFewRowsInVectorArg
|
||||
WrapInFunction(tc);
|
||||
|
||||
EXPECT_FALSE(r()->Resolve());
|
||||
EXPECT_THAT(r()->error(),
|
||||
HasSubstr("12:34 error: no matching constructor for " + MatrixStr(param) + "(" +
|
||||
args_tys.str() + ")\n\n3 candidate constructors:"));
|
||||
EXPECT_THAT(r()->error(), HasSubstr("12:34 error: no matching constructor for " +
|
||||
MatrixStr(param) + "(" + args_tys.str() + ")"));
|
||||
}
|
||||
|
||||
TEST_P(MatrixConstructorTest, Expr_ColumnConstructor_Error_TooManyRowsInVectorArgument) {
|
||||
@@ -2212,9 +2212,8 @@ TEST_P(MatrixConstructorTest, Expr_ColumnConstructor_Error_TooManyRowsInVectorAr
|
||||
WrapInFunction(tc);
|
||||
|
||||
EXPECT_FALSE(r()->Resolve());
|
||||
EXPECT_THAT(r()->error(),
|
||||
HasSubstr("12:34 error: no matching constructor for " + MatrixStr(param) + "(" +
|
||||
args_tys.str() + ")\n\n3 candidate constructors:"));
|
||||
EXPECT_THAT(r()->error(), HasSubstr("12:34 error: no matching constructor for " +
|
||||
MatrixStr(param) + "(" + args_tys.str() + ")"));
|
||||
}
|
||||
|
||||
TEST_P(MatrixConstructorTest, Expr_Constructor_ZeroValue_Success) {
|
||||
@@ -2285,9 +2284,8 @@ TEST_P(MatrixConstructorTest, Expr_Constructor_ElementTypeAlias_Error) {
|
||||
WrapInFunction(tc);
|
||||
|
||||
EXPECT_FALSE(r()->Resolve());
|
||||
EXPECT_THAT(r()->error(),
|
||||
HasSubstr("12:34 error: no matching constructor for " + MatrixStr(param) + "(" +
|
||||
args_tys.str() + ")\n\n3 candidate constructors:"));
|
||||
EXPECT_THAT(r()->error(), HasSubstr("12:34 error: no matching constructor for " +
|
||||
MatrixStr(param) + "(" + args_tys.str() + ")"));
|
||||
}
|
||||
|
||||
TEST_P(MatrixConstructorTest, Expr_Constructor_ElementTypeAlias_Success) {
|
||||
@@ -2357,9 +2355,8 @@ TEST_P(MatrixConstructorTest, Expr_Constructor_ArgumentElementTypeAlias_Error) {
|
||||
WrapInFunction(tc);
|
||||
|
||||
EXPECT_FALSE(r()->Resolve());
|
||||
EXPECT_THAT(r()->error(),
|
||||
HasSubstr("12:34 error: no matching constructor for " + MatrixStr(param) + "(" +
|
||||
args_tys.str() + ")\n\n3 candidate constructors:"));
|
||||
EXPECT_THAT(r()->error(), HasSubstr("12:34 error: no matching constructor for " +
|
||||
MatrixStr(param) + "(" + args_tys.str() + ")"));
|
||||
}
|
||||
|
||||
TEST_P(MatrixConstructorTest, Expr_Constructor_ArgumentElementTypeAlias_Success) {
|
||||
|
||||
@@ -14,12 +14,14 @@
|
||||
|
||||
#include "src/tint/transform/vectorize_scalar_matrix_constructors.h"
|
||||
|
||||
#include <unordered_map>
|
||||
#include <utility>
|
||||
|
||||
#include "src/tint/program_builder.h"
|
||||
#include "src/tint/sem/call.h"
|
||||
#include "src/tint/sem/expression.h"
|
||||
#include "src/tint/sem/type_constructor.h"
|
||||
#include "src/tint/utils/map.h"
|
||||
|
||||
TINT_INSTANTIATE_TYPEINFO(tint::transform::VectorizeScalarMatrixConstructors);
|
||||
|
||||
@@ -44,6 +46,8 @@ bool VectorizeScalarMatrixConstructors::ShouldRun(const Program* program, const
|
||||
}
|
||||
|
||||
void VectorizeScalarMatrixConstructors::Run(CloneContext& ctx, const DataMap&, DataMap&) const {
|
||||
std::unordered_map<const sem::Matrix*, Symbol> scalar_ctors;
|
||||
|
||||
ctx.ReplaceAll([&](const ast::CallExpression* expr) -> const ast::CallExpression* {
|
||||
auto* call = ctx.src->Sem().Get(expr);
|
||||
auto* ty_ctor = call->Target()->As<sem::TypeConstructor>();
|
||||
@@ -64,21 +68,56 @@ void VectorizeScalarMatrixConstructors::Run(CloneContext& ctx, const DataMap&, D
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
// Build a list of vector expressions for each column.
|
||||
ast::ExpressionList columns;
|
||||
for (uint32_t c = 0; c < mat_type->columns(); c++) {
|
||||
// Build a list of scalar expressions for each value in the column.
|
||||
ast::ExpressionList row_values;
|
||||
for (uint32_t r = 0; r < mat_type->rows(); r++) {
|
||||
row_values.push_back(ctx.Clone(args[c * mat_type->rows() + r]->Declaration()));
|
||||
}
|
||||
// Constructs a matrix using vector columns, with the elements constructed using the
|
||||
// 'element(uint32_t c, uint32_t r)' callback.
|
||||
auto build_mat = [&](auto&& element) {
|
||||
ast::ExpressionList columns(mat_type->columns());
|
||||
for (uint32_t c = 0; c < mat_type->columns(); c++) {
|
||||
ast::ExpressionList row_values(mat_type->rows());
|
||||
for (uint32_t r = 0; r < mat_type->rows(); r++) {
|
||||
row_values[r] = element(c, r);
|
||||
}
|
||||
|
||||
// Construct the column vector.
|
||||
auto* col =
|
||||
ctx.dst->vec(CreateASTTypeFor(ctx, mat_type->type()), mat_type->rows(), row_values);
|
||||
columns.push_back(col);
|
||||
// Construct the column vector.
|
||||
columns[c] = ctx.dst->vec(CreateASTTypeFor(ctx, mat_type->type()), mat_type->rows(),
|
||||
row_values);
|
||||
}
|
||||
return ctx.dst->Construct(CreateASTTypeFor(ctx, mat_type), columns);
|
||||
};
|
||||
|
||||
if (args.size() == 1) {
|
||||
// Generate a helper function for constructing the matrix.
|
||||
// This is done to ensure that the single argument value is only evaluated once, and
|
||||
// with the correct expression evaluation order.
|
||||
auto fn = utils::GetOrCreate(scalar_ctors, mat_type, [&] {
|
||||
auto name =
|
||||
ctx.dst->Symbols().New("build_mat" + std::to_string(mat_type->columns()) + "x" +
|
||||
std::to_string(mat_type->rows()));
|
||||
ctx.dst->Func(name,
|
||||
{
|
||||
// Single scalar parameter
|
||||
ctx.dst->Param("value", CreateASTTypeFor(ctx, mat_type->type())),
|
||||
},
|
||||
CreateASTTypeFor(ctx, mat_type),
|
||||
{
|
||||
ctx.dst->Return(build_mat([&](uint32_t, uint32_t) { //
|
||||
return ctx.dst->Expr("value");
|
||||
})),
|
||||
});
|
||||
return name;
|
||||
});
|
||||
return ctx.dst->Call(fn, ctx.Clone(args[0]->Declaration()));
|
||||
}
|
||||
return ctx.dst->Construct(CreateASTTypeFor(ctx, mat_type), columns);
|
||||
|
||||
if (args.size() == mat_type->columns() * mat_type->rows()) {
|
||||
return build_mat([&](uint32_t c, uint32_t r) {
|
||||
return ctx.Clone(args[c * mat_type->rows() + r]->Declaration());
|
||||
});
|
||||
}
|
||||
|
||||
TINT_ICE(Transform, ctx.dst->Diagnostics())
|
||||
<< "matrix constructor has unexpected number of arguments";
|
||||
return nullptr;
|
||||
});
|
||||
|
||||
ctx.Clone();
|
||||
|
||||
@@ -31,7 +31,57 @@ TEST_F(VectorizeScalarMatrixConstructorsTest, ShouldRunEmptyModule) {
|
||||
EXPECT_FALSE(ShouldRun<VectorizeScalarMatrixConstructors>(src));
|
||||
}
|
||||
|
||||
TEST_P(VectorizeScalarMatrixConstructorsTest, Basic) {
|
||||
TEST_P(VectorizeScalarMatrixConstructorsTest, SingleScalars) {
|
||||
uint32_t cols = GetParam().first;
|
||||
uint32_t rows = GetParam().second;
|
||||
std::string matrix_no_type = "mat" + std::to_string(cols) + "x" + std::to_string(rows);
|
||||
std::string matrix = matrix_no_type + "<f32>";
|
||||
std::string vector = "vec" + std::to_string(rows) + "<f32>";
|
||||
std::string values;
|
||||
for (uint32_t c = 0; c < cols; c++) {
|
||||
if (c > 0) {
|
||||
values += ", ";
|
||||
}
|
||||
values += vector + "(";
|
||||
for (uint32_t r = 0; r < rows; r++) {
|
||||
if (r > 0) {
|
||||
values += ", ";
|
||||
}
|
||||
values += "value";
|
||||
}
|
||||
values += ")";
|
||||
}
|
||||
|
||||
std::string src = R"(
|
||||
@stage(fragment)
|
||||
fn main() {
|
||||
let m = ${matrix}(42.0);
|
||||
}
|
||||
)";
|
||||
|
||||
std::string expect = R"(
|
||||
fn build_${matrix_no_type}(value : f32) -> ${matrix} {
|
||||
return ${matrix}(${values});
|
||||
}
|
||||
|
||||
@stage(fragment)
|
||||
fn main() {
|
||||
let m = build_${matrix_no_type}(42.0);
|
||||
}
|
||||
)";
|
||||
src = utils::ReplaceAll(src, "${matrix}", matrix);
|
||||
expect = utils::ReplaceAll(expect, "${matrix}", matrix);
|
||||
expect = utils::ReplaceAll(expect, "${matrix_no_type}", matrix_no_type);
|
||||
expect = utils::ReplaceAll(expect, "${values}", values);
|
||||
|
||||
EXPECT_TRUE(ShouldRun<VectorizeScalarMatrixConstructors>(src));
|
||||
|
||||
auto got = Run<VectorizeScalarMatrixConstructors>(src);
|
||||
|
||||
EXPECT_EQ(expect, str(got));
|
||||
}
|
||||
|
||||
TEST_P(VectorizeScalarMatrixConstructorsTest, MultipleScalars) {
|
||||
uint32_t cols = GetParam().first;
|
||||
uint32_t rows = GetParam().second;
|
||||
std::string mat_type = "mat" + std::to_string(cols) + "x" + std::to_string(rows) + "<f32>";
|
||||
|
||||
@@ -66,6 +66,7 @@
|
||||
#include "src/tint/transform/simplify_pointers.h"
|
||||
#include "src/tint/transform/unshadow.h"
|
||||
#include "src/tint/transform/unwind_discard_functions.h"
|
||||
#include "src/tint/transform/vectorize_scalar_matrix_constructors.h"
|
||||
#include "src/tint/transform/zero_init_workgroup_memory.h"
|
||||
#include "src/tint/utils/defer.h"
|
||||
#include "src/tint/utils/map.h"
|
||||
@@ -198,6 +199,7 @@ SanitizedResult Sanitize(const Program* in, const Options& options) {
|
||||
manager.Add<transform::ExpandCompoundAssignment>();
|
||||
manager.Add<transform::PromoteSideEffectsToDecl>();
|
||||
manager.Add<transform::UnwindDiscardFunctions>();
|
||||
manager.Add<transform::VectorizeScalarMatrixConstructors>();
|
||||
manager.Add<transform::SimplifyPointers>();
|
||||
manager.Add<transform::RemovePhonies>();
|
||||
// ArrayLengthFromUniform must come after InlinePointerLets and Simplify, as
|
||||
@@ -1080,6 +1082,55 @@ bool GeneratorImpl::EmitTypeConstructor(std::ostream& out,
|
||||
return EmitZeroValue(out, type);
|
||||
}
|
||||
|
||||
if (auto* mat = call->Type()->As<sem::Matrix>()) {
|
||||
if (ctor->Parameters().size() == 1) {
|
||||
// Matrix constructor with single scalar.
|
||||
auto fn = utils::GetOrCreate(matrix_scalar_ctors_, mat, [&]() -> std::string {
|
||||
TextBuffer b;
|
||||
TINT_DEFER(helpers_.Append(b));
|
||||
|
||||
auto name = UniqueIdentifier("build_mat" + std::to_string(mat->columns()) + "x" +
|
||||
std::to_string(mat->rows()));
|
||||
{
|
||||
auto l = line(&b);
|
||||
if (!EmitType(l, mat, ast::StorageClass::kNone, ast::Access::kUndefined, "")) {
|
||||
return "";
|
||||
}
|
||||
l << " " << name << "(";
|
||||
if (!EmitType(l, mat->type(), ast::StorageClass::kNone, ast::Access::kUndefined,
|
||||
"")) {
|
||||
return "";
|
||||
}
|
||||
l << " value) {";
|
||||
}
|
||||
{
|
||||
ScopedIndent si(&b);
|
||||
auto l = line(&b);
|
||||
l << "return ";
|
||||
if (!EmitType(l, mat, ast::StorageClass::kNone, ast::Access::kUndefined, "")) {
|
||||
return "";
|
||||
}
|
||||
l << "(";
|
||||
for (uint32_t i = 0; i < mat->columns() * mat->rows(); i++) {
|
||||
l << ((i > 0) ? ", value" : "value");
|
||||
}
|
||||
l << ");";
|
||||
}
|
||||
line(&b) << "}";
|
||||
return name;
|
||||
});
|
||||
if (fn.empty()) {
|
||||
return false;
|
||||
}
|
||||
out << fn << "(";
|
||||
if (!EmitExpression(out, call->Arguments()[0]->Declaration())) {
|
||||
return false;
|
||||
}
|
||||
out << ")";
|
||||
return true;
|
||||
}
|
||||
}
|
||||
|
||||
bool brackets = type->IsAnyOf<sem::Array, sem::Struct>();
|
||||
|
||||
// For single-value vector initializers, swizzle the scalar to the right
|
||||
|
||||
@@ -512,6 +512,7 @@ class GeneratorImpl : public TextGenerator {
|
||||
TextBuffer helpers_; // Helper functions emitted at the top of the output
|
||||
std::function<bool()> emit_continuing_;
|
||||
std::unordered_map<DMAIntrinsic, std::string, DMAIntrinsic::Hasher> dma_intrinsics_;
|
||||
std::unordered_map<const sem::Matrix*, std::string> matrix_scalar_ctors_;
|
||||
std::unordered_map<const sem::Builtin*, std::string> builtins_;
|
||||
std::unordered_map<const sem::Struct*, std::string> structure_builders_;
|
||||
std::unordered_map<const sem::Vector*, std::string> dynamic_vector_write_;
|
||||
|
||||
Reference in New Issue
Block a user