Add support for matrix construction from scalars

Use a transform to convert these to the vector form for the MSL and
SPIR-V backends.

MSL only has the scalar form from version 2.0 onwards.

Fixed: tint:1123
Change-Id: I384abd9872d9eae52a10a37cbd6aa96004692e9c
Reviewed-on: https://dawn-review.googlesource.com/c/tint/+/67360
Kokoro: Kokoro <noreply+kokoro@google.com>
Commit-Queue: James Price <jrprice@google.com>
Reviewed-by: Ben Clayton <bclayton@google.com>
This commit is contained in:
James Price 2021-10-25 19:20:31 +00:00 committed by Tint LUCI CQ
parent c73b57f2c6
commit 91689fb6f0
56 changed files with 832 additions and 87 deletions

View File

@ -17,6 +17,7 @@
* `any()` and `all()` now support a `bool` parameter. These simply return the passed argument. [tint:1253](https://crbug.com/tint/1253)
* Call statements may now include functions that return a value (`ignore()` is no longer needed).
* The `interpolate(flat)` attribute can now be specified on integral user-defined IO. It will eventually become an error to define integral user-defined IO without this attribute.
* Matrix construction from scalar element values is now supported.
### Fixes

View File

@ -316,8 +316,8 @@ libtint_source_set("libtint_core_all_src") {
"ast/switch_statement.h",
"ast/texture.cc",
"ast/texture.h",
"ast/type.h",
"ast/traverse_expressions.h",
"ast/type.h",
"ast/type_constructor_expression.cc",
"ast/type_constructor_expression.h",
"ast/type_decl.cc",
@ -475,6 +475,8 @@ libtint_source_set("libtint_core_all_src") {
"transform/single_entry_point.h",
"transform/transform.cc",
"transform/transform.h",
"transform/vectorize_scalar_matrix_constructors.cc",
"transform/vectorize_scalar_matrix_constructors.h",
"transform/vertex_pulling.cc",
"transform/vertex_pulling.h",
"transform/wrap_arrays_in_structs.cc",

View File

@ -341,6 +341,8 @@ set(TINT_LIB_SRCS
transform/single_entry_point.h
transform/transform.cc
transform/transform.h
transform/vectorize_scalar_matrix_constructors.cc
transform/vectorize_scalar_matrix_constructors.h
transform/vertex_pulling.cc
transform/vertex_pulling.h
transform/wrap_arrays_in_structs.cc
@ -963,6 +965,7 @@ if(${TINT_BUILD_TESTS})
transform/simplify_test.cc
transform/single_entry_point_test.cc
transform/test_helper.h
transform/vectorize_scalar_matrix_constructors.cc
transform/vertex_pulling_test.cc
transform/wrap_arrays_in_structs_test.cc
transform/zero_init_workgroup_memory_test.cc

View File

@ -2885,28 +2885,47 @@ bool Resolver::ValidateMatrixConstructor(
}
auto* elem_type = matrix_type->type();
if (matrix_type->columns() != values.size()) {
auto num_elements = matrix_type->columns() * matrix_type->rows();
// Print a generic error for an invalid matrix constructor, showing the
// available overloads.
auto print_error = [&]() {
const Source& values_start = values[0]->source;
const Source& values_end = values[values.size() - 1]->source;
AddError("expected " + std::to_string(matrix_type->columns()) + " '" +
VectorPretty(matrix_type->rows(), elem_type) +
"' arguments in '" + type_name + "' constructor, found " +
std::to_string(values.size()),
Source::Combine(values_start, values_end));
auto elem_type_name = elem_type->FriendlyName(builder_->Symbols());
std::stringstream ss;
ss << "invalid constructor for " + type_name << std::endl << std::endl;
ss << "3 candidates available:" << std::endl;
ss << " " << type_name << "()" << std::endl;
ss << " " << type_name << "(" << elem_type_name << ",...,"
<< elem_type_name << ")"
<< " // " << std::to_string(num_elements) << " arguments" << std::endl;
ss << " " << type_name << "(";
for (uint32_t c = 0; c < matrix_type->columns(); c++) {
if (c > 0) {
ss << ", ";
}
ss << VectorPretty(matrix_type->rows(), elem_type);
}
ss << ")" << std::endl;
AddError(ss.str(), Source::Combine(values_start, values_end));
};
const sem::Type* expected_arg_type = nullptr;
if (num_elements == values.size()) {
// Column-major construction from scalar elements.
expected_arg_type = matrix_type->type();
} else if (matrix_type->columns() == values.size()) {
// Column-by-column construction from vectors.
expected_arg_type = matrix_type->ColumnType();
} else {
print_error();
return false;
}
for (auto* value : values) {
auto* value_type = TypeOf(value)->UnwrapRef();
auto* value_vec = value_type->As<sem::Vector>();
if (!value_vec || value_vec->Width() != matrix_type->rows() ||
elem_type != value_vec->type()) {
AddError("expected argument type '" +
VectorPretty(matrix_type->rows(), elem_type) + "' in '" +
type_name + "' constructor, found '" + TypeNameOf(value) +
"'",
value->source);
if (TypeOf(value)->UnwrapRef() != expected_arg_type) {
print_error();
return false;
}
}

View File

@ -12,6 +12,7 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#include "gmock/gmock.h"
#include "src/resolver/resolver_test_helper.h"
#include "src/sem/reference_type.h"
@ -19,6 +20,8 @@ namespace tint {
namespace resolver {
namespace {
using ::testing::HasSubstr;
// Helpers and typedefs
using builder::alias;
using builder::alias1;
@ -1641,13 +1644,9 @@ static std::string MatrixStr(const MatrixDimensions& dimensions,
std::to_string(dimensions.rows) + "<" + subtype + ">";
}
static std::string VecStr(uint32_t dimensions, std::string subtype = "f32") {
return "vec" + std::to_string(dimensions) + "<" + subtype + ">";
}
using MatrixConstructorTest = ResolverTestWithParam<MatrixDimensions>;
TEST_P(MatrixConstructorTest, Expr_Constructor_Error_TooFewArguments) {
TEST_P(MatrixConstructorTest, Expr_ColumnConstructor_Error_TooFewArguments) {
// matNxM<f32>(vecM<f32>(), ...); with N - 1 arguments
const auto param = GetParam();
@ -1665,13 +1664,34 @@ TEST_P(MatrixConstructorTest, Expr_Constructor_Error_TooFewArguments) {
WrapInFunction(tc);
EXPECT_FALSE(r()->Resolve());
EXPECT_EQ(r()->error(),
"12:1 error: expected " + std::to_string(param.columns) + " '" +
VecStr(param.rows) + "' arguments in '" + MatrixStr(param) +
"' constructor, found " + std::to_string(param.columns - 1));
EXPECT_THAT(r()->error(),
HasSubstr("12:1 error: invalid constructor for " +
MatrixStr(param) + "\n\n3 candidates available:"));
}
TEST_P(MatrixConstructorTest, Expr_Constructor_Error_TooManyArguments) {
TEST_P(MatrixConstructorTest, Expr_ElementConstructor_Error_TooFewArguments) {
// matNxM<f32>(f32,...,f32); with N*M - 1 arguments
const auto param = GetParam();
ast::ExpressionList args;
for (uint32_t i = 1; i <= param.columns * param.rows - 1; i++) {
args.push_back(create<ast::TypeConstructorExpression>(
Source{{12, i}}, ty.f32(), ExprList()));
}
auto* matrix_type = ty.mat<f32>(param.columns, param.rows);
auto* tc = create<ast::TypeConstructorExpression>(Source{}, matrix_type,
std::move(args));
WrapInFunction(tc);
EXPECT_FALSE(r()->Resolve());
EXPECT_THAT(r()->error(),
HasSubstr("12:1 error: invalid constructor for " +
MatrixStr(param) + "\n\n3 candidates available:"));
}
TEST_P(MatrixConstructorTest, Expr_ColumnConstructor_Error_TooManyArguments) {
// matNxM<f32>(vecM<f32>(), ...); with N + 1 arguments
const auto param = GetParam();
@ -1689,21 +1709,20 @@ TEST_P(MatrixConstructorTest, Expr_Constructor_Error_TooManyArguments) {
WrapInFunction(tc);
EXPECT_FALSE(r()->Resolve());
EXPECT_EQ(r()->error(),
"12:1 error: expected " + std::to_string(param.columns) + " '" +
VecStr(param.rows) + "' arguments in '" + MatrixStr(param) +
"' constructor, found " + std::to_string(param.columns + 1));
EXPECT_THAT(r()->error(),
HasSubstr("12:1 error: invalid constructor for " +
MatrixStr(param) + "\n\n3 candidates available:"));
}
TEST_P(MatrixConstructorTest, Expr_Constructor_Error_InvalidArgumentType) {
// matNxM<f32>(1.0, 1.0, ...); N arguments
TEST_P(MatrixConstructorTest, Expr_ElementConstructor_Error_TooManyArguments) {
// matNxM<f32>(f32,...,f32); with N*M + 1 arguments
const auto param = GetParam();
ast::ExpressionList args;
for (uint32_t i = 1; i <= param.columns; i++) {
args.push_back(create<ast::ScalarConstructorExpression>(Source{{12, i}},
Literal(1.0f)));
for (uint32_t i = 1; i <= param.columns * param.rows + 1; i++) {
args.push_back(create<ast::TypeConstructorExpression>(
Source{{12, i}}, ty.f32(), ExprList()));
}
auto* matrix_type = ty.mat<f32>(param.columns, param.rows);
@ -1712,13 +1731,60 @@ TEST_P(MatrixConstructorTest, Expr_Constructor_Error_InvalidArgumentType) {
WrapInFunction(tc);
EXPECT_FALSE(r()->Resolve());
EXPECT_EQ(r()->error(), "12:1 error: expected argument type '" +
VecStr(param.rows) + "' in '" + MatrixStr(param) +
"' constructor, found 'f32'");
EXPECT_THAT(r()->error(),
HasSubstr("12:1 error: invalid constructor for " +
MatrixStr(param) + "\n\n3 candidates available:"));
}
TEST_P(MatrixConstructorTest,
Expr_Constructor_Error_TooFewRowsInVectorArgument) {
Expr_ColumnConstructor_Error_InvalidArgumentType) {
// matNxM<f32>(vec<u32>, vec<u32>, ...); N arguments
const auto param = GetParam();
ast::ExpressionList args;
for (uint32_t i = 1; i <= param.columns; i++) {
auto* vec_type = ty.vec<u32>(param.rows);
args.push_back(create<ast::TypeConstructorExpression>(
Source{{12, i}}, vec_type, ExprList()));
}
auto* matrix_type = ty.mat<f32>(param.columns, param.rows);
auto* tc = create<ast::TypeConstructorExpression>(Source{}, matrix_type,
std::move(args));
WrapInFunction(tc);
EXPECT_FALSE(r()->Resolve());
EXPECT_THAT(r()->error(),
HasSubstr("12:1 error: invalid constructor for " +
MatrixStr(param) + "\n\n3 candidates available:"));
}
TEST_P(MatrixConstructorTest,
Expr_ElementConstructor_Error_InvalidArgumentType) {
// matNxM<f32>(u32, u32, ...); N*M arguments
const auto param = GetParam();
ast::ExpressionList args;
for (uint32_t i = 1; i <= param.columns; i++) {
args.push_back(
create<ast::ScalarConstructorExpression>(Source{{12, i}}, Literal(1u)));
}
auto* matrix_type = ty.mat<f32>(param.columns, param.rows);
auto* tc = create<ast::TypeConstructorExpression>(Source{}, matrix_type,
std::move(args));
WrapInFunction(tc);
EXPECT_FALSE(r()->Resolve());
EXPECT_THAT(r()->error(),
HasSubstr("12:1 error: invalid constructor for " +
MatrixStr(param) + "\n\n3 candidates available:"));
}
TEST_P(MatrixConstructorTest,
Expr_ColumnConstructor_Error_TooFewRowsInVectorArgument) {
// matNxM<f32>(vecM<f32>(),...,vecM-1<f32>());
const auto param = GetParam();
@ -1745,15 +1811,13 @@ TEST_P(MatrixConstructorTest,
WrapInFunction(tc);
EXPECT_FALSE(r()->Resolve());
EXPECT_EQ(r()->error(), "12:" + std::to_string(kInvalidLoc) +
" error: expected argument type '" +
VecStr(param.rows) + "' in '" + MatrixStr(param) +
"' constructor, found '" +
VecStr(param.rows - 1) + "'");
EXPECT_THAT(r()->error(),
HasSubstr("12:1 error: invalid constructor for " +
MatrixStr(param) + "\n\n3 candidates available:"));
}
TEST_P(MatrixConstructorTest,
Expr_Constructor_Error_TooManyRowsInVectorArgument) {
Expr_ColumnConstructor_Error_TooManyRowsInVectorArgument) {
// matNxM<f32>(vecM<f32>(),...,vecM+1<f32>());
const auto param = GetParam();
@ -1780,36 +1844,9 @@ TEST_P(MatrixConstructorTest,
WrapInFunction(tc);
EXPECT_FALSE(r()->Resolve());
EXPECT_EQ(r()->error(), "12:" + std::to_string(kInvalidLoc) +
" error: expected argument type '" +
VecStr(param.rows) + "' in '" + MatrixStr(param) +
"' constructor, found '" +
VecStr(param.rows + 1) + "'");
}
TEST_P(MatrixConstructorTest,
Expr_Constructor_Error_ArgumentVectorElementTypeMismatch) {
// matNxM<f32>(vecM<u32>(), ...); with N arguments
const auto param = GetParam();
ast::ExpressionList args;
for (uint32_t i = 1; i <= param.columns; i++) {
auto* vec_type = ty.vec<u32>(param.rows);
args.push_back(create<ast::TypeConstructorExpression>(
Source{{12, i}}, vec_type, ExprList()));
}
auto* matrix_type = ty.mat<f32>(param.columns, param.rows);
auto* tc = create<ast::TypeConstructorExpression>(Source{}, matrix_type,
std::move(args));
WrapInFunction(tc);
EXPECT_FALSE(r()->Resolve());
EXPECT_EQ(r()->error(), "12:1 error: expected argument type '" +
VecStr(param.rows) + "' in '" + MatrixStr(param) +
"' constructor, found '" +
VecStr(param.rows, "u32") + "'");
EXPECT_THAT(r()->error(),
HasSubstr("12:1 error: invalid constructor for " +
MatrixStr(param) + "\n\n3 candidates available:"));
}
TEST_P(MatrixConstructorTest, Expr_Constructor_ZeroValue_Success) {
@ -1824,7 +1861,7 @@ TEST_P(MatrixConstructorTest, Expr_Constructor_ZeroValue_Success) {
EXPECT_TRUE(r()->Resolve()) << r()->error();
}
TEST_P(MatrixConstructorTest, Expr_Constructor_WithArguments_Success) {
TEST_P(MatrixConstructorTest, Expr_Constructor_WithColumns_Success) {
// matNxM<f32>(vecM<f32>(), ...); with N arguments
const auto param = GetParam();
@ -1844,6 +1881,25 @@ TEST_P(MatrixConstructorTest, Expr_Constructor_WithArguments_Success) {
EXPECT_TRUE(r()->Resolve()) << r()->error();
}
TEST_P(MatrixConstructorTest, Expr_Constructor_WithElements_Success) {
// matNxM<f32>(f32,...,f32); with N*M arguments
const auto param = GetParam();
ast::ExpressionList args;
for (uint32_t i = 1; i <= param.columns * param.rows; i++) {
args.push_back(create<ast::TypeConstructorExpression>(
Source{{12, i}}, ty.f32(), ExprList()));
}
auto* matrix_type = ty.mat<f32>(param.columns, param.rows);
auto* tc = create<ast::TypeConstructorExpression>(Source{}, matrix_type,
std::move(args));
WrapInFunction(tc);
EXPECT_TRUE(r()->Resolve()) << r()->error();
}
TEST_P(MatrixConstructorTest, Expr_Constructor_ElementTypeAlias_Error) {
// matNxM<Float32>(vecM<u32>(), ...); with N arguments
@ -1863,10 +1919,9 @@ TEST_P(MatrixConstructorTest, Expr_Constructor_ElementTypeAlias_Error) {
WrapInFunction(tc);
EXPECT_FALSE(r()->Resolve());
EXPECT_EQ(r()->error(),
"12:1 error: expected argument type '" + VecStr(param.rows) +
"' in '" + MatrixStr(param, "Float32") +
"' constructor, found '" + VecStr(param.rows, "u32") + "'");
EXPECT_THAT(r()->error(), HasSubstr("12:1 error: invalid constructor for " +
MatrixStr(param, "Float32") +
"\n\n3 candidates available:"));
}
TEST_P(MatrixConstructorTest, Expr_Constructor_ElementTypeAlias_Success) {
@ -1900,8 +1955,13 @@ TEST_F(ResolverTypeConstructorValidationTest,
EXPECT_FALSE(r()->Resolve());
EXPECT_EQ(r()->error(),
"12:34 error: expected argument type 'vec2<f32>' in 'mat2x2<f32>' "
"constructor, found 'VectorUnsigned2'");
R"(12:34 error: invalid constructor for mat2x2<f32>
3 candidates available:
mat2x2<f32>()
mat2x2<f32>(f32,...,f32) // 4 arguments
mat2x2<f32>(vec2<f32>, vec2<f32>)
)");
}
TEST_P(MatrixConstructorTest, Expr_Constructor_ArgumentTypeAlias_Success) {
@ -1940,10 +2000,9 @@ TEST_P(MatrixConstructorTest, Expr_Constructor_ArgumentElementTypeAlias_Error) {
WrapInFunction(tc);
EXPECT_FALSE(r()->Resolve());
EXPECT_EQ(r()->error(), "12:1 error: expected argument type '" +
VecStr(param.rows) + "' in '" + MatrixStr(param) +
"' constructor, found '" +
VecStr(param.rows, "UnsignedInt") + "'");
EXPECT_THAT(r()->error(),
HasSubstr("12:1 error: invalid constructor for " +
MatrixStr(param) + "\n\n3 candidates available:"));
}
TEST_P(MatrixConstructorTest,

View File

@ -0,0 +1,73 @@
// Copyright 2021 The Tint Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "src/transform/vectorize_scalar_matrix_constructors.h"
#include <utility>
#include "src/program_builder.h"
#include "src/sem/expression.h"
TINT_INSTANTIATE_TYPEINFO(tint::transform::VectorizeScalarMatrixConstructors);
namespace tint {
namespace transform {
VectorizeScalarMatrixConstructors::VectorizeScalarMatrixConstructors() =
default;
VectorizeScalarMatrixConstructors::~VectorizeScalarMatrixConstructors() =
default;
void VectorizeScalarMatrixConstructors::Run(CloneContext& ctx,
const DataMap&,
DataMap&) {
ctx.ReplaceAll([&](const ast::TypeConstructorExpression* constructor)
-> const ast::TypeConstructorExpression* {
// Check if this is a matrix constructor with scalar arguments.
auto* mat_type = ctx.src->Sem().Get(constructor->type)->As<sem::Matrix>();
if (!mat_type) {
return nullptr;
}
if (constructor->values.size() == 0) {
return nullptr;
}
if (!ctx.src->Sem().Get(constructor->values[0])->Type()->is_scalar()) {
return nullptr;
}
// Build a list of vector expressions for each column.
ast::ExpressionList columns;
for (uint32_t c = 0; c < mat_type->columns(); c++) {
// Build a list of scalar expressions for each value in the column.
ast::ExpressionList row_values;
for (uint32_t r = 0; r < mat_type->rows(); r++) {
row_values.push_back(
ctx.Clone(constructor->values[c * mat_type->rows() + r]));
}
// Construct the column vector.
auto* col = ctx.dst->vec(CreateASTTypeFor(ctx, mat_type->type()),
mat_type->rows(), row_values);
columns.push_back(col);
}
return ctx.dst->Construct(CreateASTTypeFor(ctx, mat_type), columns);
});
ctx.Clone();
}
} // namespace transform
} // namespace tint

View File

@ -0,0 +1,46 @@
// Copyright 2021 The Tint Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#ifndef SRC_TRANSFORM_VECTORIZE_SCALAR_MATRIX_CONSTRUCTORS_H_
#define SRC_TRANSFORM_VECTORIZE_SCALAR_MATRIX_CONSTRUCTORS_H_
#include "src/transform/transform.h"
namespace tint {
namespace transform {
/// A transform that converts scalar matrix constructors to the vector form.
class VectorizeScalarMatrixConstructors
: public Castable<VectorizeScalarMatrixConstructors, Transform> {
public:
/// Constructor
VectorizeScalarMatrixConstructors();
/// Destructor
~VectorizeScalarMatrixConstructors() override;
protected:
/// Runs the transform using the CloneContext built for transforming a
/// program. Run() is responsible for calling Clone() on the CloneContext.
/// @param ctx the CloneContext primed with the input program and
/// ProgramBuilder
/// @param inputs optional extra transform-specific input data
/// @param outputs optional extra transform-specific output data
void Run(CloneContext& ctx, const DataMap& inputs, DataMap& outputs) override;
};
} // namespace transform
} // namespace tint
#endif // SRC_TRANSFORM_VECTORIZE_SCALAR_MATRIX_CONSTRUCTORS_H_

View File

@ -0,0 +1,114 @@
// Copyright 2021 The Tint Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "src/transform/vectorize_scalar_matrix_constructors.h"
#include <string>
#include <utility>
#include "src/transform/test_helper.h"
#include "src/utils/string.h"
namespace tint {
namespace transform {
namespace {
using VectorizeScalarMatrixConstructorsTest =
TransformTestWithParam<std::pair<uint32_t, uint32_t>>;
TEST_P(VectorizeScalarMatrixConstructorsTest, Basic) {
uint32_t cols = GetParam().first;
uint32_t rows = GetParam().second;
std::string mat_type =
"mat" + std::to_string(cols) + "x" + std::to_string(rows) + "<f32>";
std::string vec_type = "vec" + std::to_string(rows) + "<f32>";
std::string scalar_values;
std::string vector_values;
for (uint32_t c = 0; c < cols; c++) {
if (c > 0) {
vector_values += ", ";
scalar_values += ", ";
}
vector_values += vec_type + "(";
for (uint32_t r = 0; r < rows; r++) {
if (r > 0) {
scalar_values += ", ";
vector_values += ", ";
}
auto value = std::to_string(c * rows + r) + ".0";
scalar_values += value;
vector_values += value;
}
vector_values += ")";
}
std::string tmpl = R"(
[[stage(fragment)]]
fn main() {
let m = ${matrix}(${values});
}
)";
tmpl = utils::ReplaceAll(tmpl, "${matrix}", mat_type);
auto src = utils::ReplaceAll(tmpl, "${values}", scalar_values);
auto expect = utils::ReplaceAll(tmpl, "${values}", vector_values);
auto got = Run<VectorizeScalarMatrixConstructors>(src);
EXPECT_EQ(expect, str(got));
}
TEST_P(VectorizeScalarMatrixConstructorsTest, NonScalarConstructors) {
uint32_t cols = GetParam().first;
uint32_t rows = GetParam().second;
std::string mat_type =
"mat" + std::to_string(cols) + "x" + std::to_string(rows) + "<f32>";
std::string vec_type = "vec" + std::to_string(rows) + "<f32>";
std::string columns;
for (uint32_t c = 0; c < cols; c++) {
if (c > 0) {
columns += ", ";
}
columns += vec_type + "()";
}
std::string tmpl = R"(
[[stage(fragment)]]
fn main() {
let m = ${matrix}(${columns});
}
)";
tmpl = utils::ReplaceAll(tmpl, "${matrix}", mat_type);
auto src = utils::ReplaceAll(tmpl, "${columns}", columns);
auto expect = src;
auto got = Run<VectorizeScalarMatrixConstructors>(src);
EXPECT_EQ(expect, str(got));
}
INSTANTIATE_TEST_SUITE_P(VectorizeScalarMatrixConstructorsTest,
VectorizeScalarMatrixConstructorsTest,
testing::Values(std::make_pair(2, 2),
std::make_pair(2, 3),
std::make_pair(2, 4),
std::make_pair(3, 2),
std::make_pair(3, 3),
std::make_pair(3, 4),
std::make_pair(4, 2),
std::make_pair(4, 3),
std::make_pair(4, 4)));
} // namespace
} // namespace transform
} // namespace tint

View File

@ -65,6 +65,7 @@
#include "src/transform/promote_initializers_to_const_var.h"
#include "src/transform/remove_phonies.h"
#include "src/transform/simplify.h"
#include "src/transform/vectorize_scalar_matrix_constructors.h"
#include "src/transform/wrap_arrays_in_structs.h"
#include "src/transform/zero_init_workgroup_memory.h"
#include "src/utils/defer.h"
@ -142,6 +143,7 @@ SanitizedResult Sanitize(const Program* in,
manager.Add<transform::CanonicalizeEntryPointIO>();
manager.Add<transform::ExternalTextureTransform>();
manager.Add<transform::PromoteInitializersToConstVar>();
manager.Add<transform::VectorizeScalarMatrixConstructors>();
manager.Add<transform::WrapArraysInStructs>();
manager.Add<transform::PadArrayElements>();
manager.Add<transform::ModuleScopeVarToEntryPointParam>();

View File

@ -44,6 +44,7 @@
#include "src/transform/inline_pointer_lets.h"
#include "src/transform/manager.h"
#include "src/transform/simplify.h"
#include "src/transform/vectorize_scalar_matrix_constructors.h"
#include "src/transform/zero_init_workgroup_memory.h"
#include "src/utils/get_or_create.h"
#include "src/writer/append_vector.h"
@ -268,6 +269,7 @@ SanitizedResult Sanitize(const Program* in,
manager.Add<transform::Simplify>(); // Required for arrayLength()
manager.Add<transform::FoldConstants>();
manager.Add<transform::ExternalTextureTransform>();
manager.Add<transform::VectorizeScalarMatrixConstructors>();
manager.Add<transform::ForLoopToLoop>(); // Must come after
// ZeroInitWorkgroupMemory
manager.Add<transform::CanonicalizeEntryPointIO>();

View File

@ -314,6 +314,7 @@ tint_unittests_source_set("tint_unittests_core_src") {
"../src/transform/single_entry_point_test.cc",
"../src/transform/test_helper.h",
"../src/transform/transform_test.cc",
"../src/transform/vectorize_scalar_matrix_constructors_test.cc",
"../src/transform/vertex_pulling_test.cc",
"../src/transform/wrap_arrays_in_structs_test.cc",
"../src/transform/zero_init_workgroup_memory_test.cc",

View File

@ -0,0 +1,2 @@
let m = mat2x2<f32>(0.0, 1.0,
2.0, 3.0);

View File

@ -0,0 +1,6 @@
[numthreads(1, 1, 1)]
void unused_entry_point() {
return;
}
static const float2x2 m = float2x2(0.0f, 1.0f, 2.0f, 3.0f);

View File

@ -0,0 +1,4 @@
#include <metal_stdlib>
using namespace metal;
constant float2x2 m = float2x2(float2(0.0f, 1.0f), float2(2.0f, 3.0f));

View File

@ -0,0 +1,27 @@
; SPIR-V
; Version: 1.3
; Generator: Google Tint Compiler; 0
; Bound: 15
; Schema: 0
OpCapability Shader
OpMemoryModel Logical GLSL450
OpEntryPoint GLCompute %unused_entry_point "unused_entry_point"
OpExecutionMode %unused_entry_point LocalSize 1 1 1
OpName %m "m"
OpName %unused_entry_point "unused_entry_point"
%float = OpTypeFloat 32
%v2float = OpTypeVector %float 2
%mat2v2float = OpTypeMatrix %v2float 2
%float_0 = OpConstant %float 0
%float_1 = OpConstant %float 1
%6 = OpConstantComposite %v2float %float_0 %float_1
%float_2 = OpConstant %float 2
%float_3 = OpConstant %float 3
%9 = OpConstantComposite %v2float %float_2 %float_3
%m = OpConstantComposite %mat2v2float %6 %9
%void = OpTypeVoid
%11 = OpTypeFunction %void
%unused_entry_point = OpFunction %void None %11
%14 = OpLabel
OpReturn
OpFunctionEnd

View File

@ -0,0 +1 @@
let m = mat2x2<f32>(0.0, 1.0, 2.0, 3.0);

View File

@ -0,0 +1,2 @@
let m = mat2x3<f32>(0.0, 1.0, 2.0,
3.0, 4.0, 5.0);

View File

@ -0,0 +1,6 @@
[numthreads(1, 1, 1)]
void unused_entry_point() {
return;
}
static const float2x3 m = float2x3(0.0f, 1.0f, 2.0f, 3.0f, 4.0f, 5.0f);

View File

@ -0,0 +1,4 @@
#include <metal_stdlib>
using namespace metal;
constant float2x3 m = float2x3(float3(0.0f, 1.0f, 2.0f), float3(3.0f, 4.0f, 5.0f));

View File

@ -0,0 +1,29 @@
; SPIR-V
; Version: 1.3
; Generator: Google Tint Compiler; 0
; Bound: 17
; Schema: 0
OpCapability Shader
OpMemoryModel Logical GLSL450
OpEntryPoint GLCompute %unused_entry_point "unused_entry_point"
OpExecutionMode %unused_entry_point LocalSize 1 1 1
OpName %m "m"
OpName %unused_entry_point "unused_entry_point"
%float = OpTypeFloat 32
%v3float = OpTypeVector %float 3
%mat2v3float = OpTypeMatrix %v3float 2
%float_0 = OpConstant %float 0
%float_1 = OpConstant %float 1
%float_2 = OpConstant %float 2
%7 = OpConstantComposite %v3float %float_0 %float_1 %float_2
%float_3 = OpConstant %float 3
%float_4 = OpConstant %float 4
%float_5 = OpConstant %float 5
%11 = OpConstantComposite %v3float %float_3 %float_4 %float_5
%m = OpConstantComposite %mat2v3float %7 %11
%void = OpTypeVoid
%13 = OpTypeFunction %void
%unused_entry_point = OpFunction %void None %13
%16 = OpLabel
OpReturn
OpFunctionEnd

View File

@ -0,0 +1 @@
let m = mat2x3<f32>(0.0, 1.0, 2.0, 3.0, 4.0, 5.0);

View File

@ -0,0 +1,2 @@
let m = mat2x4<f32>(0.0, 1.0, 2.0, 3.0,
4.0, 5.0, 6.0, 7.0);

View File

@ -0,0 +1,6 @@
[numthreads(1, 1, 1)]
void unused_entry_point() {
return;
}
static const float2x4 m = float2x4(0.0f, 1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.0f);

View File

@ -0,0 +1,4 @@
#include <metal_stdlib>
using namespace metal;
constant float2x4 m = float2x4(float4(0.0f, 1.0f, 2.0f, 3.0f), float4(4.0f, 5.0f, 6.0f, 7.0f));

View File

@ -0,0 +1,31 @@
; SPIR-V
; Version: 1.3
; Generator: Google Tint Compiler; 0
; Bound: 19
; Schema: 0
OpCapability Shader
OpMemoryModel Logical GLSL450
OpEntryPoint GLCompute %unused_entry_point "unused_entry_point"
OpExecutionMode %unused_entry_point LocalSize 1 1 1
OpName %m "m"
OpName %unused_entry_point "unused_entry_point"
%float = OpTypeFloat 32
%v4float = OpTypeVector %float 4
%mat2v4float = OpTypeMatrix %v4float 2
%float_0 = OpConstant %float 0
%float_1 = OpConstant %float 1
%float_2 = OpConstant %float 2
%float_3 = OpConstant %float 3
%8 = OpConstantComposite %v4float %float_0 %float_1 %float_2 %float_3
%float_4 = OpConstant %float 4
%float_5 = OpConstant %float 5
%float_6 = OpConstant %float 6
%float_7 = OpConstant %float 7
%13 = OpConstantComposite %v4float %float_4 %float_5 %float_6 %float_7
%m = OpConstantComposite %mat2v4float %8 %13
%void = OpTypeVoid
%15 = OpTypeFunction %void
%unused_entry_point = OpFunction %void None %15
%18 = OpLabel
OpReturn
OpFunctionEnd

View File

@ -0,0 +1 @@
let m = mat2x4<f32>(0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0);

View File

@ -0,0 +1,3 @@
let m = mat3x2<f32>(0.0, 1.0,
2.0, 3.0,
4.0, 5.0);

View File

@ -0,0 +1,6 @@
[numthreads(1, 1, 1)]
void unused_entry_point() {
return;
}
static const float3x2 m = float3x2(0.0f, 1.0f, 2.0f, 3.0f, 4.0f, 5.0f);

View File

@ -0,0 +1,4 @@
#include <metal_stdlib>
using namespace metal;
constant float3x2 m = float3x2(float2(0.0f, 1.0f), float2(2.0f, 3.0f), float2(4.0f, 5.0f));

View File

@ -0,0 +1,30 @@
; SPIR-V
; Version: 1.3
; Generator: Google Tint Compiler; 0
; Bound: 18
; Schema: 0
OpCapability Shader
OpMemoryModel Logical GLSL450
OpEntryPoint GLCompute %unused_entry_point "unused_entry_point"
OpExecutionMode %unused_entry_point LocalSize 1 1 1
OpName %m "m"
OpName %unused_entry_point "unused_entry_point"
%float = OpTypeFloat 32
%v2float = OpTypeVector %float 2
%mat3v2float = OpTypeMatrix %v2float 3
%float_0 = OpConstant %float 0
%float_1 = OpConstant %float 1
%6 = OpConstantComposite %v2float %float_0 %float_1
%float_2 = OpConstant %float 2
%float_3 = OpConstant %float 3
%9 = OpConstantComposite %v2float %float_2 %float_3
%float_4 = OpConstant %float 4
%float_5 = OpConstant %float 5
%12 = OpConstantComposite %v2float %float_4 %float_5
%m = OpConstantComposite %mat3v2float %6 %9 %12
%void = OpTypeVoid
%14 = OpTypeFunction %void
%unused_entry_point = OpFunction %void None %14
%17 = OpLabel
OpReturn
OpFunctionEnd

View File

@ -0,0 +1 @@
let m = mat3x2<f32>(0.0, 1.0, 2.0, 3.0, 4.0, 5.0);

View File

@ -0,0 +1,3 @@
let m = mat3x3<f32>(0.0, 1.0, 2.0,
3.0, 4.0, 5.0,
6.0, 7.0, 8.0);

View File

@ -0,0 +1,6 @@
[numthreads(1, 1, 1)]
void unused_entry_point() {
return;
}
static const float3x3 m = float3x3(0.0f, 1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.0f, 8.0f);

View File

@ -0,0 +1,4 @@
#include <metal_stdlib>
using namespace metal;
constant float3x3 m = float3x3(float3(0.0f, 1.0f, 2.0f), float3(3.0f, 4.0f, 5.0f), float3(6.0f, 7.0f, 8.0f));

View File

@ -0,0 +1,33 @@
; SPIR-V
; Version: 1.3
; Generator: Google Tint Compiler; 0
; Bound: 21
; Schema: 0
OpCapability Shader
OpMemoryModel Logical GLSL450
OpEntryPoint GLCompute %unused_entry_point "unused_entry_point"
OpExecutionMode %unused_entry_point LocalSize 1 1 1
OpName %m "m"
OpName %unused_entry_point "unused_entry_point"
%float = OpTypeFloat 32
%v3float = OpTypeVector %float 3
%mat3v3float = OpTypeMatrix %v3float 3
%float_0 = OpConstant %float 0
%float_1 = OpConstant %float 1
%float_2 = OpConstant %float 2
%7 = OpConstantComposite %v3float %float_0 %float_1 %float_2
%float_3 = OpConstant %float 3
%float_4 = OpConstant %float 4
%float_5 = OpConstant %float 5
%11 = OpConstantComposite %v3float %float_3 %float_4 %float_5
%float_6 = OpConstant %float 6
%float_7 = OpConstant %float 7
%float_8 = OpConstant %float 8
%15 = OpConstantComposite %v3float %float_6 %float_7 %float_8
%m = OpConstantComposite %mat3v3float %7 %11 %15
%void = OpTypeVoid
%17 = OpTypeFunction %void
%unused_entry_point = OpFunction %void None %17
%20 = OpLabel
OpReturn
OpFunctionEnd

View File

@ -0,0 +1 @@
let m = mat3x3<f32>(0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0);

View File

@ -0,0 +1,3 @@
let m = mat3x4<f32>(0.0, 1.0, 2.0, 3.0,
4.0, 5.0, 6.0, 7.0,
8.0, 9.0, 10.0, 11.0);

View File

@ -0,0 +1,6 @@
[numthreads(1, 1, 1)]
void unused_entry_point() {
return;
}
static const float3x4 m = float3x4(0.0f, 1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.0f, 8.0f, 9.0f, 10.0f, 11.0f);

View File

@ -0,0 +1,4 @@
#include <metal_stdlib>
using namespace metal;
constant float3x4 m = float3x4(float4(0.0f, 1.0f, 2.0f, 3.0f), float4(4.0f, 5.0f, 6.0f, 7.0f), float4(8.0f, 9.0f, 10.0f, 11.0f));

View File

@ -0,0 +1,36 @@
; SPIR-V
; Version: 1.3
; Generator: Google Tint Compiler; 0
; Bound: 24
; Schema: 0
OpCapability Shader
OpMemoryModel Logical GLSL450
OpEntryPoint GLCompute %unused_entry_point "unused_entry_point"
OpExecutionMode %unused_entry_point LocalSize 1 1 1
OpName %m "m"
OpName %unused_entry_point "unused_entry_point"
%float = OpTypeFloat 32
%v4float = OpTypeVector %float 4
%mat3v4float = OpTypeMatrix %v4float 3
%float_0 = OpConstant %float 0
%float_1 = OpConstant %float 1
%float_2 = OpConstant %float 2
%float_3 = OpConstant %float 3
%8 = OpConstantComposite %v4float %float_0 %float_1 %float_2 %float_3
%float_4 = OpConstant %float 4
%float_5 = OpConstant %float 5
%float_6 = OpConstant %float 6
%float_7 = OpConstant %float 7
%13 = OpConstantComposite %v4float %float_4 %float_5 %float_6 %float_7
%float_8 = OpConstant %float 8
%float_9 = OpConstant %float 9
%float_10 = OpConstant %float 10
%float_11 = OpConstant %float 11
%18 = OpConstantComposite %v4float %float_8 %float_9 %float_10 %float_11
%m = OpConstantComposite %mat3v4float %8 %13 %18
%void = OpTypeVoid
%20 = OpTypeFunction %void
%unused_entry_point = OpFunction %void None %20
%23 = OpLabel
OpReturn
OpFunctionEnd

View File

@ -0,0 +1 @@
let m = mat3x4<f32>(0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0);

View File

@ -0,0 +1,4 @@
let m = mat4x2<f32>(0.0, 1.0,
2.0, 3.0,
4.0, 5.0,
6.0, 7.0);

View File

@ -0,0 +1,6 @@
[numthreads(1, 1, 1)]
void unused_entry_point() {
return;
}
static const float4x2 m = float4x2(0.0f, 1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.0f);

View File

@ -0,0 +1,4 @@
#include <metal_stdlib>
using namespace metal;
constant float4x2 m = float4x2(float2(0.0f, 1.0f), float2(2.0f, 3.0f), float2(4.0f, 5.0f), float2(6.0f, 7.0f));

View File

@ -0,0 +1,33 @@
; SPIR-V
; Version: 1.3
; Generator: Google Tint Compiler; 0
; Bound: 21
; Schema: 0
OpCapability Shader
OpMemoryModel Logical GLSL450
OpEntryPoint GLCompute %unused_entry_point "unused_entry_point"
OpExecutionMode %unused_entry_point LocalSize 1 1 1
OpName %m "m"
OpName %unused_entry_point "unused_entry_point"
%float = OpTypeFloat 32
%v2float = OpTypeVector %float 2
%mat4v2float = OpTypeMatrix %v2float 4
%float_0 = OpConstant %float 0
%float_1 = OpConstant %float 1
%6 = OpConstantComposite %v2float %float_0 %float_1
%float_2 = OpConstant %float 2
%float_3 = OpConstant %float 3
%9 = OpConstantComposite %v2float %float_2 %float_3
%float_4 = OpConstant %float 4
%float_5 = OpConstant %float 5
%12 = OpConstantComposite %v2float %float_4 %float_5
%float_6 = OpConstant %float 6
%float_7 = OpConstant %float 7
%15 = OpConstantComposite %v2float %float_6 %float_7
%m = OpConstantComposite %mat4v2float %6 %9 %12 %15
%void = OpTypeVoid
%17 = OpTypeFunction %void
%unused_entry_point = OpFunction %void None %17
%20 = OpLabel
OpReturn
OpFunctionEnd

View File

@ -0,0 +1 @@
let m = mat4x2<f32>(0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0);

View File

@ -0,0 +1,4 @@
let m = mat4x3<f32>(0.0, 1.0, 2.0,
3.0, 4.0, 5.0,
6.0, 7.0, 8.0,
9.0, 10.0, 11.0);

View File

@ -0,0 +1,6 @@
[numthreads(1, 1, 1)]
void unused_entry_point() {
return;
}
static const float4x3 m = float4x3(0.0f, 1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.0f, 8.0f, 9.0f, 10.0f, 11.0f);

View File

@ -0,0 +1,4 @@
#include <metal_stdlib>
using namespace metal;
constant float4x3 m = float4x3(float3(0.0f, 1.0f, 2.0f), float3(3.0f, 4.0f, 5.0f), float3(6.0f, 7.0f, 8.0f), float3(9.0f, 10.0f, 11.0f));

View File

@ -0,0 +1,37 @@
; SPIR-V
; Version: 1.3
; Generator: Google Tint Compiler; 0
; Bound: 25
; Schema: 0
OpCapability Shader
OpMemoryModel Logical GLSL450
OpEntryPoint GLCompute %unused_entry_point "unused_entry_point"
OpExecutionMode %unused_entry_point LocalSize 1 1 1
OpName %m "m"
OpName %unused_entry_point "unused_entry_point"
%float = OpTypeFloat 32
%v3float = OpTypeVector %float 3
%mat4v3float = OpTypeMatrix %v3float 4
%float_0 = OpConstant %float 0
%float_1 = OpConstant %float 1
%float_2 = OpConstant %float 2
%7 = OpConstantComposite %v3float %float_0 %float_1 %float_2
%float_3 = OpConstant %float 3
%float_4 = OpConstant %float 4
%float_5 = OpConstant %float 5
%11 = OpConstantComposite %v3float %float_3 %float_4 %float_5
%float_6 = OpConstant %float 6
%float_7 = OpConstant %float 7
%float_8 = OpConstant %float 8
%15 = OpConstantComposite %v3float %float_6 %float_7 %float_8
%float_9 = OpConstant %float 9
%float_10 = OpConstant %float 10
%float_11 = OpConstant %float 11
%19 = OpConstantComposite %v3float %float_9 %float_10 %float_11
%m = OpConstantComposite %mat4v3float %7 %11 %15 %19
%void = OpTypeVoid
%21 = OpTypeFunction %void
%unused_entry_point = OpFunction %void None %21
%24 = OpLabel
OpReturn
OpFunctionEnd

View File

@ -0,0 +1 @@
let m = mat4x3<f32>(0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0);

View File

@ -0,0 +1,4 @@
let m = mat4x4<f32>(0.0, 1.0, 2.0, 3.0,
4.0, 5.0, 6.0, 7.0,
8.0, 9.0, 10.0, 11.0,
12.0, 13.0, 14.0, 15.0);

View File

@ -0,0 +1,6 @@
[numthreads(1, 1, 1)]
void unused_entry_point() {
return;
}
static const float4x4 m = float4x4(0.0f, 1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.0f, 8.0f, 9.0f, 10.0f, 11.0f, 12.0f, 13.0f, 14.0f, 15.0f);

View File

@ -0,0 +1,4 @@
#include <metal_stdlib>
using namespace metal;
constant float4x4 m = float4x4(float4(0.0f, 1.0f, 2.0f, 3.0f), float4(4.0f, 5.0f, 6.0f, 7.0f), float4(8.0f, 9.0f, 10.0f, 11.0f), float4(12.0f, 13.0f, 14.0f, 15.0f));

View File

@ -0,0 +1,41 @@
; SPIR-V
; Version: 1.3
; Generator: Google Tint Compiler; 0
; Bound: 29
; Schema: 0
OpCapability Shader
OpMemoryModel Logical GLSL450
OpEntryPoint GLCompute %unused_entry_point "unused_entry_point"
OpExecutionMode %unused_entry_point LocalSize 1 1 1
OpName %m "m"
OpName %unused_entry_point "unused_entry_point"
%float = OpTypeFloat 32
%v4float = OpTypeVector %float 4
%mat4v4float = OpTypeMatrix %v4float 4
%float_0 = OpConstant %float 0
%float_1 = OpConstant %float 1
%float_2 = OpConstant %float 2
%float_3 = OpConstant %float 3
%8 = OpConstantComposite %v4float %float_0 %float_1 %float_2 %float_3
%float_4 = OpConstant %float 4
%float_5 = OpConstant %float 5
%float_6 = OpConstant %float 6
%float_7 = OpConstant %float 7
%13 = OpConstantComposite %v4float %float_4 %float_5 %float_6 %float_7
%float_8 = OpConstant %float 8
%float_9 = OpConstant %float 9
%float_10 = OpConstant %float 10
%float_11 = OpConstant %float 11
%18 = OpConstantComposite %v4float %float_8 %float_9 %float_10 %float_11
%float_12 = OpConstant %float 12
%float_13 = OpConstant %float 13
%float_14 = OpConstant %float 14
%float_15 = OpConstant %float 15
%23 = OpConstantComposite %v4float %float_12 %float_13 %float_14 %float_15
%m = OpConstantComposite %mat4v4float %8 %13 %18 %23
%void = OpTypeVoid
%25 = OpTypeFunction %void
%unused_entry_point = OpFunction %void None %25
%28 = OpLabel
OpReturn
OpFunctionEnd

View File

@ -0,0 +1 @@
let m = mat4x4<f32>(0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0, 13.0, 14.0, 15.0);