diff --git a/src/tint/writer/spirv/builder.cc b/src/tint/writer/spirv/builder.cc index df57345b8f..e2d1fac831 100644 --- a/src/tint/writer/spirv/builder.cc +++ b/src/tint/writer/spirv/builder.cc @@ -368,7 +368,11 @@ void Builder::push_capability(uint32_t cap) { } } -bool Builder::GenerateExtension(ast::Extension) { +void Builder::push_extension(const char* extension) { + extensions_.push_back(Instruction{spv::Op::OpExtension, {Operand(extension)}}); +} + +bool Builder::GenerateExtension(ast::Extension extension) { /* For each supported extension, push corresponding capability into the builder. For example: @@ -379,6 +383,15 @@ bool Builder::GenerateExtension(ast::Extension) { push_capability(SpvCapabilityStorageInputOutput16); } */ + switch (extension) { + case ast::Extension::kChromiumExperimentalDP4a: + push_extension("SPV_KHR_integer_dot_product"); + push_capability(SpvCapabilityDotProductKHR); + push_capability(SpvCapabilityDotProductInput4x8BitPackedKHR); + break; + default: + return false; + } return true; } @@ -2494,6 +2507,30 @@ uint32_t Builder::GenerateBuiltinCall(const sem::Call* call, const sem::Builtin* glsl_std450(GLSLstd450SAbs); } break; + case BuiltinType::kDot4I8Packed: { + auto first_param_id = get_arg_as_value_id(0); + auto second_param_id = get_arg_as_value_id(1); + if (!push_function_inst(spv::Op::OpSDotKHR, + {Operand(result_type_id), result, Operand(first_param_id), + Operand(second_param_id), + Operand(static_cast( + spv::PackedVectorFormat::PackedVectorFormat4x8BitKHR))})) { + return 0; + } + return result_id; + } + case BuiltinType::kDot4U8Packed: { + auto first_param_id = get_arg_as_value_id(0); + auto second_param_id = get_arg_as_value_id(1); + if (!push_function_inst(spv::Op::OpUDotKHR, + {Operand(result_type_id), result, Operand(first_param_id), + Operand(second_param_id), + Operand(static_cast( + spv::PackedVectorFormat::PackedVectorFormat4x8BitKHR))})) { + return 0; + } + return result_id; + } default: { auto inst_id = builtin_to_glsl_method(builtin); if (inst_id == 0) { diff --git a/src/tint/writer/spirv/builder.h b/src/tint/writer/spirv/builder.h index 1745ed568f..fc2fa13940 100644 --- a/src/tint/writer/spirv/builder.h +++ b/src/tint/writer/spirv/builder.h @@ -113,11 +113,8 @@ class Builder { /// @returns the capabilities const InstructionList& capabilities() const { return capabilities_; } /// Adds an instruction to the extensions - /// @param op the op to set - /// @param operands the operands for the instruction - void push_extension(spv::Op op, const OperandList& operands) { - extensions_.push_back(Instruction{op, operands}); - } + /// @param extension the name of the extension + void push_extension(const char* extension); /// @returns the extensions const InstructionList& extensions() const { return extensions_; } /// Adds an instruction to the ext import diff --git a/src/tint/writer/spirv/builder_builtin_test.cc b/src/tint/writer/spirv/builder_builtin_test.cc index 6d1316e99c..59a567aae1 100644 --- a/src/tint/writer/spirv/builder_builtin_test.cc +++ b/src/tint/writer/spirv/builder_builtin_test.cc @@ -2601,5 +2601,80 @@ OpFunctionEnd )"); } +TEST_F(BuiltinBuilderTest, Call_Dot4I8Packed) { + auto* ext = + create(Source{Source::Range{Source::Location{10, 2}, Source::Location{10, 5}}}, + ast::Extension::kChromiumExperimentalDP4a); + AST().AddEnable(ext); + + auto* val1 = Var("val1", ty.u32()); + auto* val2 = Var("val2", ty.u32()); + auto* call = Call("dot4I8Packed", val1, val2); + auto* func = WrapInFunction(val1, val2, call); + + spirv::Builder& b = Build(); + + ASSERT_TRUE(b.GenerateFunction(func)) << b.error(); + + EXPECT_EQ(DumpBuilder(b), R"(OpEntryPoint GLCompute %3 "test_function" +OpExecutionMode %3 LocalSize 1 1 1 +OpName %3 "test_function" +OpName %5 "val1" +OpName %9 "val2" +%2 = OpTypeVoid +%1 = OpTypeFunction %2 +%7 = OpTypeInt 32 0 +%6 = OpTypePointer Function %7 +%8 = OpConstantNull %7 +%11 = OpTypeInt 32 1 +%3 = OpFunction %2 None %1 +%4 = OpLabel +%5 = OpVariable %6 Function %8 +%9 = OpVariable %6 Function %8 +%12 = OpLoad %7 %5 +%13 = OpLoad %7 %9 +%10 = OpSDot %11 %12 %13 PackedVectorFormat4x8Bit +OpReturn +OpFunctionEnd +)"); +} + +TEST_F(BuiltinBuilderTest, Call_Dot4U8Packed) { + auto* ext = + create(Source{Source::Range{Source::Location{10, 2}, Source::Location{10, 5}}}, + ast::Extension::kChromiumExperimentalDP4a); + AST().AddEnable(ext); + + auto* val1 = Var("val1", ty.u32()); + auto* val2 = Var("val2", ty.u32()); + auto* call = Call("dot4U8Packed", val1, val2); + auto* func = WrapInFunction(val1, val2, call); + + spirv::Builder& b = Build(); + + ASSERT_TRUE(b.GenerateFunction(func)) << b.error(); + + EXPECT_EQ(DumpBuilder(b), R"(OpEntryPoint GLCompute %3 "test_function" +OpExecutionMode %3 LocalSize 1 1 1 +OpName %3 "test_function" +OpName %5 "val1" +OpName %9 "val2" +%2 = OpTypeVoid +%1 = OpTypeFunction %2 +%7 = OpTypeInt 32 0 +%6 = OpTypePointer Function %7 +%8 = OpConstantNull %7 +%3 = OpFunction %2 None %1 +%4 = OpLabel +%5 = OpVariable %6 Function %8 +%9 = OpVariable %6 Function %8 +%11 = OpLoad %7 %5 +%12 = OpLoad %7 %9 +%10 = OpUDot %7 %11 %12 PackedVectorFormat4x8Bit +OpReturn +OpFunctionEnd +)"); +} + } // namespace } // namespace tint::writer::spirv diff --git a/src/tint/writer/spirv/builder_test.cc b/src/tint/writer/spirv/builder_test.cc index 3548f9a360..24d5b725f5 100644 --- a/src/tint/writer/spirv/builder_test.cc +++ b/src/tint/writer/spirv/builder_test.cc @@ -49,5 +49,13 @@ TEST_F(BuilderTest, Capabilities_Dedup) { EXPECT_EQ(DumpInstructions(b.capabilities()), "OpCapability Shader\n"); } +TEST_F(BuilderTest, DeclareExtension) { + spirv::Builder& b = Build(); + + b.push_extension("SPV_KHR_integer_dot_product"); + + EXPECT_EQ(DumpInstructions(b.extensions()), "OpExtension \"SPV_KHR_integer_dot_product\"\n"); +} + } // namespace } // namespace tint::writer::spirv