mirror of
https://github.com/encounter/dawn-cmake.git
synced 2025-07-07 05:36:04 +00:00
Tint: Fix VectorizeScalarMatrixConstructors to run for ref to scalar
This patch make VectorizeScalarMatrixConstructors transform run for reference to scalar as well as scalar types node, i.e. run for `mat2x2<f32>(v[2])`, where `v` is a f32 vector, as well as `mat2x2<f32>(1.0)`. Bug: tint:1589 Change-Id: I5d3e367ee6a9826b3e1add3495aaac0ae326be14 Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/94023 Reviewed-by: Ben Clayton <bclayton@google.com> Kokoro: Kokoro <noreply+kokoro@google.com> Commit-Queue: Zhaoming Jiang <zhaoming.jiang@intel.com> Reviewed-by: Ryan Harrison <rharrison@chromium.org>
This commit is contained in:
parent
889a499ef4
commit
d713669329
@ -36,7 +36,7 @@ bool VectorizeScalarMatrixConstructors::ShouldRun(const Program* program, const
|
|||||||
if (auto* call = program->Sem().Get<sem::Call>(node)) {
|
if (auto* call = program->Sem().Get<sem::Call>(node)) {
|
||||||
if (call->Target()->Is<sem::TypeConstructor>() && call->Type()->Is<sem::Matrix>()) {
|
if (call->Target()->Is<sem::TypeConstructor>() && call->Type()->Is<sem::Matrix>()) {
|
||||||
auto& args = call->Arguments();
|
auto& args = call->Arguments();
|
||||||
if (args.size() > 0 && args[0]->Type()->is_scalar()) {
|
if (args.size() > 0 && args[0]->Type()->UnwrapRef()->is_scalar()) {
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -64,7 +64,7 @@ void VectorizeScalarMatrixConstructors::Run(CloneContext& ctx, const DataMap&, D
|
|||||||
if (args.size() == 0) {
|
if (args.size() == 0) {
|
||||||
return nullptr;
|
return nullptr;
|
||||||
}
|
}
|
||||||
if (!args[0]->Type()->is_scalar()) {
|
if (!args[0]->Type()->UnwrapRef()->is_scalar()) {
|
||||||
return nullptr;
|
return nullptr;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -81,6 +81,58 @@ fn main() {
|
|||||||
EXPECT_EQ(expect, str(got));
|
EXPECT_EQ(expect, str(got));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
TEST_P(VectorizeScalarMatrixConstructorsTest, SingleScalarsReference) {
|
||||||
|
uint32_t cols = GetParam().first;
|
||||||
|
uint32_t rows = GetParam().second;
|
||||||
|
std::string matrix_no_type = "mat" + std::to_string(cols) + "x" + std::to_string(rows);
|
||||||
|
std::string matrix = matrix_no_type + "<f32>";
|
||||||
|
std::string vector = "vec" + std::to_string(rows) + "<f32>";
|
||||||
|
std::string values;
|
||||||
|
for (uint32_t c = 0; c < cols; c++) {
|
||||||
|
if (c > 0) {
|
||||||
|
values += ", ";
|
||||||
|
}
|
||||||
|
values += vector + "(";
|
||||||
|
for (uint32_t r = 0; r < rows; r++) {
|
||||||
|
if (r > 0) {
|
||||||
|
values += ", ";
|
||||||
|
}
|
||||||
|
values += "value";
|
||||||
|
}
|
||||||
|
values += ")";
|
||||||
|
}
|
||||||
|
|
||||||
|
std::string src = R"(
|
||||||
|
@fragment
|
||||||
|
fn main() {
|
||||||
|
let v = vec4<f32>(42.0);
|
||||||
|
let m = ${matrix}(v[2]);
|
||||||
|
}
|
||||||
|
)";
|
||||||
|
|
||||||
|
std::string expect = R"(
|
||||||
|
fn build_${matrix_no_type}(value : f32) -> ${matrix} {
|
||||||
|
return ${matrix}(${values});
|
||||||
|
}
|
||||||
|
|
||||||
|
@fragment
|
||||||
|
fn main() {
|
||||||
|
let v = vec4<f32>(42.0);
|
||||||
|
let m = build_${matrix_no_type}(v[2]);
|
||||||
|
}
|
||||||
|
)";
|
||||||
|
src = utils::ReplaceAll(src, "${matrix}", matrix);
|
||||||
|
expect = utils::ReplaceAll(expect, "${matrix}", matrix);
|
||||||
|
expect = utils::ReplaceAll(expect, "${matrix_no_type}", matrix_no_type);
|
||||||
|
expect = utils::ReplaceAll(expect, "${values}", values);
|
||||||
|
|
||||||
|
EXPECT_TRUE(ShouldRun<VectorizeScalarMatrixConstructors>(src));
|
||||||
|
|
||||||
|
auto got = Run<VectorizeScalarMatrixConstructors>(src);
|
||||||
|
|
||||||
|
EXPECT_EQ(expect, str(got));
|
||||||
|
}
|
||||||
|
|
||||||
TEST_P(VectorizeScalarMatrixConstructorsTest, MultipleScalars) {
|
TEST_P(VectorizeScalarMatrixConstructorsTest, MultipleScalars) {
|
||||||
uint32_t cols = GetParam().first;
|
uint32_t cols = GetParam().first;
|
||||||
uint32_t rows = GetParam().second;
|
uint32_t rows = GetParam().second;
|
||||||
@ -123,6 +175,49 @@ fn main() {
|
|||||||
EXPECT_EQ(expect, str(got));
|
EXPECT_EQ(expect, str(got));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
TEST_P(VectorizeScalarMatrixConstructorsTest, MultipleScalarsReference) {
|
||||||
|
uint32_t cols = GetParam().first;
|
||||||
|
uint32_t rows = GetParam().second;
|
||||||
|
std::string mat_type = "mat" + std::to_string(cols) + "x" + std::to_string(rows) + "<f32>";
|
||||||
|
std::string vec_type = "vec" + std::to_string(rows) + "<f32>";
|
||||||
|
std::string scalar_values;
|
||||||
|
std::string vector_values;
|
||||||
|
for (uint32_t c = 0; c < cols; c++) {
|
||||||
|
if (c > 0) {
|
||||||
|
vector_values += ", ";
|
||||||
|
scalar_values += ", ";
|
||||||
|
}
|
||||||
|
vector_values += vec_type + "(";
|
||||||
|
for (uint32_t r = 0; r < rows; r++) {
|
||||||
|
if (r > 0) {
|
||||||
|
scalar_values += ", ";
|
||||||
|
vector_values += ", ";
|
||||||
|
}
|
||||||
|
auto value = "v[" + std::to_string((c * rows + r) % 4) + "]";
|
||||||
|
scalar_values += value;
|
||||||
|
vector_values += value;
|
||||||
|
}
|
||||||
|
vector_values += ")";
|
||||||
|
}
|
||||||
|
|
||||||
|
std::string tmpl = R"(
|
||||||
|
@fragment
|
||||||
|
fn main() {
|
||||||
|
let v = vec4<f32>(1.0, 2.0, 3.0, 8.0);
|
||||||
|
let m = ${matrix}(${values});
|
||||||
|
}
|
||||||
|
)";
|
||||||
|
tmpl = utils::ReplaceAll(tmpl, "${matrix}", mat_type);
|
||||||
|
auto src = utils::ReplaceAll(tmpl, "${values}", scalar_values);
|
||||||
|
auto expect = utils::ReplaceAll(tmpl, "${values}", vector_values);
|
||||||
|
|
||||||
|
EXPECT_TRUE(ShouldRun<VectorizeScalarMatrixConstructors>(src));
|
||||||
|
|
||||||
|
auto got = Run<VectorizeScalarMatrixConstructors>(src);
|
||||||
|
|
||||||
|
EXPECT_EQ(expect, str(got));
|
||||||
|
}
|
||||||
|
|
||||||
TEST_P(VectorizeScalarMatrixConstructorsTest, NonScalarConstructors) {
|
TEST_P(VectorizeScalarMatrixConstructorsTest, NonScalarConstructors) {
|
||||||
uint32_t cols = GetParam().first;
|
uint32_t cols = GetParam().first;
|
||||||
uint32_t rows = GetParam().second;
|
uint32_t rows = GetParam().second;
|
||||||
|
Loading…
x
Reference in New Issue
Block a user