From d713669329430ccebbf2c0eb1d3667388adddfcb Mon Sep 17 00:00:00 2001 From: Zhaoming Jiang Date: Fri, 17 Jun 2022 15:36:59 +0000 Subject: [PATCH] Tint: Fix VectorizeScalarMatrixConstructors to run for ref to scalar This patch make VectorizeScalarMatrixConstructors transform run for reference to scalar as well as scalar types node, i.e. run for `mat2x2(v[2])`, where `v` is a f32 vector, as well as `mat2x2(1.0)`. Bug: tint:1589 Change-Id: I5d3e367ee6a9826b3e1add3495aaac0ae326be14 Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/94023 Reviewed-by: Ben Clayton Kokoro: Kokoro Commit-Queue: Zhaoming Jiang Reviewed-by: Ryan Harrison --- .../vectorize_scalar_matrix_constructors.cc | 4 +- ...ctorize_scalar_matrix_constructors_test.cc | 95 +++++++++++++++++++ 2 files changed, 97 insertions(+), 2 deletions(-) diff --git a/src/tint/transform/vectorize_scalar_matrix_constructors.cc b/src/tint/transform/vectorize_scalar_matrix_constructors.cc index 9cd9757bbb..e7c76e3ca6 100644 --- a/src/tint/transform/vectorize_scalar_matrix_constructors.cc +++ b/src/tint/transform/vectorize_scalar_matrix_constructors.cc @@ -36,7 +36,7 @@ bool VectorizeScalarMatrixConstructors::ShouldRun(const Program* program, const if (auto* call = program->Sem().Get(node)) { if (call->Target()->Is() && call->Type()->Is()) { auto& args = call->Arguments(); - if (args.size() > 0 && args[0]->Type()->is_scalar()) { + if (args.size() > 0 && args[0]->Type()->UnwrapRef()->is_scalar()) { return true; } } @@ -64,7 +64,7 @@ void VectorizeScalarMatrixConstructors::Run(CloneContext& ctx, const DataMap&, D if (args.size() == 0) { return nullptr; } - if (!args[0]->Type()->is_scalar()) { + if (!args[0]->Type()->UnwrapRef()->is_scalar()) { return nullptr; } diff --git a/src/tint/transform/vectorize_scalar_matrix_constructors_test.cc b/src/tint/transform/vectorize_scalar_matrix_constructors_test.cc index 1ee9d337b7..f92df44cca 100644 --- a/src/tint/transform/vectorize_scalar_matrix_constructors_test.cc +++ b/src/tint/transform/vectorize_scalar_matrix_constructors_test.cc @@ -81,6 +81,58 @@ fn main() { EXPECT_EQ(expect, str(got)); } +TEST_P(VectorizeScalarMatrixConstructorsTest, SingleScalarsReference) { + uint32_t cols = GetParam().first; + uint32_t rows = GetParam().second; + std::string matrix_no_type = "mat" + std::to_string(cols) + "x" + std::to_string(rows); + std::string matrix = matrix_no_type + ""; + std::string vector = "vec" + std::to_string(rows) + ""; + std::string values; + for (uint32_t c = 0; c < cols; c++) { + if (c > 0) { + values += ", "; + } + values += vector + "("; + for (uint32_t r = 0; r < rows; r++) { + if (r > 0) { + values += ", "; + } + values += "value"; + } + values += ")"; + } + + std::string src = R"( +@fragment +fn main() { + let v = vec4(42.0); + let m = ${matrix}(v[2]); +} +)"; + + std::string expect = R"( +fn build_${matrix_no_type}(value : f32) -> ${matrix} { + return ${matrix}(${values}); +} + +@fragment +fn main() { + let v = vec4(42.0); + let m = build_${matrix_no_type}(v[2]); +} +)"; + src = utils::ReplaceAll(src, "${matrix}", matrix); + expect = utils::ReplaceAll(expect, "${matrix}", matrix); + expect = utils::ReplaceAll(expect, "${matrix_no_type}", matrix_no_type); + expect = utils::ReplaceAll(expect, "${values}", values); + + EXPECT_TRUE(ShouldRun(src)); + + auto got = Run(src); + + EXPECT_EQ(expect, str(got)); +} + TEST_P(VectorizeScalarMatrixConstructorsTest, MultipleScalars) { uint32_t cols = GetParam().first; uint32_t rows = GetParam().second; @@ -123,6 +175,49 @@ fn main() { EXPECT_EQ(expect, str(got)); } +TEST_P(VectorizeScalarMatrixConstructorsTest, MultipleScalarsReference) { + uint32_t cols = GetParam().first; + uint32_t rows = GetParam().second; + std::string mat_type = "mat" + std::to_string(cols) + "x" + std::to_string(rows) + ""; + std::string vec_type = "vec" + std::to_string(rows) + ""; + std::string scalar_values; + std::string vector_values; + for (uint32_t c = 0; c < cols; c++) { + if (c > 0) { + vector_values += ", "; + scalar_values += ", "; + } + vector_values += vec_type + "("; + for (uint32_t r = 0; r < rows; r++) { + if (r > 0) { + scalar_values += ", "; + vector_values += ", "; + } + auto value = "v[" + std::to_string((c * rows + r) % 4) + "]"; + scalar_values += value; + vector_values += value; + } + vector_values += ")"; + } + + std::string tmpl = R"( +@fragment +fn main() { + let v = vec4(1.0, 2.0, 3.0, 8.0); + let m = ${matrix}(${values}); +} +)"; + tmpl = utils::ReplaceAll(tmpl, "${matrix}", mat_type); + auto src = utils::ReplaceAll(tmpl, "${values}", scalar_values); + auto expect = utils::ReplaceAll(tmpl, "${values}", vector_values); + + EXPECT_TRUE(ShouldRun(src)); + + auto got = Run(src); + + EXPECT_EQ(expect, str(got)); +} + TEST_P(VectorizeScalarMatrixConstructorsTest, NonScalarConstructors) { uint32_t cols = GetParam().first; uint32_t rows = GetParam().second;