Tint: Fix VectorizeScalarMatrixConstructors to run for ref to scalar

This patch make VectorizeScalarMatrixConstructors transform run for
reference to scalar as well as scalar types node, i.e. run for
`mat2x2<f32>(v[2])`, where `v` is a f32 vector, as well as
`mat2x2<f32>(1.0)`.

Bug: tint:1589
Change-Id: I5d3e367ee6a9826b3e1add3495aaac0ae326be14
Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/94023
Reviewed-by: Ben Clayton <bclayton@google.com>
Kokoro: Kokoro <noreply+kokoro@google.com>
Commit-Queue: Zhaoming Jiang <zhaoming.jiang@intel.com>
Reviewed-by: Ryan Harrison <rharrison@chromium.org>
This commit is contained in:
Zhaoming Jiang 2022-06-17 15:36:59 +00:00 committed by Dawn LUCI CQ
parent 889a499ef4
commit d713669329
2 changed files with 97 additions and 2 deletions

View File

@ -36,7 +36,7 @@ bool VectorizeScalarMatrixConstructors::ShouldRun(const Program* program, const
if (auto* call = program->Sem().Get<sem::Call>(node)) {
if (call->Target()->Is<sem::TypeConstructor>() && call->Type()->Is<sem::Matrix>()) {
auto& args = call->Arguments();
if (args.size() > 0 && args[0]->Type()->is_scalar()) {
if (args.size() > 0 && args[0]->Type()->UnwrapRef()->is_scalar()) {
return true;
}
}
@ -64,7 +64,7 @@ void VectorizeScalarMatrixConstructors::Run(CloneContext& ctx, const DataMap&, D
if (args.size() == 0) {
return nullptr;
}
if (!args[0]->Type()->is_scalar()) {
if (!args[0]->Type()->UnwrapRef()->is_scalar()) {
return nullptr;
}

View File

@ -81,6 +81,58 @@ fn main() {
EXPECT_EQ(expect, str(got));
}
TEST_P(VectorizeScalarMatrixConstructorsTest, SingleScalarsReference) {
uint32_t cols = GetParam().first;
uint32_t rows = GetParam().second;
std::string matrix_no_type = "mat" + std::to_string(cols) + "x" + std::to_string(rows);
std::string matrix = matrix_no_type + "<f32>";
std::string vector = "vec" + std::to_string(rows) + "<f32>";
std::string values;
for (uint32_t c = 0; c < cols; c++) {
if (c > 0) {
values += ", ";
}
values += vector + "(";
for (uint32_t r = 0; r < rows; r++) {
if (r > 0) {
values += ", ";
}
values += "value";
}
values += ")";
}
std::string src = R"(
@fragment
fn main() {
let v = vec4<f32>(42.0);
let m = ${matrix}(v[2]);
}
)";
std::string expect = R"(
fn build_${matrix_no_type}(value : f32) -> ${matrix} {
return ${matrix}(${values});
}
@fragment
fn main() {
let v = vec4<f32>(42.0);
let m = build_${matrix_no_type}(v[2]);
}
)";
src = utils::ReplaceAll(src, "${matrix}", matrix);
expect = utils::ReplaceAll(expect, "${matrix}", matrix);
expect = utils::ReplaceAll(expect, "${matrix_no_type}", matrix_no_type);
expect = utils::ReplaceAll(expect, "${values}", values);
EXPECT_TRUE(ShouldRun<VectorizeScalarMatrixConstructors>(src));
auto got = Run<VectorizeScalarMatrixConstructors>(src);
EXPECT_EQ(expect, str(got));
}
TEST_P(VectorizeScalarMatrixConstructorsTest, MultipleScalars) {
uint32_t cols = GetParam().first;
uint32_t rows = GetParam().second;
@ -123,6 +175,49 @@ fn main() {
EXPECT_EQ(expect, str(got));
}
TEST_P(VectorizeScalarMatrixConstructorsTest, MultipleScalarsReference) {
uint32_t cols = GetParam().first;
uint32_t rows = GetParam().second;
std::string mat_type = "mat" + std::to_string(cols) + "x" + std::to_string(rows) + "<f32>";
std::string vec_type = "vec" + std::to_string(rows) + "<f32>";
std::string scalar_values;
std::string vector_values;
for (uint32_t c = 0; c < cols; c++) {
if (c > 0) {
vector_values += ", ";
scalar_values += ", ";
}
vector_values += vec_type + "(";
for (uint32_t r = 0; r < rows; r++) {
if (r > 0) {
scalar_values += ", ";
vector_values += ", ";
}
auto value = "v[" + std::to_string((c * rows + r) % 4) + "]";
scalar_values += value;
vector_values += value;
}
vector_values += ")";
}
std::string tmpl = R"(
@fragment
fn main() {
let v = vec4<f32>(1.0, 2.0, 3.0, 8.0);
let m = ${matrix}(${values});
}
)";
tmpl = utils::ReplaceAll(tmpl, "${matrix}", mat_type);
auto src = utils::ReplaceAll(tmpl, "${values}", scalar_values);
auto expect = utils::ReplaceAll(tmpl, "${values}", vector_values);
EXPECT_TRUE(ShouldRun<VectorizeScalarMatrixConstructors>(src));
auto got = Run<VectorizeScalarMatrixConstructors>(src);
EXPECT_EQ(expect, str(got));
}
TEST_P(VectorizeScalarMatrixConstructorsTest, NonScalarConstructors) {
uint32_t cols = GetParam().first;
uint32_t rows = GetParam().second;