diff --git a/src/writer/spirv/builder.cc b/src/writer/spirv/builder.cc index f2b220ce46..382681b1ed 100644 --- a/src/writer/spirv/builder.cc +++ b/src/writer/spirv/builder.cc @@ -1284,6 +1284,8 @@ uint32_t Builder::GenerateIntrinsic(const std::string& name, op = spv::Op::OpIsInf; } else if (name == "is_nan") { op = spv::Op::OpIsNan; + } else if (name == "outer_product") { + op = spv::Op::OpOuterProduct; } if (op == spv::Op::OpNop) { error_ = "unable to determine operator for: " + name; diff --git a/src/writer/spirv/builder_intrinsic_test.cc b/src/writer/spirv/builder_intrinsic_test.cc index 34d434e8ea..ea4974c0e4 100644 --- a/src/writer/spirv/builder_intrinsic_test.cc +++ b/src/writer/spirv/builder_intrinsic_test.cc @@ -19,6 +19,7 @@ #include "src/ast/identifier_expression.h" #include "src/ast/type/bool_type.h" #include "src/ast/type/f32_type.h" +#include "src/ast/type/matrix_type.h" #include "src/ast/type/vector_type.h" #include "src/ast/variable.h" #include "src/context.h" @@ -304,6 +305,57 @@ INSTANTIATE_TEST_SUITE_P( IntrinsicData{"fwidth_fine", "OpFwidthFine"}, IntrinsicData{"fwidth_coarse", "OpFwidthCoarse"})); +TEST_F(BuilderTest, Call_OuterProduct) { + ast::type::F32Type f32; + ast::type::VectorType vec2(&f32, 2); + ast::type::VectorType vec3(&f32, 3); + ast::type::MatrixType mat(&f32, 2, 3); + + auto v2 = + std::make_unique("v2", ast::StorageClass::kPrivate, &vec2); + auto v3 = + std::make_unique("v3", ast::StorageClass::kPrivate, &vec3); + + ast::ExpressionList params; + params.push_back(std::make_unique("v2")); + params.push_back(std::make_unique("v3")); + ast::CallExpression expr( + std::make_unique("outer_product"), + std::move(params)); + + Context ctx; + ast::Module mod; + TypeDeterminer td(&ctx, &mod); + td.RegisterVariableForTesting(v2.get()); + td.RegisterVariableForTesting(v3.get()); + + ASSERT_TRUE(td.DetermineResultType(&expr)) << td.error(); + + Builder b(&mod); + b.push_function(Function{}); + ASSERT_TRUE(b.GenerateGlobalVariable(v2.get())) << b.error(); + ASSERT_TRUE(b.GenerateGlobalVariable(v3.get())) << b.error(); + + EXPECT_EQ(b.GenerateCallExpression(&expr), 10u) << b.error(); + + EXPECT_EQ(DumpInstructions(b.types()), R"(%4 = OpTypeFloat 32 +%3 = OpTypeVector %4 2 +%2 = OpTypePointer Private %3 +%5 = OpConstantNull %3 +%1 = OpVariable %2 Private %5 +%8 = OpTypeVector %4 3 +%7 = OpTypePointer Private %8 +%9 = OpConstantNull %8 +%6 = OpVariable %7 Private %9 +%11 = OpTypeMatrix %3 3 +)"); + EXPECT_EQ(DumpInstructions(b.functions()[0].instructions()), + R"(%12 = OpLoad %3 %1 +%13 = OpLoad %8 %6 +%10 = OpOuterProduct %11 %12 %13 +)"); +} + } // namespace } // namespace spirv } // namespace writer