diff --git a/src/writer/spirv/builder.cc b/src/writer/spirv/builder.cc index c388e70f40..e9e45d7c04 100644 --- a/src/writer/spirv/builder.cc +++ b/src/writer/spirv/builder.cc @@ -79,6 +79,36 @@ uint32_t pipeline_stage_to_execution_model(ast::PipelineStage stage) { return model; } +bool is_float_scalar(ast::type::Type* type) { + return type->IsF32(); +} + +bool is_float_matrix(ast::type::Type* type) { + return type->IsMatrix() && type->AsMatrix()->type()->IsF32(); +} + +bool is_float_vector(ast::type::Type* type) { + return type->IsVector() && type->AsVector()->type()->IsF32(); +} + +bool is_float_scalar_or_vector(ast::type::Type* type) { + return is_float_scalar(type) || is_float_vector(type); +} + +bool is_unsigned_scalar_or_vector(ast::type::Type* type) { + return type->IsU32() || + (type->IsVector() && type->AsVector()->type()->IsU32()); +} + +bool is_signed_scalar_or_vector(ast::type::Type* type) { + return type->IsI32() || + (type->IsVector() && type->AsVector()->type()->IsI32()); +} + +bool is_integer_scalar_or_vector(ast::type::Type* type) { + return is_unsigned_scalar_or_vector(type) || is_signed_scalar_or_vector(type); +} + } // namespace Builder::Builder() : scope_stack_({}) {} @@ -569,12 +599,9 @@ uint32_t Builder::GenerateBinaryExpression(ast::BinaryExpression* expr) { // Handle int and float and the vectors of those types. Other types // should have been rejected by validation. auto* lhs_type = expr->lhs()->result_type(); - bool lhs_is_float_or_vec = - lhs_type->IsF32() || - (lhs_type->IsVector() && lhs_type->AsVector()->type()->IsF32()); - bool lhs_is_unsigned = - lhs_type->IsU32() || - (lhs_type->IsVector() && lhs_type->AsVector()->type()->IsU32()); + auto* rhs_type = expr->rhs()->result_type(); + bool lhs_is_float_or_vec = is_float_scalar_or_vector(lhs_type); + bool lhs_is_unsigned = is_unsigned_scalar_or_vector(lhs_type); spv::Op op = spv::Op::OpNop; if (expr->IsAnd()) { @@ -631,6 +658,45 @@ uint32_t Builder::GenerateBinaryExpression(ast::BinaryExpression* expr) { } else { op = spv::Op::OpSMod; } + } else if (expr->IsMultiply()) { + if (is_integer_scalar_or_vector(lhs_type)) { + // If the left hand side is an integer then this _has_ to be OpIMul as + // there there is no other integer multiplication. + op = spv::Op::OpIMul; + } else if (is_float_scalar(lhs_type) && is_float_scalar(rhs_type)) { + // Float scalars multiply with OpFMul + op = spv::Op::OpFMul; + } else if (is_float_vector(lhs_type) && is_float_vector(rhs_type)) { + // Float vectors must be validated to be the same size and then use OpFMul + op = spv::Op::OpFMul; + } else if (is_float_scalar(lhs_type) && is_float_vector(rhs_type)) { + // Scalar * Vector we need to flip lhs and rhs types + // because OpVectorTimesScalar expects , + std::swap(lhs_id, rhs_id); + op = spv::Op::OpVectorTimesScalar; + } else if (is_float_vector(lhs_type) && is_float_scalar(rhs_type)) { + // float vector * scalar + op = spv::Op::OpVectorTimesScalar; + } else if (is_float_scalar(lhs_type) && is_float_matrix(rhs_type)) { + // Scalar * Matrix we need to flip lhs and rhs types because + // OpMatrixTimesScalar expects , + std::swap(lhs_id, rhs_id); + op = spv::Op::OpMatrixTimesScalar; + } else if (is_float_matrix(lhs_type) && is_float_scalar(rhs_type)) { + // float matrix * scalar + op = spv::Op::OpMatrixTimesScalar; + } else if (is_float_vector(lhs_type) && is_float_matrix(rhs_type)) { + // float vector * matrix + op = spv::Op::OpVectorTimesMatrix; + } else if (is_float_matrix(lhs_type) && is_float_vector(rhs_type)) { + // float matrix * vector + op = spv::Op::OpMatrixTimesVector; + } else if (is_float_matrix(lhs_type) && is_float_matrix(rhs_type)) { + // float matrix * matrix + op = spv::Op::OpMatrixTimesMatrix; + } else { + return 0; + } } else if (expr->IsNotEqual()) { op = lhs_is_float_or_vec ? spv::Op::OpFOrdNotEqual : spv::Op::OpINotEqual; } else if (expr->IsOr()) { diff --git a/src/writer/spirv/builder_binary_expression_test.cc b/src/writer/spirv/builder_binary_expression_test.cc index 97b3569772..fe18196c13 100644 --- a/src/writer/spirv/builder_binary_expression_test.cc +++ b/src/writer/spirv/builder_binary_expression_test.cc @@ -17,10 +17,12 @@ #include "gtest/gtest.h" #include "src/ast/binary_expression.h" #include "src/ast/float_literal.h" +#include "src/ast/identifier_expression.h" #include "src/ast/int_literal.h" #include "src/ast/scalar_constructor_expression.h" #include "src/ast/type/f32_type.h" #include "src/ast/type/i32_type.h" +#include "src/ast/type/matrix_type.h" #include "src/ast/type/u32_type.h" #include "src/ast/type/vector_type.h" #include "src/ast/type_constructor_expression.h" @@ -65,7 +67,7 @@ TEST_P(BinaryArithSignedIntegerTest, Scalar) { Builder b; b.push_function(Function{}); - ASSERT_EQ(b.GenerateBinaryExpression(&expr), 4) << b.error(); + EXPECT_EQ(b.GenerateBinaryExpression(&expr), 4) << b.error(); EXPECT_EQ(DumpInstructions(b.types()), R"(%1 = OpTypeInt 32 1 %2 = OpConstant %1 3 %3 = OpConstant %1 4 @@ -108,7 +110,7 @@ TEST_P(BinaryArithSignedIntegerTest, Vector) { Builder b; b.push_function(Function{}); - ASSERT_EQ(b.GenerateBinaryExpression(&expr), 5) << b.error(); + EXPECT_EQ(b.GenerateBinaryExpression(&expr), 5) << b.error(); EXPECT_EQ(DumpInstructions(b.types()), R"(%2 = OpTypeInt 32 1 %1 = OpTypeVector %2 3 %3 = OpConstant %2 1 @@ -125,6 +127,7 @@ INSTANTIATE_TEST_SUITE_P( BinaryData{ast::BinaryOp::kAnd, "OpBitwiseAnd"}, BinaryData{ast::BinaryOp::kDivide, "OpSDiv"}, BinaryData{ast::BinaryOp::kModulo, "OpSMod"}, + BinaryData{ast::BinaryOp::kMultiply, "OpIMul"}, BinaryData{ast::BinaryOp::kOr, "OpBitwiseOr"}, BinaryData{ast::BinaryOp::kShiftLeft, "OpShiftLeftLogical"}, BinaryData{ast::BinaryOp::kShiftRight, "OpShiftRightLogical"}, @@ -152,7 +155,7 @@ TEST_P(BinaryArithUnsignedIntegerTest, Scalar) { Builder b; b.push_function(Function{}); - ASSERT_EQ(b.GenerateBinaryExpression(&expr), 4) << b.error(); + EXPECT_EQ(b.GenerateBinaryExpression(&expr), 4) << b.error(); EXPECT_EQ(DumpInstructions(b.types()), R"(%1 = OpTypeInt 32 0 %2 = OpConstant %1 3 %3 = OpConstant %1 4 @@ -195,7 +198,7 @@ TEST_P(BinaryArithUnsignedIntegerTest, Vector) { Builder b; b.push_function(Function{}); - ASSERT_EQ(b.GenerateBinaryExpression(&expr), 5) << b.error(); + EXPECT_EQ(b.GenerateBinaryExpression(&expr), 5) << b.error(); EXPECT_EQ(DumpInstructions(b.types()), R"(%2 = OpTypeInt 32 0 %1 = OpTypeVector %2 3 %3 = OpConstant %2 1 @@ -212,6 +215,7 @@ INSTANTIATE_TEST_SUITE_P( BinaryData{ast::BinaryOp::kAnd, "OpBitwiseAnd"}, BinaryData{ast::BinaryOp::kDivide, "OpUDiv"}, BinaryData{ast::BinaryOp::kModulo, "OpUMod"}, + BinaryData{ast::BinaryOp::kMultiply, "OpIMul"}, BinaryData{ast::BinaryOp::kOr, "OpBitwiseOr"}, BinaryData{ast::BinaryOp::kShiftLeft, "OpShiftLeftLogical"}, BinaryData{ast::BinaryOp::kShiftRight, "OpShiftRightLogical"}, @@ -239,7 +243,7 @@ TEST_P(BinaryArithFloatTest, Scalar) { Builder b; b.push_function(Function{}); - ASSERT_EQ(b.GenerateBinaryExpression(&expr), 4) << b.error(); + EXPECT_EQ(b.GenerateBinaryExpression(&expr), 4) << b.error(); EXPECT_EQ(DumpInstructions(b.types()), R"(%1 = OpTypeFloat 32 %2 = OpConstant %1 3.20000005 %3 = OpConstant %1 4.5 @@ -283,7 +287,7 @@ TEST_P(BinaryArithFloatTest, Vector) { Builder b; b.push_function(Function{}); - ASSERT_EQ(b.GenerateBinaryExpression(&expr), 5) << b.error(); + EXPECT_EQ(b.GenerateBinaryExpression(&expr), 5) << b.error(); EXPECT_EQ(DumpInstructions(b.types()), R"(%2 = OpTypeFloat 32 %1 = OpTypeVector %2 3 %3 = OpConstant %2 1 @@ -298,6 +302,7 @@ INSTANTIATE_TEST_SUITE_P( testing::Values(BinaryData{ast::BinaryOp::kAdd, "OpFAdd"}, BinaryData{ast::BinaryOp::kDivide, "OpFDiv"}, BinaryData{ast::BinaryOp::kModulo, "OpFMod"}, + BinaryData{ast::BinaryOp::kMultiply, "OpFMul"}, BinaryData{ast::BinaryOp::kSubtract, "OpFSub"})); using BinaryCompareUnsignedIntegerTest = testing::TestWithParam; @@ -320,7 +325,7 @@ TEST_P(BinaryCompareUnsignedIntegerTest, Scalar) { Builder b; b.push_function(Function{}); - ASSERT_EQ(b.GenerateBinaryExpression(&expr), 4) << b.error(); + EXPECT_EQ(b.GenerateBinaryExpression(&expr), 4) << b.error(); EXPECT_EQ(DumpInstructions(b.types()), R"(%1 = OpTypeInt 32 0 %2 = OpConstant %1 3 %3 = OpConstant %1 4 @@ -365,7 +370,7 @@ TEST_P(BinaryCompareUnsignedIntegerTest, Vector) { Builder b; b.push_function(Function{}); - ASSERT_EQ(b.GenerateBinaryExpression(&expr), 5) << b.error(); + EXPECT_EQ(b.GenerateBinaryExpression(&expr), 5) << b.error(); EXPECT_EQ(DumpInstructions(b.types()), R"(%2 = OpTypeInt 32 0 %1 = OpTypeVector %2 3 %3 = OpConstant %2 1 @@ -407,7 +412,7 @@ TEST_P(BinaryCompareSignedIntegerTest, Scalar) { Builder b; b.push_function(Function{}); - ASSERT_EQ(b.GenerateBinaryExpression(&expr), 4) << b.error(); + EXPECT_EQ(b.GenerateBinaryExpression(&expr), 4) << b.error(); EXPECT_EQ(DumpInstructions(b.types()), R"(%1 = OpTypeInt 32 1 %2 = OpConstant %1 3 %3 = OpConstant %1 4 @@ -452,7 +457,7 @@ TEST_P(BinaryCompareSignedIntegerTest, Vector) { Builder b; b.push_function(Function{}); - ASSERT_EQ(b.GenerateBinaryExpression(&expr), 5) << b.error(); + EXPECT_EQ(b.GenerateBinaryExpression(&expr), 5) << b.error(); EXPECT_EQ(DumpInstructions(b.types()), R"(%2 = OpTypeInt 32 1 %1 = OpTypeVector %2 3 %3 = OpConstant %2 1 @@ -494,7 +499,7 @@ TEST_P(BinaryCompareFloatTest, Scalar) { Builder b; b.push_function(Function{}); - ASSERT_EQ(b.GenerateBinaryExpression(&expr), 4) << b.error(); + EXPECT_EQ(b.GenerateBinaryExpression(&expr), 4) << b.error(); EXPECT_EQ(DumpInstructions(b.types()), R"(%1 = OpTypeFloat 32 %2 = OpConstant %1 3.20000005 %3 = OpConstant %1 4.5 @@ -539,7 +544,7 @@ TEST_P(BinaryCompareFloatTest, Vector) { Builder b; b.push_function(Function{}); - ASSERT_EQ(b.GenerateBinaryExpression(&expr), 5) << b.error(); + EXPECT_EQ(b.GenerateBinaryExpression(&expr), 5) << b.error(); EXPECT_EQ(DumpInstructions(b.types()), R"(%2 = OpTypeFloat 32 %1 = OpTypeVector %2 3 %3 = OpConstant %2 1 @@ -561,6 +566,288 @@ INSTANTIATE_TEST_SUITE_P( BinaryData{ast::BinaryOp::kLessThanEqual, "OpFOrdLessThanEqual"}, BinaryData{ast::BinaryOp::kNotEqual, "OpFOrdNotEqual"})); +TEST_F(BuilderTest, Binary_Multiply_VectorScalar) { + ast::type::F32Type f32; + ast::type::VectorType vec3(&f32, 3); + + ast::ExpressionList vals; + vals.push_back(std::make_unique( + std::make_unique(&f32, 1.f))); + vals.push_back(std::make_unique( + std::make_unique(&f32, 1.f))); + vals.push_back(std::make_unique( + std::make_unique(&f32, 1.f))); + auto lhs = + std::make_unique(&vec3, std::move(vals)); + + auto rhs = std::make_unique( + std::make_unique(&f32, 1.f)); + + Context ctx; + TypeDeterminer td(&ctx); + + ast::BinaryExpression expr(ast::BinaryOp::kMultiply, std::move(lhs), + std::move(rhs)); + + ASSERT_TRUE(td.DetermineResultType(&expr)) << td.error(); + + Builder b; + b.push_function(Function{}); + + EXPECT_EQ(b.GenerateBinaryExpression(&expr), 5) << b.error(); + EXPECT_EQ(DumpInstructions(b.types()), R"(%2 = OpTypeFloat 32 +%1 = OpTypeVector %2 3 +%3 = OpConstant %2 1 +%4 = OpConstantComposite %1 %3 %3 %3 +)"); + EXPECT_EQ(DumpInstructions(b.functions()[0].instructions()), + "%5 = OpVectorTimesScalar %1 %4 %3\n"); +} + +TEST_F(BuilderTest, Binary_Multiply_ScalarVector) { + ast::type::F32Type f32; + ast::type::VectorType vec3(&f32, 3); + + auto lhs = std::make_unique( + std::make_unique(&f32, 1.f)); + + ast::ExpressionList vals; + vals.push_back(std::make_unique( + std::make_unique(&f32, 1.f))); + vals.push_back(std::make_unique( + std::make_unique(&f32, 1.f))); + vals.push_back(std::make_unique( + std::make_unique(&f32, 1.f))); + auto rhs = + std::make_unique(&vec3, std::move(vals)); + + Context ctx; + TypeDeterminer td(&ctx); + + ast::BinaryExpression expr(ast::BinaryOp::kMultiply, std::move(lhs), + std::move(rhs)); + + ASSERT_TRUE(td.DetermineResultType(&expr)) << td.error(); + + Builder b; + b.push_function(Function{}); + + EXPECT_EQ(b.GenerateBinaryExpression(&expr), 5) << b.error(); + EXPECT_EQ(DumpInstructions(b.types()), R"(%1 = OpTypeFloat 32 +%2 = OpConstant %1 1 +%3 = OpTypeVector %1 3 +%4 = OpConstantComposite %3 %2 %2 %2 +)"); + EXPECT_EQ(DumpInstructions(b.functions()[0].instructions()), + "%5 = OpVectorTimesScalar %3 %4 %2\n"); +} + +TEST_F(BuilderTest, Binary_Multiply_MatrixScalar) { + ast::type::F32Type f32; + ast::type::MatrixType mat3(&f32, 3, 3); + + auto var = std::make_unique( + "mat", ast::StorageClass::kFunction, &mat3); + auto lhs = std::make_unique("mat"); + auto rhs = std::make_unique( + std::make_unique(&f32, 1.f)); + + Context ctx; + TypeDeterminer td(&ctx); + td.RegisterVariableForTesting(var.get()); + + ast::BinaryExpression expr(ast::BinaryOp::kMultiply, std::move(lhs), + std::move(rhs)); + + ASSERT_TRUE(td.DetermineResultType(&expr)) << td.error(); + + Builder b; + b.push_function(Function{}); + ASSERT_TRUE(b.GenerateGlobalVariable(var.get())) << b.error(); + + EXPECT_EQ(b.GenerateBinaryExpression(&expr), 8) << b.error(); + EXPECT_EQ(DumpInstructions(b.types()), R"(%5 = OpTypeFloat 32 +%4 = OpTypeVector %5 3 +%3 = OpTypeMatrix %4 3 +%2 = OpTypePointer Function %3 +%1 = OpVariable %2 Function +%7 = OpConstant %5 1 +)"); + EXPECT_EQ(DumpInstructions(b.functions()[0].instructions()), + R"(%6 = OpLoad %3 %1 +%8 = OpMatrixTimesScalar %3 %6 %7 +)"); +} + +TEST_F(BuilderTest, Binary_Multiply_ScalarMatrix) { + ast::type::F32Type f32; + ast::type::MatrixType mat3(&f32, 3, 3); + + auto var = std::make_unique( + "mat", ast::StorageClass::kFunction, &mat3); + auto lhs = std::make_unique( + std::make_unique(&f32, 1.f)); + auto rhs = std::make_unique("mat"); + + Context ctx; + TypeDeterminer td(&ctx); + td.RegisterVariableForTesting(var.get()); + + ast::BinaryExpression expr(ast::BinaryOp::kMultiply, std::move(lhs), + std::move(rhs)); + + ASSERT_TRUE(td.DetermineResultType(&expr)) << td.error(); + + Builder b; + b.push_function(Function{}); + ASSERT_TRUE(b.GenerateGlobalVariable(var.get())) << b.error(); + + EXPECT_EQ(b.GenerateBinaryExpression(&expr), 8) << b.error(); + EXPECT_EQ(DumpInstructions(b.types()), R"(%5 = OpTypeFloat 32 +%4 = OpTypeVector %5 3 +%3 = OpTypeMatrix %4 3 +%2 = OpTypePointer Function %3 +%1 = OpVariable %2 Function +%6 = OpConstant %5 1 +)"); + EXPECT_EQ(DumpInstructions(b.functions()[0].instructions()), + R"(%7 = OpLoad %3 %1 +%8 = OpMatrixTimesScalar %3 %7 %6 +)"); +} + +TEST_F(BuilderTest, Binary_Multiply_MatrixVector) { + ast::type::F32Type f32; + ast::type::VectorType vec3(&f32, 3); + ast::type::MatrixType mat3(&f32, 3, 3); + + auto var = std::make_unique( + "mat", ast::StorageClass::kFunction, &mat3); + auto lhs = std::make_unique("mat"); + + ast::ExpressionList vals; + vals.push_back(std::make_unique( + std::make_unique(&f32, 1.f))); + vals.push_back(std::make_unique( + std::make_unique(&f32, 1.f))); + vals.push_back(std::make_unique( + std::make_unique(&f32, 1.f))); + auto rhs = + std::make_unique(&vec3, std::move(vals)); + + Context ctx; + TypeDeterminer td(&ctx); + td.RegisterVariableForTesting(var.get()); + + ast::BinaryExpression expr(ast::BinaryOp::kMultiply, std::move(lhs), + std::move(rhs)); + + ASSERT_TRUE(td.DetermineResultType(&expr)) << td.error(); + + Builder b; + b.push_function(Function{}); + ASSERT_TRUE(b.GenerateGlobalVariable(var.get())) << b.error(); + + EXPECT_EQ(b.GenerateBinaryExpression(&expr), 9) << b.error(); + EXPECT_EQ(DumpInstructions(b.types()), R"(%5 = OpTypeFloat 32 +%4 = OpTypeVector %5 3 +%3 = OpTypeMatrix %4 3 +%2 = OpTypePointer Function %3 +%1 = OpVariable %2 Function +%7 = OpConstant %5 1 +%8 = OpConstantComposite %4 %7 %7 %7 +)"); + EXPECT_EQ(DumpInstructions(b.functions()[0].instructions()), + R"(%6 = OpLoad %3 %1 +%9 = OpMatrixTimesVector %4 %6 %8 +)"); +} + +TEST_F(BuilderTest, Binary_Multiply_VectorMatrix) { + ast::type::F32Type f32; + ast::type::VectorType vec3(&f32, 3); + ast::type::MatrixType mat3(&f32, 3, 3); + + auto var = std::make_unique( + "mat", ast::StorageClass::kFunction, &mat3); + + ast::ExpressionList vals; + vals.push_back(std::make_unique( + std::make_unique(&f32, 1.f))); + vals.push_back(std::make_unique( + std::make_unique(&f32, 1.f))); + vals.push_back(std::make_unique( + std::make_unique(&f32, 1.f))); + auto lhs = + std::make_unique(&vec3, std::move(vals)); + + auto rhs = std::make_unique("mat"); + + Context ctx; + TypeDeterminer td(&ctx); + td.RegisterVariableForTesting(var.get()); + + ast::BinaryExpression expr(ast::BinaryOp::kMultiply, std::move(lhs), + std::move(rhs)); + + ASSERT_TRUE(td.DetermineResultType(&expr)) << td.error(); + + Builder b; + b.push_function(Function{}); + ASSERT_TRUE(b.GenerateGlobalVariable(var.get())) << b.error(); + + EXPECT_EQ(b.GenerateBinaryExpression(&expr), 9) << b.error(); + EXPECT_EQ(DumpInstructions(b.types()), R"(%5 = OpTypeFloat 32 +%4 = OpTypeVector %5 3 +%3 = OpTypeMatrix %4 3 +%2 = OpTypePointer Function %3 +%1 = OpVariable %2 Function +%6 = OpConstant %5 1 +%7 = OpConstantComposite %4 %6 %6 %6 +)"); + EXPECT_EQ(DumpInstructions(b.functions()[0].instructions()), + R"(%8 = OpLoad %3 %1 +%9 = OpVectorTimesMatrix %4 %7 %8 +)"); +} + +TEST_F(BuilderTest, Binary_Multiply_MatrixMatrix) { + ast::type::F32Type f32; + ast::type::VectorType vec3(&f32, 3); + ast::type::MatrixType mat3(&f32, 3, 3); + + auto var = std::make_unique( + "mat", ast::StorageClass::kFunction, &mat3); + auto lhs = std::make_unique("mat"); + auto rhs = std::make_unique("mat"); + + Context ctx; + TypeDeterminer td(&ctx); + td.RegisterVariableForTesting(var.get()); + + ast::BinaryExpression expr(ast::BinaryOp::kMultiply, std::move(lhs), + std::move(rhs)); + + ASSERT_TRUE(td.DetermineResultType(&expr)) << td.error(); + + Builder b; + b.push_function(Function{}); + ASSERT_TRUE(b.GenerateGlobalVariable(var.get())) << b.error(); + + EXPECT_EQ(b.GenerateBinaryExpression(&expr), 8) << b.error(); + EXPECT_EQ(DumpInstructions(b.types()), R"(%5 = OpTypeFloat 32 +%4 = OpTypeVector %5 3 +%3 = OpTypeMatrix %4 3 +%2 = OpTypePointer Function %3 +%1 = OpVariable %2 Function +)"); + EXPECT_EQ(DumpInstructions(b.functions()[0].instructions()), + R"(%6 = OpLoad %3 %1 +%7 = OpLoad %3 %1 +%8 = OpMatrixTimesMatrix %3 %6 %7 +)"); +} + } // namespace } // namespace spirv } // namespace writer