tint: Implement DP4a on SPIR-V writer

Bug: tint:1497
Test: tint_unittests
Change-Id: Id0aa2cedb5de1a2f3139b1f67c320ac78f93aa57
Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/91500
Reviewed-by: Ben Clayton <bclayton@google.com>
Reviewed-by: David Neto <dneto@google.com>
Commit-Queue: Jiawei Shao <jiawei.shao@intel.com>
Kokoro: Kokoro <noreply+kokoro@google.com>
This commit is contained in:
Jiawei Shao 2022-05-26 00:25:04 +00:00 committed by Dawn LUCI CQ
parent 8ae9e94344
commit ce6adf4c67
4 changed files with 123 additions and 6 deletions

View File

@ -368,7 +368,11 @@ void Builder::push_capability(uint32_t cap) {
} }
} }
bool Builder::GenerateExtension(ast::Extension) { void Builder::push_extension(const char* extension) {
extensions_.push_back(Instruction{spv::Op::OpExtension, {Operand(extension)}});
}
bool Builder::GenerateExtension(ast::Extension extension) {
/* /*
For each supported extension, push corresponding capability into the builder. For each supported extension, push corresponding capability into the builder.
For example: For example:
@ -379,6 +383,15 @@ bool Builder::GenerateExtension(ast::Extension) {
push_capability(SpvCapabilityStorageInputOutput16); push_capability(SpvCapabilityStorageInputOutput16);
} }
*/ */
switch (extension) {
case ast::Extension::kChromiumExperimentalDP4a:
push_extension("SPV_KHR_integer_dot_product");
push_capability(SpvCapabilityDotProductKHR);
push_capability(SpvCapabilityDotProductInput4x8BitPackedKHR);
break;
default:
return false;
}
return true; return true;
} }
@ -2494,6 +2507,30 @@ uint32_t Builder::GenerateBuiltinCall(const sem::Call* call, const sem::Builtin*
glsl_std450(GLSLstd450SAbs); glsl_std450(GLSLstd450SAbs);
} }
break; break;
case BuiltinType::kDot4I8Packed: {
auto first_param_id = get_arg_as_value_id(0);
auto second_param_id = get_arg_as_value_id(1);
if (!push_function_inst(spv::Op::OpSDotKHR,
{Operand(result_type_id), result, Operand(first_param_id),
Operand(second_param_id),
Operand(static_cast<uint32_t>(
spv::PackedVectorFormat::PackedVectorFormat4x8BitKHR))})) {
return 0;
}
return result_id;
}
case BuiltinType::kDot4U8Packed: {
auto first_param_id = get_arg_as_value_id(0);
auto second_param_id = get_arg_as_value_id(1);
if (!push_function_inst(spv::Op::OpUDotKHR,
{Operand(result_type_id), result, Operand(first_param_id),
Operand(second_param_id),
Operand(static_cast<uint32_t>(
spv::PackedVectorFormat::PackedVectorFormat4x8BitKHR))})) {
return 0;
}
return result_id;
}
default: { default: {
auto inst_id = builtin_to_glsl_method(builtin); auto inst_id = builtin_to_glsl_method(builtin);
if (inst_id == 0) { if (inst_id == 0) {

View File

@ -113,11 +113,8 @@ class Builder {
/// @returns the capabilities /// @returns the capabilities
const InstructionList& capabilities() const { return capabilities_; } const InstructionList& capabilities() const { return capabilities_; }
/// Adds an instruction to the extensions /// Adds an instruction to the extensions
/// @param op the op to set /// @param extension the name of the extension
/// @param operands the operands for the instruction void push_extension(const char* extension);
void push_extension(spv::Op op, const OperandList& operands) {
extensions_.push_back(Instruction{op, operands});
}
/// @returns the extensions /// @returns the extensions
const InstructionList& extensions() const { return extensions_; } const InstructionList& extensions() const { return extensions_; }
/// Adds an instruction to the ext import /// Adds an instruction to the ext import

View File

@ -2601,5 +2601,80 @@ OpFunctionEnd
)"); )");
} }
TEST_F(BuiltinBuilderTest, Call_Dot4I8Packed) {
auto* ext =
create<ast::Enable>(Source{Source::Range{Source::Location{10, 2}, Source::Location{10, 5}}},
ast::Extension::kChromiumExperimentalDP4a);
AST().AddEnable(ext);
auto* val1 = Var("val1", ty.u32());
auto* val2 = Var("val2", ty.u32());
auto* call = Call("dot4I8Packed", val1, val2);
auto* func = WrapInFunction(val1, val2, call);
spirv::Builder& b = Build();
ASSERT_TRUE(b.GenerateFunction(func)) << b.error();
EXPECT_EQ(DumpBuilder(b), R"(OpEntryPoint GLCompute %3 "test_function"
OpExecutionMode %3 LocalSize 1 1 1
OpName %3 "test_function"
OpName %5 "val1"
OpName %9 "val2"
%2 = OpTypeVoid
%1 = OpTypeFunction %2
%7 = OpTypeInt 32 0
%6 = OpTypePointer Function %7
%8 = OpConstantNull %7
%11 = OpTypeInt 32 1
%3 = OpFunction %2 None %1
%4 = OpLabel
%5 = OpVariable %6 Function %8
%9 = OpVariable %6 Function %8
%12 = OpLoad %7 %5
%13 = OpLoad %7 %9
%10 = OpSDot %11 %12 %13 PackedVectorFormat4x8Bit
OpReturn
OpFunctionEnd
)");
}
TEST_F(BuiltinBuilderTest, Call_Dot4U8Packed) {
auto* ext =
create<ast::Enable>(Source{Source::Range{Source::Location{10, 2}, Source::Location{10, 5}}},
ast::Extension::kChromiumExperimentalDP4a);
AST().AddEnable(ext);
auto* val1 = Var("val1", ty.u32());
auto* val2 = Var("val2", ty.u32());
auto* call = Call("dot4U8Packed", val1, val2);
auto* func = WrapInFunction(val1, val2, call);
spirv::Builder& b = Build();
ASSERT_TRUE(b.GenerateFunction(func)) << b.error();
EXPECT_EQ(DumpBuilder(b), R"(OpEntryPoint GLCompute %3 "test_function"
OpExecutionMode %3 LocalSize 1 1 1
OpName %3 "test_function"
OpName %5 "val1"
OpName %9 "val2"
%2 = OpTypeVoid
%1 = OpTypeFunction %2
%7 = OpTypeInt 32 0
%6 = OpTypePointer Function %7
%8 = OpConstantNull %7
%3 = OpFunction %2 None %1
%4 = OpLabel
%5 = OpVariable %6 Function %8
%9 = OpVariable %6 Function %8
%11 = OpLoad %7 %5
%12 = OpLoad %7 %9
%10 = OpUDot %7 %11 %12 PackedVectorFormat4x8Bit
OpReturn
OpFunctionEnd
)");
}
} // namespace } // namespace
} // namespace tint::writer::spirv } // namespace tint::writer::spirv

View File

@ -49,5 +49,13 @@ TEST_F(BuilderTest, Capabilities_Dedup) {
EXPECT_EQ(DumpInstructions(b.capabilities()), "OpCapability Shader\n"); EXPECT_EQ(DumpInstructions(b.capabilities()), "OpCapability Shader\n");
} }
TEST_F(BuilderTest, DeclareExtension) {
spirv::Builder& b = Build();
b.push_extension("SPV_KHR_integer_dot_product");
EXPECT_EQ(DumpInstructions(b.extensions()), "OpExtension \"SPV_KHR_integer_dot_product\"\n");
}
} // namespace } // namespace
} // namespace tint::writer::spirv } // namespace tint::writer::spirv