[spirv-writer] Add binary multiplication.
This CL adds binary multiplication generation to the SPIR-V writer. Bug: tint:5 Change-Id: I668d24035e947c51a9737549fd0841a4e8af1331 Reviewed-on: https://dawn-review.googlesource.com/c/tint/+/19700 Reviewed-by: David Neto <dneto@google.com>
This commit is contained in:
parent
460345d993
commit
366b74c364
|
@ -79,6 +79,36 @@ uint32_t pipeline_stage_to_execution_model(ast::PipelineStage stage) {
|
|||
return model;
|
||||
}
|
||||
|
||||
bool is_float_scalar(ast::type::Type* type) {
|
||||
return type->IsF32();
|
||||
}
|
||||
|
||||
bool is_float_matrix(ast::type::Type* type) {
|
||||
return type->IsMatrix() && type->AsMatrix()->type()->IsF32();
|
||||
}
|
||||
|
||||
bool is_float_vector(ast::type::Type* type) {
|
||||
return type->IsVector() && type->AsVector()->type()->IsF32();
|
||||
}
|
||||
|
||||
bool is_float_scalar_or_vector(ast::type::Type* type) {
|
||||
return is_float_scalar(type) || is_float_vector(type);
|
||||
}
|
||||
|
||||
bool is_unsigned_scalar_or_vector(ast::type::Type* type) {
|
||||
return type->IsU32() ||
|
||||
(type->IsVector() && type->AsVector()->type()->IsU32());
|
||||
}
|
||||
|
||||
bool is_signed_scalar_or_vector(ast::type::Type* type) {
|
||||
return type->IsI32() ||
|
||||
(type->IsVector() && type->AsVector()->type()->IsI32());
|
||||
}
|
||||
|
||||
bool is_integer_scalar_or_vector(ast::type::Type* type) {
|
||||
return is_unsigned_scalar_or_vector(type) || is_signed_scalar_or_vector(type);
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
Builder::Builder() : scope_stack_({}) {}
|
||||
|
@ -569,12 +599,9 @@ uint32_t Builder::GenerateBinaryExpression(ast::BinaryExpression* expr) {
|
|||
// Handle int and float and the vectors of those types. Other types
|
||||
// should have been rejected by validation.
|
||||
auto* lhs_type = expr->lhs()->result_type();
|
||||
bool lhs_is_float_or_vec =
|
||||
lhs_type->IsF32() ||
|
||||
(lhs_type->IsVector() && lhs_type->AsVector()->type()->IsF32());
|
||||
bool lhs_is_unsigned =
|
||||
lhs_type->IsU32() ||
|
||||
(lhs_type->IsVector() && lhs_type->AsVector()->type()->IsU32());
|
||||
auto* rhs_type = expr->rhs()->result_type();
|
||||
bool lhs_is_float_or_vec = is_float_scalar_or_vector(lhs_type);
|
||||
bool lhs_is_unsigned = is_unsigned_scalar_or_vector(lhs_type);
|
||||
|
||||
spv::Op op = spv::Op::OpNop;
|
||||
if (expr->IsAnd()) {
|
||||
|
@ -631,6 +658,45 @@ uint32_t Builder::GenerateBinaryExpression(ast::BinaryExpression* expr) {
|
|||
} else {
|
||||
op = spv::Op::OpSMod;
|
||||
}
|
||||
} else if (expr->IsMultiply()) {
|
||||
if (is_integer_scalar_or_vector(lhs_type)) {
|
||||
// If the left hand side is an integer then this _has_ to be OpIMul as
|
||||
// there there is no other integer multiplication.
|
||||
op = spv::Op::OpIMul;
|
||||
} else if (is_float_scalar(lhs_type) && is_float_scalar(rhs_type)) {
|
||||
// Float scalars multiply with OpFMul
|
||||
op = spv::Op::OpFMul;
|
||||
} else if (is_float_vector(lhs_type) && is_float_vector(rhs_type)) {
|
||||
// Float vectors must be validated to be the same size and then use OpFMul
|
||||
op = spv::Op::OpFMul;
|
||||
} else if (is_float_scalar(lhs_type) && is_float_vector(rhs_type)) {
|
||||
// Scalar * Vector we need to flip lhs and rhs types
|
||||
// because OpVectorTimesScalar expects <vector>, <scalar>
|
||||
std::swap(lhs_id, rhs_id);
|
||||
op = spv::Op::OpVectorTimesScalar;
|
||||
} else if (is_float_vector(lhs_type) && is_float_scalar(rhs_type)) {
|
||||
// float vector * scalar
|
||||
op = spv::Op::OpVectorTimesScalar;
|
||||
} else if (is_float_scalar(lhs_type) && is_float_matrix(rhs_type)) {
|
||||
// Scalar * Matrix we need to flip lhs and rhs types because
|
||||
// OpMatrixTimesScalar expects <matrix>, <scalar>
|
||||
std::swap(lhs_id, rhs_id);
|
||||
op = spv::Op::OpMatrixTimesScalar;
|
||||
} else if (is_float_matrix(lhs_type) && is_float_scalar(rhs_type)) {
|
||||
// float matrix * scalar
|
||||
op = spv::Op::OpMatrixTimesScalar;
|
||||
} else if (is_float_vector(lhs_type) && is_float_matrix(rhs_type)) {
|
||||
// float vector * matrix
|
||||
op = spv::Op::OpVectorTimesMatrix;
|
||||
} else if (is_float_matrix(lhs_type) && is_float_vector(rhs_type)) {
|
||||
// float matrix * vector
|
||||
op = spv::Op::OpMatrixTimesVector;
|
||||
} else if (is_float_matrix(lhs_type) && is_float_matrix(rhs_type)) {
|
||||
// float matrix * matrix
|
||||
op = spv::Op::OpMatrixTimesMatrix;
|
||||
} else {
|
||||
return 0;
|
||||
}
|
||||
} else if (expr->IsNotEqual()) {
|
||||
op = lhs_is_float_or_vec ? spv::Op::OpFOrdNotEqual : spv::Op::OpINotEqual;
|
||||
} else if (expr->IsOr()) {
|
||||
|
|
|
@ -17,10 +17,12 @@
|
|||
#include "gtest/gtest.h"
|
||||
#include "src/ast/binary_expression.h"
|
||||
#include "src/ast/float_literal.h"
|
||||
#include "src/ast/identifier_expression.h"
|
||||
#include "src/ast/int_literal.h"
|
||||
#include "src/ast/scalar_constructor_expression.h"
|
||||
#include "src/ast/type/f32_type.h"
|
||||
#include "src/ast/type/i32_type.h"
|
||||
#include "src/ast/type/matrix_type.h"
|
||||
#include "src/ast/type/u32_type.h"
|
||||
#include "src/ast/type/vector_type.h"
|
||||
#include "src/ast/type_constructor_expression.h"
|
||||
|
@ -65,7 +67,7 @@ TEST_P(BinaryArithSignedIntegerTest, Scalar) {
|
|||
Builder b;
|
||||
b.push_function(Function{});
|
||||
|
||||
ASSERT_EQ(b.GenerateBinaryExpression(&expr), 4) << b.error();
|
||||
EXPECT_EQ(b.GenerateBinaryExpression(&expr), 4) << b.error();
|
||||
EXPECT_EQ(DumpInstructions(b.types()), R"(%1 = OpTypeInt 32 1
|
||||
%2 = OpConstant %1 3
|
||||
%3 = OpConstant %1 4
|
||||
|
@ -108,7 +110,7 @@ TEST_P(BinaryArithSignedIntegerTest, Vector) {
|
|||
Builder b;
|
||||
b.push_function(Function{});
|
||||
|
||||
ASSERT_EQ(b.GenerateBinaryExpression(&expr), 5) << b.error();
|
||||
EXPECT_EQ(b.GenerateBinaryExpression(&expr), 5) << b.error();
|
||||
EXPECT_EQ(DumpInstructions(b.types()), R"(%2 = OpTypeInt 32 1
|
||||
%1 = OpTypeVector %2 3
|
||||
%3 = OpConstant %2 1
|
||||
|
@ -125,6 +127,7 @@ INSTANTIATE_TEST_SUITE_P(
|
|||
BinaryData{ast::BinaryOp::kAnd, "OpBitwiseAnd"},
|
||||
BinaryData{ast::BinaryOp::kDivide, "OpSDiv"},
|
||||
BinaryData{ast::BinaryOp::kModulo, "OpSMod"},
|
||||
BinaryData{ast::BinaryOp::kMultiply, "OpIMul"},
|
||||
BinaryData{ast::BinaryOp::kOr, "OpBitwiseOr"},
|
||||
BinaryData{ast::BinaryOp::kShiftLeft, "OpShiftLeftLogical"},
|
||||
BinaryData{ast::BinaryOp::kShiftRight, "OpShiftRightLogical"},
|
||||
|
@ -152,7 +155,7 @@ TEST_P(BinaryArithUnsignedIntegerTest, Scalar) {
|
|||
Builder b;
|
||||
b.push_function(Function{});
|
||||
|
||||
ASSERT_EQ(b.GenerateBinaryExpression(&expr), 4) << b.error();
|
||||
EXPECT_EQ(b.GenerateBinaryExpression(&expr), 4) << b.error();
|
||||
EXPECT_EQ(DumpInstructions(b.types()), R"(%1 = OpTypeInt 32 0
|
||||
%2 = OpConstant %1 3
|
||||
%3 = OpConstant %1 4
|
||||
|
@ -195,7 +198,7 @@ TEST_P(BinaryArithUnsignedIntegerTest, Vector) {
|
|||
Builder b;
|
||||
b.push_function(Function{});
|
||||
|
||||
ASSERT_EQ(b.GenerateBinaryExpression(&expr), 5) << b.error();
|
||||
EXPECT_EQ(b.GenerateBinaryExpression(&expr), 5) << b.error();
|
||||
EXPECT_EQ(DumpInstructions(b.types()), R"(%2 = OpTypeInt 32 0
|
||||
%1 = OpTypeVector %2 3
|
||||
%3 = OpConstant %2 1
|
||||
|
@ -212,6 +215,7 @@ INSTANTIATE_TEST_SUITE_P(
|
|||
BinaryData{ast::BinaryOp::kAnd, "OpBitwiseAnd"},
|
||||
BinaryData{ast::BinaryOp::kDivide, "OpUDiv"},
|
||||
BinaryData{ast::BinaryOp::kModulo, "OpUMod"},
|
||||
BinaryData{ast::BinaryOp::kMultiply, "OpIMul"},
|
||||
BinaryData{ast::BinaryOp::kOr, "OpBitwiseOr"},
|
||||
BinaryData{ast::BinaryOp::kShiftLeft, "OpShiftLeftLogical"},
|
||||
BinaryData{ast::BinaryOp::kShiftRight, "OpShiftRightLogical"},
|
||||
|
@ -239,7 +243,7 @@ TEST_P(BinaryArithFloatTest, Scalar) {
|
|||
Builder b;
|
||||
b.push_function(Function{});
|
||||
|
||||
ASSERT_EQ(b.GenerateBinaryExpression(&expr), 4) << b.error();
|
||||
EXPECT_EQ(b.GenerateBinaryExpression(&expr), 4) << b.error();
|
||||
EXPECT_EQ(DumpInstructions(b.types()), R"(%1 = OpTypeFloat 32
|
||||
%2 = OpConstant %1 3.20000005
|
||||
%3 = OpConstant %1 4.5
|
||||
|
@ -283,7 +287,7 @@ TEST_P(BinaryArithFloatTest, Vector) {
|
|||
Builder b;
|
||||
b.push_function(Function{});
|
||||
|
||||
ASSERT_EQ(b.GenerateBinaryExpression(&expr), 5) << b.error();
|
||||
EXPECT_EQ(b.GenerateBinaryExpression(&expr), 5) << b.error();
|
||||
EXPECT_EQ(DumpInstructions(b.types()), R"(%2 = OpTypeFloat 32
|
||||
%1 = OpTypeVector %2 3
|
||||
%3 = OpConstant %2 1
|
||||
|
@ -298,6 +302,7 @@ INSTANTIATE_TEST_SUITE_P(
|
|||
testing::Values(BinaryData{ast::BinaryOp::kAdd, "OpFAdd"},
|
||||
BinaryData{ast::BinaryOp::kDivide, "OpFDiv"},
|
||||
BinaryData{ast::BinaryOp::kModulo, "OpFMod"},
|
||||
BinaryData{ast::BinaryOp::kMultiply, "OpFMul"},
|
||||
BinaryData{ast::BinaryOp::kSubtract, "OpFSub"}));
|
||||
|
||||
using BinaryCompareUnsignedIntegerTest = testing::TestWithParam<BinaryData>;
|
||||
|
@ -320,7 +325,7 @@ TEST_P(BinaryCompareUnsignedIntegerTest, Scalar) {
|
|||
Builder b;
|
||||
b.push_function(Function{});
|
||||
|
||||
ASSERT_EQ(b.GenerateBinaryExpression(&expr), 4) << b.error();
|
||||
EXPECT_EQ(b.GenerateBinaryExpression(&expr), 4) << b.error();
|
||||
EXPECT_EQ(DumpInstructions(b.types()), R"(%1 = OpTypeInt 32 0
|
||||
%2 = OpConstant %1 3
|
||||
%3 = OpConstant %1 4
|
||||
|
@ -365,7 +370,7 @@ TEST_P(BinaryCompareUnsignedIntegerTest, Vector) {
|
|||
Builder b;
|
||||
b.push_function(Function{});
|
||||
|
||||
ASSERT_EQ(b.GenerateBinaryExpression(&expr), 5) << b.error();
|
||||
EXPECT_EQ(b.GenerateBinaryExpression(&expr), 5) << b.error();
|
||||
EXPECT_EQ(DumpInstructions(b.types()), R"(%2 = OpTypeInt 32 0
|
||||
%1 = OpTypeVector %2 3
|
||||
%3 = OpConstant %2 1
|
||||
|
@ -407,7 +412,7 @@ TEST_P(BinaryCompareSignedIntegerTest, Scalar) {
|
|||
Builder b;
|
||||
b.push_function(Function{});
|
||||
|
||||
ASSERT_EQ(b.GenerateBinaryExpression(&expr), 4) << b.error();
|
||||
EXPECT_EQ(b.GenerateBinaryExpression(&expr), 4) << b.error();
|
||||
EXPECT_EQ(DumpInstructions(b.types()), R"(%1 = OpTypeInt 32 1
|
||||
%2 = OpConstant %1 3
|
||||
%3 = OpConstant %1 4
|
||||
|
@ -452,7 +457,7 @@ TEST_P(BinaryCompareSignedIntegerTest, Vector) {
|
|||
Builder b;
|
||||
b.push_function(Function{});
|
||||
|
||||
ASSERT_EQ(b.GenerateBinaryExpression(&expr), 5) << b.error();
|
||||
EXPECT_EQ(b.GenerateBinaryExpression(&expr), 5) << b.error();
|
||||
EXPECT_EQ(DumpInstructions(b.types()), R"(%2 = OpTypeInt 32 1
|
||||
%1 = OpTypeVector %2 3
|
||||
%3 = OpConstant %2 1
|
||||
|
@ -494,7 +499,7 @@ TEST_P(BinaryCompareFloatTest, Scalar) {
|
|||
Builder b;
|
||||
b.push_function(Function{});
|
||||
|
||||
ASSERT_EQ(b.GenerateBinaryExpression(&expr), 4) << b.error();
|
||||
EXPECT_EQ(b.GenerateBinaryExpression(&expr), 4) << b.error();
|
||||
EXPECT_EQ(DumpInstructions(b.types()), R"(%1 = OpTypeFloat 32
|
||||
%2 = OpConstant %1 3.20000005
|
||||
%3 = OpConstant %1 4.5
|
||||
|
@ -539,7 +544,7 @@ TEST_P(BinaryCompareFloatTest, Vector) {
|
|||
Builder b;
|
||||
b.push_function(Function{});
|
||||
|
||||
ASSERT_EQ(b.GenerateBinaryExpression(&expr), 5) << b.error();
|
||||
EXPECT_EQ(b.GenerateBinaryExpression(&expr), 5) << b.error();
|
||||
EXPECT_EQ(DumpInstructions(b.types()), R"(%2 = OpTypeFloat 32
|
||||
%1 = OpTypeVector %2 3
|
||||
%3 = OpConstant %2 1
|
||||
|
@ -561,6 +566,288 @@ INSTANTIATE_TEST_SUITE_P(
|
|||
BinaryData{ast::BinaryOp::kLessThanEqual, "OpFOrdLessThanEqual"},
|
||||
BinaryData{ast::BinaryOp::kNotEqual, "OpFOrdNotEqual"}));
|
||||
|
||||
TEST_F(BuilderTest, Binary_Multiply_VectorScalar) {
|
||||
ast::type::F32Type f32;
|
||||
ast::type::VectorType vec3(&f32, 3);
|
||||
|
||||
ast::ExpressionList vals;
|
||||
vals.push_back(std::make_unique<ast::ScalarConstructorExpression>(
|
||||
std::make_unique<ast::FloatLiteral>(&f32, 1.f)));
|
||||
vals.push_back(std::make_unique<ast::ScalarConstructorExpression>(
|
||||
std::make_unique<ast::FloatLiteral>(&f32, 1.f)));
|
||||
vals.push_back(std::make_unique<ast::ScalarConstructorExpression>(
|
||||
std::make_unique<ast::FloatLiteral>(&f32, 1.f)));
|
||||
auto lhs =
|
||||
std::make_unique<ast::TypeConstructorExpression>(&vec3, std::move(vals));
|
||||
|
||||
auto rhs = std::make_unique<ast::ScalarConstructorExpression>(
|
||||
std::make_unique<ast::FloatLiteral>(&f32, 1.f));
|
||||
|
||||
Context ctx;
|
||||
TypeDeterminer td(&ctx);
|
||||
|
||||
ast::BinaryExpression expr(ast::BinaryOp::kMultiply, std::move(lhs),
|
||||
std::move(rhs));
|
||||
|
||||
ASSERT_TRUE(td.DetermineResultType(&expr)) << td.error();
|
||||
|
||||
Builder b;
|
||||
b.push_function(Function{});
|
||||
|
||||
EXPECT_EQ(b.GenerateBinaryExpression(&expr), 5) << b.error();
|
||||
EXPECT_EQ(DumpInstructions(b.types()), R"(%2 = OpTypeFloat 32
|
||||
%1 = OpTypeVector %2 3
|
||||
%3 = OpConstant %2 1
|
||||
%4 = OpConstantComposite %1 %3 %3 %3
|
||||
)");
|
||||
EXPECT_EQ(DumpInstructions(b.functions()[0].instructions()),
|
||||
"%5 = OpVectorTimesScalar %1 %4 %3\n");
|
||||
}
|
||||
|
||||
TEST_F(BuilderTest, Binary_Multiply_ScalarVector) {
|
||||
ast::type::F32Type f32;
|
||||
ast::type::VectorType vec3(&f32, 3);
|
||||
|
||||
auto lhs = std::make_unique<ast::ScalarConstructorExpression>(
|
||||
std::make_unique<ast::FloatLiteral>(&f32, 1.f));
|
||||
|
||||
ast::ExpressionList vals;
|
||||
vals.push_back(std::make_unique<ast::ScalarConstructorExpression>(
|
||||
std::make_unique<ast::FloatLiteral>(&f32, 1.f)));
|
||||
vals.push_back(std::make_unique<ast::ScalarConstructorExpression>(
|
||||
std::make_unique<ast::FloatLiteral>(&f32, 1.f)));
|
||||
vals.push_back(std::make_unique<ast::ScalarConstructorExpression>(
|
||||
std::make_unique<ast::FloatLiteral>(&f32, 1.f)));
|
||||
auto rhs =
|
||||
std::make_unique<ast::TypeConstructorExpression>(&vec3, std::move(vals));
|
||||
|
||||
Context ctx;
|
||||
TypeDeterminer td(&ctx);
|
||||
|
||||
ast::BinaryExpression expr(ast::BinaryOp::kMultiply, std::move(lhs),
|
||||
std::move(rhs));
|
||||
|
||||
ASSERT_TRUE(td.DetermineResultType(&expr)) << td.error();
|
||||
|
||||
Builder b;
|
||||
b.push_function(Function{});
|
||||
|
||||
EXPECT_EQ(b.GenerateBinaryExpression(&expr), 5) << b.error();
|
||||
EXPECT_EQ(DumpInstructions(b.types()), R"(%1 = OpTypeFloat 32
|
||||
%2 = OpConstant %1 1
|
||||
%3 = OpTypeVector %1 3
|
||||
%4 = OpConstantComposite %3 %2 %2 %2
|
||||
)");
|
||||
EXPECT_EQ(DumpInstructions(b.functions()[0].instructions()),
|
||||
"%5 = OpVectorTimesScalar %3 %4 %2\n");
|
||||
}
|
||||
|
||||
TEST_F(BuilderTest, Binary_Multiply_MatrixScalar) {
|
||||
ast::type::F32Type f32;
|
||||
ast::type::MatrixType mat3(&f32, 3, 3);
|
||||
|
||||
auto var = std::make_unique<ast::Variable>(
|
||||
"mat", ast::StorageClass::kFunction, &mat3);
|
||||
auto lhs = std::make_unique<ast::IdentifierExpression>("mat");
|
||||
auto rhs = std::make_unique<ast::ScalarConstructorExpression>(
|
||||
std::make_unique<ast::FloatLiteral>(&f32, 1.f));
|
||||
|
||||
Context ctx;
|
||||
TypeDeterminer td(&ctx);
|
||||
td.RegisterVariableForTesting(var.get());
|
||||
|
||||
ast::BinaryExpression expr(ast::BinaryOp::kMultiply, std::move(lhs),
|
||||
std::move(rhs));
|
||||
|
||||
ASSERT_TRUE(td.DetermineResultType(&expr)) << td.error();
|
||||
|
||||
Builder b;
|
||||
b.push_function(Function{});
|
||||
ASSERT_TRUE(b.GenerateGlobalVariable(var.get())) << b.error();
|
||||
|
||||
EXPECT_EQ(b.GenerateBinaryExpression(&expr), 8) << b.error();
|
||||
EXPECT_EQ(DumpInstructions(b.types()), R"(%5 = OpTypeFloat 32
|
||||
%4 = OpTypeVector %5 3
|
||||
%3 = OpTypeMatrix %4 3
|
||||
%2 = OpTypePointer Function %3
|
||||
%1 = OpVariable %2 Function
|
||||
%7 = OpConstant %5 1
|
||||
)");
|
||||
EXPECT_EQ(DumpInstructions(b.functions()[0].instructions()),
|
||||
R"(%6 = OpLoad %3 %1
|
||||
%8 = OpMatrixTimesScalar %3 %6 %7
|
||||
)");
|
||||
}
|
||||
|
||||
TEST_F(BuilderTest, Binary_Multiply_ScalarMatrix) {
|
||||
ast::type::F32Type f32;
|
||||
ast::type::MatrixType mat3(&f32, 3, 3);
|
||||
|
||||
auto var = std::make_unique<ast::Variable>(
|
||||
"mat", ast::StorageClass::kFunction, &mat3);
|
||||
auto lhs = std::make_unique<ast::ScalarConstructorExpression>(
|
||||
std::make_unique<ast::FloatLiteral>(&f32, 1.f));
|
||||
auto rhs = std::make_unique<ast::IdentifierExpression>("mat");
|
||||
|
||||
Context ctx;
|
||||
TypeDeterminer td(&ctx);
|
||||
td.RegisterVariableForTesting(var.get());
|
||||
|
||||
ast::BinaryExpression expr(ast::BinaryOp::kMultiply, std::move(lhs),
|
||||
std::move(rhs));
|
||||
|
||||
ASSERT_TRUE(td.DetermineResultType(&expr)) << td.error();
|
||||
|
||||
Builder b;
|
||||
b.push_function(Function{});
|
||||
ASSERT_TRUE(b.GenerateGlobalVariable(var.get())) << b.error();
|
||||
|
||||
EXPECT_EQ(b.GenerateBinaryExpression(&expr), 8) << b.error();
|
||||
EXPECT_EQ(DumpInstructions(b.types()), R"(%5 = OpTypeFloat 32
|
||||
%4 = OpTypeVector %5 3
|
||||
%3 = OpTypeMatrix %4 3
|
||||
%2 = OpTypePointer Function %3
|
||||
%1 = OpVariable %2 Function
|
||||
%6 = OpConstant %5 1
|
||||
)");
|
||||
EXPECT_EQ(DumpInstructions(b.functions()[0].instructions()),
|
||||
R"(%7 = OpLoad %3 %1
|
||||
%8 = OpMatrixTimesScalar %3 %7 %6
|
||||
)");
|
||||
}
|
||||
|
||||
TEST_F(BuilderTest, Binary_Multiply_MatrixVector) {
|
||||
ast::type::F32Type f32;
|
||||
ast::type::VectorType vec3(&f32, 3);
|
||||
ast::type::MatrixType mat3(&f32, 3, 3);
|
||||
|
||||
auto var = std::make_unique<ast::Variable>(
|
||||
"mat", ast::StorageClass::kFunction, &mat3);
|
||||
auto lhs = std::make_unique<ast::IdentifierExpression>("mat");
|
||||
|
||||
ast::ExpressionList vals;
|
||||
vals.push_back(std::make_unique<ast::ScalarConstructorExpression>(
|
||||
std::make_unique<ast::FloatLiteral>(&f32, 1.f)));
|
||||
vals.push_back(std::make_unique<ast::ScalarConstructorExpression>(
|
||||
std::make_unique<ast::FloatLiteral>(&f32, 1.f)));
|
||||
vals.push_back(std::make_unique<ast::ScalarConstructorExpression>(
|
||||
std::make_unique<ast::FloatLiteral>(&f32, 1.f)));
|
||||
auto rhs =
|
||||
std::make_unique<ast::TypeConstructorExpression>(&vec3, std::move(vals));
|
||||
|
||||
Context ctx;
|
||||
TypeDeterminer td(&ctx);
|
||||
td.RegisterVariableForTesting(var.get());
|
||||
|
||||
ast::BinaryExpression expr(ast::BinaryOp::kMultiply, std::move(lhs),
|
||||
std::move(rhs));
|
||||
|
||||
ASSERT_TRUE(td.DetermineResultType(&expr)) << td.error();
|
||||
|
||||
Builder b;
|
||||
b.push_function(Function{});
|
||||
ASSERT_TRUE(b.GenerateGlobalVariable(var.get())) << b.error();
|
||||
|
||||
EXPECT_EQ(b.GenerateBinaryExpression(&expr), 9) << b.error();
|
||||
EXPECT_EQ(DumpInstructions(b.types()), R"(%5 = OpTypeFloat 32
|
||||
%4 = OpTypeVector %5 3
|
||||
%3 = OpTypeMatrix %4 3
|
||||
%2 = OpTypePointer Function %3
|
||||
%1 = OpVariable %2 Function
|
||||
%7 = OpConstant %5 1
|
||||
%8 = OpConstantComposite %4 %7 %7 %7
|
||||
)");
|
||||
EXPECT_EQ(DumpInstructions(b.functions()[0].instructions()),
|
||||
R"(%6 = OpLoad %3 %1
|
||||
%9 = OpMatrixTimesVector %4 %6 %8
|
||||
)");
|
||||
}
|
||||
|
||||
TEST_F(BuilderTest, Binary_Multiply_VectorMatrix) {
|
||||
ast::type::F32Type f32;
|
||||
ast::type::VectorType vec3(&f32, 3);
|
||||
ast::type::MatrixType mat3(&f32, 3, 3);
|
||||
|
||||
auto var = std::make_unique<ast::Variable>(
|
||||
"mat", ast::StorageClass::kFunction, &mat3);
|
||||
|
||||
ast::ExpressionList vals;
|
||||
vals.push_back(std::make_unique<ast::ScalarConstructorExpression>(
|
||||
std::make_unique<ast::FloatLiteral>(&f32, 1.f)));
|
||||
vals.push_back(std::make_unique<ast::ScalarConstructorExpression>(
|
||||
std::make_unique<ast::FloatLiteral>(&f32, 1.f)));
|
||||
vals.push_back(std::make_unique<ast::ScalarConstructorExpression>(
|
||||
std::make_unique<ast::FloatLiteral>(&f32, 1.f)));
|
||||
auto lhs =
|
||||
std::make_unique<ast::TypeConstructorExpression>(&vec3, std::move(vals));
|
||||
|
||||
auto rhs = std::make_unique<ast::IdentifierExpression>("mat");
|
||||
|
||||
Context ctx;
|
||||
TypeDeterminer td(&ctx);
|
||||
td.RegisterVariableForTesting(var.get());
|
||||
|
||||
ast::BinaryExpression expr(ast::BinaryOp::kMultiply, std::move(lhs),
|
||||
std::move(rhs));
|
||||
|
||||
ASSERT_TRUE(td.DetermineResultType(&expr)) << td.error();
|
||||
|
||||
Builder b;
|
||||
b.push_function(Function{});
|
||||
ASSERT_TRUE(b.GenerateGlobalVariable(var.get())) << b.error();
|
||||
|
||||
EXPECT_EQ(b.GenerateBinaryExpression(&expr), 9) << b.error();
|
||||
EXPECT_EQ(DumpInstructions(b.types()), R"(%5 = OpTypeFloat 32
|
||||
%4 = OpTypeVector %5 3
|
||||
%3 = OpTypeMatrix %4 3
|
||||
%2 = OpTypePointer Function %3
|
||||
%1 = OpVariable %2 Function
|
||||
%6 = OpConstant %5 1
|
||||
%7 = OpConstantComposite %4 %6 %6 %6
|
||||
)");
|
||||
EXPECT_EQ(DumpInstructions(b.functions()[0].instructions()),
|
||||
R"(%8 = OpLoad %3 %1
|
||||
%9 = OpVectorTimesMatrix %4 %7 %8
|
||||
)");
|
||||
}
|
||||
|
||||
TEST_F(BuilderTest, Binary_Multiply_MatrixMatrix) {
|
||||
ast::type::F32Type f32;
|
||||
ast::type::VectorType vec3(&f32, 3);
|
||||
ast::type::MatrixType mat3(&f32, 3, 3);
|
||||
|
||||
auto var = std::make_unique<ast::Variable>(
|
||||
"mat", ast::StorageClass::kFunction, &mat3);
|
||||
auto lhs = std::make_unique<ast::IdentifierExpression>("mat");
|
||||
auto rhs = std::make_unique<ast::IdentifierExpression>("mat");
|
||||
|
||||
Context ctx;
|
||||
TypeDeterminer td(&ctx);
|
||||
td.RegisterVariableForTesting(var.get());
|
||||
|
||||
ast::BinaryExpression expr(ast::BinaryOp::kMultiply, std::move(lhs),
|
||||
std::move(rhs));
|
||||
|
||||
ASSERT_TRUE(td.DetermineResultType(&expr)) << td.error();
|
||||
|
||||
Builder b;
|
||||
b.push_function(Function{});
|
||||
ASSERT_TRUE(b.GenerateGlobalVariable(var.get())) << b.error();
|
||||
|
||||
EXPECT_EQ(b.GenerateBinaryExpression(&expr), 8) << b.error();
|
||||
EXPECT_EQ(DumpInstructions(b.types()), R"(%5 = OpTypeFloat 32
|
||||
%4 = OpTypeVector %5 3
|
||||
%3 = OpTypeMatrix %4 3
|
||||
%2 = OpTypePointer Function %3
|
||||
%1 = OpVariable %2 Function
|
||||
)");
|
||||
EXPECT_EQ(DumpInstructions(b.functions()[0].instructions()),
|
||||
R"(%6 = OpLoad %3 %1
|
||||
%7 = OpLoad %3 %1
|
||||
%8 = OpMatrixTimesMatrix %3 %6 %7
|
||||
)");
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace spirv
|
||||
} // namespace writer
|
||||
|
|
Loading…
Reference in New Issue