[spirv-writer] Add binary multiplication.

This CL adds binary multiplication generation to the SPIR-V writer.

Bug: tint:5
Change-Id: I668d24035e947c51a9737549fd0841a4e8af1331
Reviewed-on: https://dawn-review.googlesource.com/c/tint/+/19700
Reviewed-by: David Neto <dneto@google.com>
This commit is contained in:
dan sinclair 2020-04-16 14:14:11 +00:00
parent 460345d993
commit 366b74c364
2 changed files with 371 additions and 18 deletions

View File

@ -79,6 +79,36 @@ uint32_t pipeline_stage_to_execution_model(ast::PipelineStage stage) {
return model; return model;
} }
bool is_float_scalar(ast::type::Type* type) {
return type->IsF32();
}
bool is_float_matrix(ast::type::Type* type) {
return type->IsMatrix() && type->AsMatrix()->type()->IsF32();
}
bool is_float_vector(ast::type::Type* type) {
return type->IsVector() && type->AsVector()->type()->IsF32();
}
bool is_float_scalar_or_vector(ast::type::Type* type) {
return is_float_scalar(type) || is_float_vector(type);
}
bool is_unsigned_scalar_or_vector(ast::type::Type* type) {
return type->IsU32() ||
(type->IsVector() && type->AsVector()->type()->IsU32());
}
bool is_signed_scalar_or_vector(ast::type::Type* type) {
return type->IsI32() ||
(type->IsVector() && type->AsVector()->type()->IsI32());
}
bool is_integer_scalar_or_vector(ast::type::Type* type) {
return is_unsigned_scalar_or_vector(type) || is_signed_scalar_or_vector(type);
}
} // namespace } // namespace
Builder::Builder() : scope_stack_({}) {} Builder::Builder() : scope_stack_({}) {}
@ -569,12 +599,9 @@ uint32_t Builder::GenerateBinaryExpression(ast::BinaryExpression* expr) {
// Handle int and float and the vectors of those types. Other types // Handle int and float and the vectors of those types. Other types
// should have been rejected by validation. // should have been rejected by validation.
auto* lhs_type = expr->lhs()->result_type(); auto* lhs_type = expr->lhs()->result_type();
bool lhs_is_float_or_vec = auto* rhs_type = expr->rhs()->result_type();
lhs_type->IsF32() || bool lhs_is_float_or_vec = is_float_scalar_or_vector(lhs_type);
(lhs_type->IsVector() && lhs_type->AsVector()->type()->IsF32()); bool lhs_is_unsigned = is_unsigned_scalar_or_vector(lhs_type);
bool lhs_is_unsigned =
lhs_type->IsU32() ||
(lhs_type->IsVector() && lhs_type->AsVector()->type()->IsU32());
spv::Op op = spv::Op::OpNop; spv::Op op = spv::Op::OpNop;
if (expr->IsAnd()) { if (expr->IsAnd()) {
@ -631,6 +658,45 @@ uint32_t Builder::GenerateBinaryExpression(ast::BinaryExpression* expr) {
} else { } else {
op = spv::Op::OpSMod; op = spv::Op::OpSMod;
} }
} else if (expr->IsMultiply()) {
if (is_integer_scalar_or_vector(lhs_type)) {
// If the left hand side is an integer then this _has_ to be OpIMul as
// there there is no other integer multiplication.
op = spv::Op::OpIMul;
} else if (is_float_scalar(lhs_type) && is_float_scalar(rhs_type)) {
// Float scalars multiply with OpFMul
op = spv::Op::OpFMul;
} else if (is_float_vector(lhs_type) && is_float_vector(rhs_type)) {
// Float vectors must be validated to be the same size and then use OpFMul
op = spv::Op::OpFMul;
} else if (is_float_scalar(lhs_type) && is_float_vector(rhs_type)) {
// Scalar * Vector we need to flip lhs and rhs types
// because OpVectorTimesScalar expects <vector>, <scalar>
std::swap(lhs_id, rhs_id);
op = spv::Op::OpVectorTimesScalar;
} else if (is_float_vector(lhs_type) && is_float_scalar(rhs_type)) {
// float vector * scalar
op = spv::Op::OpVectorTimesScalar;
} else if (is_float_scalar(lhs_type) && is_float_matrix(rhs_type)) {
// Scalar * Matrix we need to flip lhs and rhs types because
// OpMatrixTimesScalar expects <matrix>, <scalar>
std::swap(lhs_id, rhs_id);
op = spv::Op::OpMatrixTimesScalar;
} else if (is_float_matrix(lhs_type) && is_float_scalar(rhs_type)) {
// float matrix * scalar
op = spv::Op::OpMatrixTimesScalar;
} else if (is_float_vector(lhs_type) && is_float_matrix(rhs_type)) {
// float vector * matrix
op = spv::Op::OpVectorTimesMatrix;
} else if (is_float_matrix(lhs_type) && is_float_vector(rhs_type)) {
// float matrix * vector
op = spv::Op::OpMatrixTimesVector;
} else if (is_float_matrix(lhs_type) && is_float_matrix(rhs_type)) {
// float matrix * matrix
op = spv::Op::OpMatrixTimesMatrix;
} else {
return 0;
}
} else if (expr->IsNotEqual()) { } else if (expr->IsNotEqual()) {
op = lhs_is_float_or_vec ? spv::Op::OpFOrdNotEqual : spv::Op::OpINotEqual; op = lhs_is_float_or_vec ? spv::Op::OpFOrdNotEqual : spv::Op::OpINotEqual;
} else if (expr->IsOr()) { } else if (expr->IsOr()) {

View File

@ -17,10 +17,12 @@
#include "gtest/gtest.h" #include "gtest/gtest.h"
#include "src/ast/binary_expression.h" #include "src/ast/binary_expression.h"
#include "src/ast/float_literal.h" #include "src/ast/float_literal.h"
#include "src/ast/identifier_expression.h"
#include "src/ast/int_literal.h" #include "src/ast/int_literal.h"
#include "src/ast/scalar_constructor_expression.h" #include "src/ast/scalar_constructor_expression.h"
#include "src/ast/type/f32_type.h" #include "src/ast/type/f32_type.h"
#include "src/ast/type/i32_type.h" #include "src/ast/type/i32_type.h"
#include "src/ast/type/matrix_type.h"
#include "src/ast/type/u32_type.h" #include "src/ast/type/u32_type.h"
#include "src/ast/type/vector_type.h" #include "src/ast/type/vector_type.h"
#include "src/ast/type_constructor_expression.h" #include "src/ast/type_constructor_expression.h"
@ -65,7 +67,7 @@ TEST_P(BinaryArithSignedIntegerTest, Scalar) {
Builder b; Builder b;
b.push_function(Function{}); b.push_function(Function{});
ASSERT_EQ(b.GenerateBinaryExpression(&expr), 4) << b.error(); EXPECT_EQ(b.GenerateBinaryExpression(&expr), 4) << b.error();
EXPECT_EQ(DumpInstructions(b.types()), R"(%1 = OpTypeInt 32 1 EXPECT_EQ(DumpInstructions(b.types()), R"(%1 = OpTypeInt 32 1
%2 = OpConstant %1 3 %2 = OpConstant %1 3
%3 = OpConstant %1 4 %3 = OpConstant %1 4
@ -108,7 +110,7 @@ TEST_P(BinaryArithSignedIntegerTest, Vector) {
Builder b; Builder b;
b.push_function(Function{}); b.push_function(Function{});
ASSERT_EQ(b.GenerateBinaryExpression(&expr), 5) << b.error(); EXPECT_EQ(b.GenerateBinaryExpression(&expr), 5) << b.error();
EXPECT_EQ(DumpInstructions(b.types()), R"(%2 = OpTypeInt 32 1 EXPECT_EQ(DumpInstructions(b.types()), R"(%2 = OpTypeInt 32 1
%1 = OpTypeVector %2 3 %1 = OpTypeVector %2 3
%3 = OpConstant %2 1 %3 = OpConstant %2 1
@ -125,6 +127,7 @@ INSTANTIATE_TEST_SUITE_P(
BinaryData{ast::BinaryOp::kAnd, "OpBitwiseAnd"}, BinaryData{ast::BinaryOp::kAnd, "OpBitwiseAnd"},
BinaryData{ast::BinaryOp::kDivide, "OpSDiv"}, BinaryData{ast::BinaryOp::kDivide, "OpSDiv"},
BinaryData{ast::BinaryOp::kModulo, "OpSMod"}, BinaryData{ast::BinaryOp::kModulo, "OpSMod"},
BinaryData{ast::BinaryOp::kMultiply, "OpIMul"},
BinaryData{ast::BinaryOp::kOr, "OpBitwiseOr"}, BinaryData{ast::BinaryOp::kOr, "OpBitwiseOr"},
BinaryData{ast::BinaryOp::kShiftLeft, "OpShiftLeftLogical"}, BinaryData{ast::BinaryOp::kShiftLeft, "OpShiftLeftLogical"},
BinaryData{ast::BinaryOp::kShiftRight, "OpShiftRightLogical"}, BinaryData{ast::BinaryOp::kShiftRight, "OpShiftRightLogical"},
@ -152,7 +155,7 @@ TEST_P(BinaryArithUnsignedIntegerTest, Scalar) {
Builder b; Builder b;
b.push_function(Function{}); b.push_function(Function{});
ASSERT_EQ(b.GenerateBinaryExpression(&expr), 4) << b.error(); EXPECT_EQ(b.GenerateBinaryExpression(&expr), 4) << b.error();
EXPECT_EQ(DumpInstructions(b.types()), R"(%1 = OpTypeInt 32 0 EXPECT_EQ(DumpInstructions(b.types()), R"(%1 = OpTypeInt 32 0
%2 = OpConstant %1 3 %2 = OpConstant %1 3
%3 = OpConstant %1 4 %3 = OpConstant %1 4
@ -195,7 +198,7 @@ TEST_P(BinaryArithUnsignedIntegerTest, Vector) {
Builder b; Builder b;
b.push_function(Function{}); b.push_function(Function{});
ASSERT_EQ(b.GenerateBinaryExpression(&expr), 5) << b.error(); EXPECT_EQ(b.GenerateBinaryExpression(&expr), 5) << b.error();
EXPECT_EQ(DumpInstructions(b.types()), R"(%2 = OpTypeInt 32 0 EXPECT_EQ(DumpInstructions(b.types()), R"(%2 = OpTypeInt 32 0
%1 = OpTypeVector %2 3 %1 = OpTypeVector %2 3
%3 = OpConstant %2 1 %3 = OpConstant %2 1
@ -212,6 +215,7 @@ INSTANTIATE_TEST_SUITE_P(
BinaryData{ast::BinaryOp::kAnd, "OpBitwiseAnd"}, BinaryData{ast::BinaryOp::kAnd, "OpBitwiseAnd"},
BinaryData{ast::BinaryOp::kDivide, "OpUDiv"}, BinaryData{ast::BinaryOp::kDivide, "OpUDiv"},
BinaryData{ast::BinaryOp::kModulo, "OpUMod"}, BinaryData{ast::BinaryOp::kModulo, "OpUMod"},
BinaryData{ast::BinaryOp::kMultiply, "OpIMul"},
BinaryData{ast::BinaryOp::kOr, "OpBitwiseOr"}, BinaryData{ast::BinaryOp::kOr, "OpBitwiseOr"},
BinaryData{ast::BinaryOp::kShiftLeft, "OpShiftLeftLogical"}, BinaryData{ast::BinaryOp::kShiftLeft, "OpShiftLeftLogical"},
BinaryData{ast::BinaryOp::kShiftRight, "OpShiftRightLogical"}, BinaryData{ast::BinaryOp::kShiftRight, "OpShiftRightLogical"},
@ -239,7 +243,7 @@ TEST_P(BinaryArithFloatTest, Scalar) {
Builder b; Builder b;
b.push_function(Function{}); b.push_function(Function{});
ASSERT_EQ(b.GenerateBinaryExpression(&expr), 4) << b.error(); EXPECT_EQ(b.GenerateBinaryExpression(&expr), 4) << b.error();
EXPECT_EQ(DumpInstructions(b.types()), R"(%1 = OpTypeFloat 32 EXPECT_EQ(DumpInstructions(b.types()), R"(%1 = OpTypeFloat 32
%2 = OpConstant %1 3.20000005 %2 = OpConstant %1 3.20000005
%3 = OpConstant %1 4.5 %3 = OpConstant %1 4.5
@ -283,7 +287,7 @@ TEST_P(BinaryArithFloatTest, Vector) {
Builder b; Builder b;
b.push_function(Function{}); b.push_function(Function{});
ASSERT_EQ(b.GenerateBinaryExpression(&expr), 5) << b.error(); EXPECT_EQ(b.GenerateBinaryExpression(&expr), 5) << b.error();
EXPECT_EQ(DumpInstructions(b.types()), R"(%2 = OpTypeFloat 32 EXPECT_EQ(DumpInstructions(b.types()), R"(%2 = OpTypeFloat 32
%1 = OpTypeVector %2 3 %1 = OpTypeVector %2 3
%3 = OpConstant %2 1 %3 = OpConstant %2 1
@ -298,6 +302,7 @@ INSTANTIATE_TEST_SUITE_P(
testing::Values(BinaryData{ast::BinaryOp::kAdd, "OpFAdd"}, testing::Values(BinaryData{ast::BinaryOp::kAdd, "OpFAdd"},
BinaryData{ast::BinaryOp::kDivide, "OpFDiv"}, BinaryData{ast::BinaryOp::kDivide, "OpFDiv"},
BinaryData{ast::BinaryOp::kModulo, "OpFMod"}, BinaryData{ast::BinaryOp::kModulo, "OpFMod"},
BinaryData{ast::BinaryOp::kMultiply, "OpFMul"},
BinaryData{ast::BinaryOp::kSubtract, "OpFSub"})); BinaryData{ast::BinaryOp::kSubtract, "OpFSub"}));
using BinaryCompareUnsignedIntegerTest = testing::TestWithParam<BinaryData>; using BinaryCompareUnsignedIntegerTest = testing::TestWithParam<BinaryData>;
@ -320,7 +325,7 @@ TEST_P(BinaryCompareUnsignedIntegerTest, Scalar) {
Builder b; Builder b;
b.push_function(Function{}); b.push_function(Function{});
ASSERT_EQ(b.GenerateBinaryExpression(&expr), 4) << b.error(); EXPECT_EQ(b.GenerateBinaryExpression(&expr), 4) << b.error();
EXPECT_EQ(DumpInstructions(b.types()), R"(%1 = OpTypeInt 32 0 EXPECT_EQ(DumpInstructions(b.types()), R"(%1 = OpTypeInt 32 0
%2 = OpConstant %1 3 %2 = OpConstant %1 3
%3 = OpConstant %1 4 %3 = OpConstant %1 4
@ -365,7 +370,7 @@ TEST_P(BinaryCompareUnsignedIntegerTest, Vector) {
Builder b; Builder b;
b.push_function(Function{}); b.push_function(Function{});
ASSERT_EQ(b.GenerateBinaryExpression(&expr), 5) << b.error(); EXPECT_EQ(b.GenerateBinaryExpression(&expr), 5) << b.error();
EXPECT_EQ(DumpInstructions(b.types()), R"(%2 = OpTypeInt 32 0 EXPECT_EQ(DumpInstructions(b.types()), R"(%2 = OpTypeInt 32 0
%1 = OpTypeVector %2 3 %1 = OpTypeVector %2 3
%3 = OpConstant %2 1 %3 = OpConstant %2 1
@ -407,7 +412,7 @@ TEST_P(BinaryCompareSignedIntegerTest, Scalar) {
Builder b; Builder b;
b.push_function(Function{}); b.push_function(Function{});
ASSERT_EQ(b.GenerateBinaryExpression(&expr), 4) << b.error(); EXPECT_EQ(b.GenerateBinaryExpression(&expr), 4) << b.error();
EXPECT_EQ(DumpInstructions(b.types()), R"(%1 = OpTypeInt 32 1 EXPECT_EQ(DumpInstructions(b.types()), R"(%1 = OpTypeInt 32 1
%2 = OpConstant %1 3 %2 = OpConstant %1 3
%3 = OpConstant %1 4 %3 = OpConstant %1 4
@ -452,7 +457,7 @@ TEST_P(BinaryCompareSignedIntegerTest, Vector) {
Builder b; Builder b;
b.push_function(Function{}); b.push_function(Function{});
ASSERT_EQ(b.GenerateBinaryExpression(&expr), 5) << b.error(); EXPECT_EQ(b.GenerateBinaryExpression(&expr), 5) << b.error();
EXPECT_EQ(DumpInstructions(b.types()), R"(%2 = OpTypeInt 32 1 EXPECT_EQ(DumpInstructions(b.types()), R"(%2 = OpTypeInt 32 1
%1 = OpTypeVector %2 3 %1 = OpTypeVector %2 3
%3 = OpConstant %2 1 %3 = OpConstant %2 1
@ -494,7 +499,7 @@ TEST_P(BinaryCompareFloatTest, Scalar) {
Builder b; Builder b;
b.push_function(Function{}); b.push_function(Function{});
ASSERT_EQ(b.GenerateBinaryExpression(&expr), 4) << b.error(); EXPECT_EQ(b.GenerateBinaryExpression(&expr), 4) << b.error();
EXPECT_EQ(DumpInstructions(b.types()), R"(%1 = OpTypeFloat 32 EXPECT_EQ(DumpInstructions(b.types()), R"(%1 = OpTypeFloat 32
%2 = OpConstant %1 3.20000005 %2 = OpConstant %1 3.20000005
%3 = OpConstant %1 4.5 %3 = OpConstant %1 4.5
@ -539,7 +544,7 @@ TEST_P(BinaryCompareFloatTest, Vector) {
Builder b; Builder b;
b.push_function(Function{}); b.push_function(Function{});
ASSERT_EQ(b.GenerateBinaryExpression(&expr), 5) << b.error(); EXPECT_EQ(b.GenerateBinaryExpression(&expr), 5) << b.error();
EXPECT_EQ(DumpInstructions(b.types()), R"(%2 = OpTypeFloat 32 EXPECT_EQ(DumpInstructions(b.types()), R"(%2 = OpTypeFloat 32
%1 = OpTypeVector %2 3 %1 = OpTypeVector %2 3
%3 = OpConstant %2 1 %3 = OpConstant %2 1
@ -561,6 +566,288 @@ INSTANTIATE_TEST_SUITE_P(
BinaryData{ast::BinaryOp::kLessThanEqual, "OpFOrdLessThanEqual"}, BinaryData{ast::BinaryOp::kLessThanEqual, "OpFOrdLessThanEqual"},
BinaryData{ast::BinaryOp::kNotEqual, "OpFOrdNotEqual"})); BinaryData{ast::BinaryOp::kNotEqual, "OpFOrdNotEqual"}));
TEST_F(BuilderTest, Binary_Multiply_VectorScalar) {
ast::type::F32Type f32;
ast::type::VectorType vec3(&f32, 3);
ast::ExpressionList vals;
vals.push_back(std::make_unique<ast::ScalarConstructorExpression>(
std::make_unique<ast::FloatLiteral>(&f32, 1.f)));
vals.push_back(std::make_unique<ast::ScalarConstructorExpression>(
std::make_unique<ast::FloatLiteral>(&f32, 1.f)));
vals.push_back(std::make_unique<ast::ScalarConstructorExpression>(
std::make_unique<ast::FloatLiteral>(&f32, 1.f)));
auto lhs =
std::make_unique<ast::TypeConstructorExpression>(&vec3, std::move(vals));
auto rhs = std::make_unique<ast::ScalarConstructorExpression>(
std::make_unique<ast::FloatLiteral>(&f32, 1.f));
Context ctx;
TypeDeterminer td(&ctx);
ast::BinaryExpression expr(ast::BinaryOp::kMultiply, std::move(lhs),
std::move(rhs));
ASSERT_TRUE(td.DetermineResultType(&expr)) << td.error();
Builder b;
b.push_function(Function{});
EXPECT_EQ(b.GenerateBinaryExpression(&expr), 5) << b.error();
EXPECT_EQ(DumpInstructions(b.types()), R"(%2 = OpTypeFloat 32
%1 = OpTypeVector %2 3
%3 = OpConstant %2 1
%4 = OpConstantComposite %1 %3 %3 %3
)");
EXPECT_EQ(DumpInstructions(b.functions()[0].instructions()),
"%5 = OpVectorTimesScalar %1 %4 %3\n");
}
TEST_F(BuilderTest, Binary_Multiply_ScalarVector) {
ast::type::F32Type f32;
ast::type::VectorType vec3(&f32, 3);
auto lhs = std::make_unique<ast::ScalarConstructorExpression>(
std::make_unique<ast::FloatLiteral>(&f32, 1.f));
ast::ExpressionList vals;
vals.push_back(std::make_unique<ast::ScalarConstructorExpression>(
std::make_unique<ast::FloatLiteral>(&f32, 1.f)));
vals.push_back(std::make_unique<ast::ScalarConstructorExpression>(
std::make_unique<ast::FloatLiteral>(&f32, 1.f)));
vals.push_back(std::make_unique<ast::ScalarConstructorExpression>(
std::make_unique<ast::FloatLiteral>(&f32, 1.f)));
auto rhs =
std::make_unique<ast::TypeConstructorExpression>(&vec3, std::move(vals));
Context ctx;
TypeDeterminer td(&ctx);
ast::BinaryExpression expr(ast::BinaryOp::kMultiply, std::move(lhs),
std::move(rhs));
ASSERT_TRUE(td.DetermineResultType(&expr)) << td.error();
Builder b;
b.push_function(Function{});
EXPECT_EQ(b.GenerateBinaryExpression(&expr), 5) << b.error();
EXPECT_EQ(DumpInstructions(b.types()), R"(%1 = OpTypeFloat 32
%2 = OpConstant %1 1
%3 = OpTypeVector %1 3
%4 = OpConstantComposite %3 %2 %2 %2
)");
EXPECT_EQ(DumpInstructions(b.functions()[0].instructions()),
"%5 = OpVectorTimesScalar %3 %4 %2\n");
}
TEST_F(BuilderTest, Binary_Multiply_MatrixScalar) {
ast::type::F32Type f32;
ast::type::MatrixType mat3(&f32, 3, 3);
auto var = std::make_unique<ast::Variable>(
"mat", ast::StorageClass::kFunction, &mat3);
auto lhs = std::make_unique<ast::IdentifierExpression>("mat");
auto rhs = std::make_unique<ast::ScalarConstructorExpression>(
std::make_unique<ast::FloatLiteral>(&f32, 1.f));
Context ctx;
TypeDeterminer td(&ctx);
td.RegisterVariableForTesting(var.get());
ast::BinaryExpression expr(ast::BinaryOp::kMultiply, std::move(lhs),
std::move(rhs));
ASSERT_TRUE(td.DetermineResultType(&expr)) << td.error();
Builder b;
b.push_function(Function{});
ASSERT_TRUE(b.GenerateGlobalVariable(var.get())) << b.error();
EXPECT_EQ(b.GenerateBinaryExpression(&expr), 8) << b.error();
EXPECT_EQ(DumpInstructions(b.types()), R"(%5 = OpTypeFloat 32
%4 = OpTypeVector %5 3
%3 = OpTypeMatrix %4 3
%2 = OpTypePointer Function %3
%1 = OpVariable %2 Function
%7 = OpConstant %5 1
)");
EXPECT_EQ(DumpInstructions(b.functions()[0].instructions()),
R"(%6 = OpLoad %3 %1
%8 = OpMatrixTimesScalar %3 %6 %7
)");
}
TEST_F(BuilderTest, Binary_Multiply_ScalarMatrix) {
ast::type::F32Type f32;
ast::type::MatrixType mat3(&f32, 3, 3);
auto var = std::make_unique<ast::Variable>(
"mat", ast::StorageClass::kFunction, &mat3);
auto lhs = std::make_unique<ast::ScalarConstructorExpression>(
std::make_unique<ast::FloatLiteral>(&f32, 1.f));
auto rhs = std::make_unique<ast::IdentifierExpression>("mat");
Context ctx;
TypeDeterminer td(&ctx);
td.RegisterVariableForTesting(var.get());
ast::BinaryExpression expr(ast::BinaryOp::kMultiply, std::move(lhs),
std::move(rhs));
ASSERT_TRUE(td.DetermineResultType(&expr)) << td.error();
Builder b;
b.push_function(Function{});
ASSERT_TRUE(b.GenerateGlobalVariable(var.get())) << b.error();
EXPECT_EQ(b.GenerateBinaryExpression(&expr), 8) << b.error();
EXPECT_EQ(DumpInstructions(b.types()), R"(%5 = OpTypeFloat 32
%4 = OpTypeVector %5 3
%3 = OpTypeMatrix %4 3
%2 = OpTypePointer Function %3
%1 = OpVariable %2 Function
%6 = OpConstant %5 1
)");
EXPECT_EQ(DumpInstructions(b.functions()[0].instructions()),
R"(%7 = OpLoad %3 %1
%8 = OpMatrixTimesScalar %3 %7 %6
)");
}
TEST_F(BuilderTest, Binary_Multiply_MatrixVector) {
ast::type::F32Type f32;
ast::type::VectorType vec3(&f32, 3);
ast::type::MatrixType mat3(&f32, 3, 3);
auto var = std::make_unique<ast::Variable>(
"mat", ast::StorageClass::kFunction, &mat3);
auto lhs = std::make_unique<ast::IdentifierExpression>("mat");
ast::ExpressionList vals;
vals.push_back(std::make_unique<ast::ScalarConstructorExpression>(
std::make_unique<ast::FloatLiteral>(&f32, 1.f)));
vals.push_back(std::make_unique<ast::ScalarConstructorExpression>(
std::make_unique<ast::FloatLiteral>(&f32, 1.f)));
vals.push_back(std::make_unique<ast::ScalarConstructorExpression>(
std::make_unique<ast::FloatLiteral>(&f32, 1.f)));
auto rhs =
std::make_unique<ast::TypeConstructorExpression>(&vec3, std::move(vals));
Context ctx;
TypeDeterminer td(&ctx);
td.RegisterVariableForTesting(var.get());
ast::BinaryExpression expr(ast::BinaryOp::kMultiply, std::move(lhs),
std::move(rhs));
ASSERT_TRUE(td.DetermineResultType(&expr)) << td.error();
Builder b;
b.push_function(Function{});
ASSERT_TRUE(b.GenerateGlobalVariable(var.get())) << b.error();
EXPECT_EQ(b.GenerateBinaryExpression(&expr), 9) << b.error();
EXPECT_EQ(DumpInstructions(b.types()), R"(%5 = OpTypeFloat 32
%4 = OpTypeVector %5 3
%3 = OpTypeMatrix %4 3
%2 = OpTypePointer Function %3
%1 = OpVariable %2 Function
%7 = OpConstant %5 1
%8 = OpConstantComposite %4 %7 %7 %7
)");
EXPECT_EQ(DumpInstructions(b.functions()[0].instructions()),
R"(%6 = OpLoad %3 %1
%9 = OpMatrixTimesVector %4 %6 %8
)");
}
TEST_F(BuilderTest, Binary_Multiply_VectorMatrix) {
ast::type::F32Type f32;
ast::type::VectorType vec3(&f32, 3);
ast::type::MatrixType mat3(&f32, 3, 3);
auto var = std::make_unique<ast::Variable>(
"mat", ast::StorageClass::kFunction, &mat3);
ast::ExpressionList vals;
vals.push_back(std::make_unique<ast::ScalarConstructorExpression>(
std::make_unique<ast::FloatLiteral>(&f32, 1.f)));
vals.push_back(std::make_unique<ast::ScalarConstructorExpression>(
std::make_unique<ast::FloatLiteral>(&f32, 1.f)));
vals.push_back(std::make_unique<ast::ScalarConstructorExpression>(
std::make_unique<ast::FloatLiteral>(&f32, 1.f)));
auto lhs =
std::make_unique<ast::TypeConstructorExpression>(&vec3, std::move(vals));
auto rhs = std::make_unique<ast::IdentifierExpression>("mat");
Context ctx;
TypeDeterminer td(&ctx);
td.RegisterVariableForTesting(var.get());
ast::BinaryExpression expr(ast::BinaryOp::kMultiply, std::move(lhs),
std::move(rhs));
ASSERT_TRUE(td.DetermineResultType(&expr)) << td.error();
Builder b;
b.push_function(Function{});
ASSERT_TRUE(b.GenerateGlobalVariable(var.get())) << b.error();
EXPECT_EQ(b.GenerateBinaryExpression(&expr), 9) << b.error();
EXPECT_EQ(DumpInstructions(b.types()), R"(%5 = OpTypeFloat 32
%4 = OpTypeVector %5 3
%3 = OpTypeMatrix %4 3
%2 = OpTypePointer Function %3
%1 = OpVariable %2 Function
%6 = OpConstant %5 1
%7 = OpConstantComposite %4 %6 %6 %6
)");
EXPECT_EQ(DumpInstructions(b.functions()[0].instructions()),
R"(%8 = OpLoad %3 %1
%9 = OpVectorTimesMatrix %4 %7 %8
)");
}
TEST_F(BuilderTest, Binary_Multiply_MatrixMatrix) {
ast::type::F32Type f32;
ast::type::VectorType vec3(&f32, 3);
ast::type::MatrixType mat3(&f32, 3, 3);
auto var = std::make_unique<ast::Variable>(
"mat", ast::StorageClass::kFunction, &mat3);
auto lhs = std::make_unique<ast::IdentifierExpression>("mat");
auto rhs = std::make_unique<ast::IdentifierExpression>("mat");
Context ctx;
TypeDeterminer td(&ctx);
td.RegisterVariableForTesting(var.get());
ast::BinaryExpression expr(ast::BinaryOp::kMultiply, std::move(lhs),
std::move(rhs));
ASSERT_TRUE(td.DetermineResultType(&expr)) << td.error();
Builder b;
b.push_function(Function{});
ASSERT_TRUE(b.GenerateGlobalVariable(var.get())) << b.error();
EXPECT_EQ(b.GenerateBinaryExpression(&expr), 8) << b.error();
EXPECT_EQ(DumpInstructions(b.types()), R"(%5 = OpTypeFloat 32
%4 = OpTypeVector %5 3
%3 = OpTypeMatrix %4 3
%2 = OpTypePointer Function %3
%1 = OpVariable %2 Function
)");
EXPECT_EQ(DumpInstructions(b.functions()[0].instructions()),
R"(%6 = OpLoad %3 %1
%7 = OpLoad %3 %1
%8 = OpMatrixTimesMatrix %3 %6 %7
)");
}
} // namespace } // namespace
} // namespace spirv } // namespace spirv
} // namespace writer } // namespace writer