Tint: Fix VectorizeScalarMatrixConstructors to run for ref to scalar
This patch make VectorizeScalarMatrixConstructors transform run for reference to scalar as well as scalar types node, i.e. run for `mat2x2<f32>(v[2])`, where `v` is a f32 vector, as well as `mat2x2<f32>(1.0)`. Bug: tint:1589 Change-Id: I5d3e367ee6a9826b3e1add3495aaac0ae326be14 Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/94023 Reviewed-by: Ben Clayton <bclayton@google.com> Kokoro: Kokoro <noreply+kokoro@google.com> Commit-Queue: Zhaoming Jiang <zhaoming.jiang@intel.com> Reviewed-by: Ryan Harrison <rharrison@chromium.org>
This commit is contained in:
parent
889a499ef4
commit
d713669329
|
@ -36,7 +36,7 @@ bool VectorizeScalarMatrixConstructors::ShouldRun(const Program* program, const
|
|||
if (auto* call = program->Sem().Get<sem::Call>(node)) {
|
||||
if (call->Target()->Is<sem::TypeConstructor>() && call->Type()->Is<sem::Matrix>()) {
|
||||
auto& args = call->Arguments();
|
||||
if (args.size() > 0 && args[0]->Type()->is_scalar()) {
|
||||
if (args.size() > 0 && args[0]->Type()->UnwrapRef()->is_scalar()) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
|
@ -64,7 +64,7 @@ void VectorizeScalarMatrixConstructors::Run(CloneContext& ctx, const DataMap&, D
|
|||
if (args.size() == 0) {
|
||||
return nullptr;
|
||||
}
|
||||
if (!args[0]->Type()->is_scalar()) {
|
||||
if (!args[0]->Type()->UnwrapRef()->is_scalar()) {
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
|
|
|
@ -81,6 +81,58 @@ fn main() {
|
|||
EXPECT_EQ(expect, str(got));
|
||||
}
|
||||
|
||||
TEST_P(VectorizeScalarMatrixConstructorsTest, SingleScalarsReference) {
|
||||
uint32_t cols = GetParam().first;
|
||||
uint32_t rows = GetParam().second;
|
||||
std::string matrix_no_type = "mat" + std::to_string(cols) + "x" + std::to_string(rows);
|
||||
std::string matrix = matrix_no_type + "<f32>";
|
||||
std::string vector = "vec" + std::to_string(rows) + "<f32>";
|
||||
std::string values;
|
||||
for (uint32_t c = 0; c < cols; c++) {
|
||||
if (c > 0) {
|
||||
values += ", ";
|
||||
}
|
||||
values += vector + "(";
|
||||
for (uint32_t r = 0; r < rows; r++) {
|
||||
if (r > 0) {
|
||||
values += ", ";
|
||||
}
|
||||
values += "value";
|
||||
}
|
||||
values += ")";
|
||||
}
|
||||
|
||||
std::string src = R"(
|
||||
@fragment
|
||||
fn main() {
|
||||
let v = vec4<f32>(42.0);
|
||||
let m = ${matrix}(v[2]);
|
||||
}
|
||||
)";
|
||||
|
||||
std::string expect = R"(
|
||||
fn build_${matrix_no_type}(value : f32) -> ${matrix} {
|
||||
return ${matrix}(${values});
|
||||
}
|
||||
|
||||
@fragment
|
||||
fn main() {
|
||||
let v = vec4<f32>(42.0);
|
||||
let m = build_${matrix_no_type}(v[2]);
|
||||
}
|
||||
)";
|
||||
src = utils::ReplaceAll(src, "${matrix}", matrix);
|
||||
expect = utils::ReplaceAll(expect, "${matrix}", matrix);
|
||||
expect = utils::ReplaceAll(expect, "${matrix_no_type}", matrix_no_type);
|
||||
expect = utils::ReplaceAll(expect, "${values}", values);
|
||||
|
||||
EXPECT_TRUE(ShouldRun<VectorizeScalarMatrixConstructors>(src));
|
||||
|
||||
auto got = Run<VectorizeScalarMatrixConstructors>(src);
|
||||
|
||||
EXPECT_EQ(expect, str(got));
|
||||
}
|
||||
|
||||
TEST_P(VectorizeScalarMatrixConstructorsTest, MultipleScalars) {
|
||||
uint32_t cols = GetParam().first;
|
||||
uint32_t rows = GetParam().second;
|
||||
|
@ -123,6 +175,49 @@ fn main() {
|
|||
EXPECT_EQ(expect, str(got));
|
||||
}
|
||||
|
||||
TEST_P(VectorizeScalarMatrixConstructorsTest, MultipleScalarsReference) {
|
||||
uint32_t cols = GetParam().first;
|
||||
uint32_t rows = GetParam().second;
|
||||
std::string mat_type = "mat" + std::to_string(cols) + "x" + std::to_string(rows) + "<f32>";
|
||||
std::string vec_type = "vec" + std::to_string(rows) + "<f32>";
|
||||
std::string scalar_values;
|
||||
std::string vector_values;
|
||||
for (uint32_t c = 0; c < cols; c++) {
|
||||
if (c > 0) {
|
||||
vector_values += ", ";
|
||||
scalar_values += ", ";
|
||||
}
|
||||
vector_values += vec_type + "(";
|
||||
for (uint32_t r = 0; r < rows; r++) {
|
||||
if (r > 0) {
|
||||
scalar_values += ", ";
|
||||
vector_values += ", ";
|
||||
}
|
||||
auto value = "v[" + std::to_string((c * rows + r) % 4) + "]";
|
||||
scalar_values += value;
|
||||
vector_values += value;
|
||||
}
|
||||
vector_values += ")";
|
||||
}
|
||||
|
||||
std::string tmpl = R"(
|
||||
@fragment
|
||||
fn main() {
|
||||
let v = vec4<f32>(1.0, 2.0, 3.0, 8.0);
|
||||
let m = ${matrix}(${values});
|
||||
}
|
||||
)";
|
||||
tmpl = utils::ReplaceAll(tmpl, "${matrix}", mat_type);
|
||||
auto src = utils::ReplaceAll(tmpl, "${values}", scalar_values);
|
||||
auto expect = utils::ReplaceAll(tmpl, "${values}", vector_values);
|
||||
|
||||
EXPECT_TRUE(ShouldRun<VectorizeScalarMatrixConstructors>(src));
|
||||
|
||||
auto got = Run<VectorizeScalarMatrixConstructors>(src);
|
||||
|
||||
EXPECT_EQ(expect, str(got));
|
||||
}
|
||||
|
||||
TEST_P(VectorizeScalarMatrixConstructorsTest, NonScalarConstructors) {
|
||||
uint32_t cols = GetParam().first;
|
||||
uint32_t rows = GetParam().second;
|
||||
|
|
Loading…
Reference in New Issue