tint: Implement DP4a on HLSL writer

Bug: tint:1497
Test: tint_unittests
Change-Id: I29cc3e56949071230cdbd5afdc59eef076777149
Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/89706
Reviewed-by: Ben Clayton <bclayton@google.com>
Commit-Queue: Jiawei Shao <jiawei.shao@intel.com>
Kokoro: Kokoro <noreply+kokoro@google.com>
This commit is contained in:
Jiawei Shao 2022-05-13 00:09:56 +00:00 committed by Dawn LUCI CQ
parent 189325ab1f
commit ab9757036b
5 changed files with 115 additions and 8 deletions

View File

@ -83,6 +83,10 @@ bool IsAtomicBuiltin(BuiltinType i) {
i == sem::BuiltinType::kAtomicCompareExchangeWeak; i == sem::BuiltinType::kAtomicCompareExchangeWeak;
} }
bool IsDP4aBuiltin(BuiltinType i) {
return i == sem::BuiltinType::kDot4I8Packed || i == sem::BuiltinType::kDot4U8Packed;
}
Builtin::Builtin(BuiltinType type, Builtin::Builtin(BuiltinType type,
const sem::Type* return_type, const sem::Type* return_type,
std::vector<Parameter*> parameters, std::vector<Parameter*> parameters,
@ -135,6 +139,10 @@ bool Builtin::IsAtomic() const {
return IsAtomicBuiltin(type_); return IsAtomicBuiltin(type_);
} }
bool Builtin::IsDP4a() const {
return IsDP4aBuiltin(type_);
}
bool Builtin::HasSideEffects() const { bool Builtin::HasSideEffects() const {
if (IsAtomic() && type_ != sem::BuiltinType::kAtomicLoad) { if (IsAtomic() && type_ != sem::BuiltinType::kAtomicLoad) {
return true; return true;
@ -146,13 +154,10 @@ bool Builtin::HasSideEffects() const {
} }
ast::Enable::ExtensionKind Builtin::RequiredExtension() const { ast::Enable::ExtensionKind Builtin::RequiredExtension() const {
switch (type_) { if (IsDP4a()) {
case sem::BuiltinType::kDot4I8Packed: return ast::Enable::ExtensionKind::kChromiumExperimentalDP4a;
case sem::BuiltinType::kDot4U8Packed:
return ast::Enable::ExtensionKind::kChromiumExperimentalDP4a;
default:
return ast::Enable::ExtensionKind::kNotAnExtension;
} }
return ast::Enable::ExtensionKind::kNotAnExtension;
} }
} // namespace tint::sem } // namespace tint::sem

View File

@ -70,6 +70,11 @@ bool IsBarrierBuiltin(BuiltinType i);
/// @returns true if the given `i` is a atomic builtin /// @returns true if the given `i` is a atomic builtin
bool IsAtomicBuiltin(BuiltinType i); bool IsAtomicBuiltin(BuiltinType i);
/// Determins if the given `i` is a DP4a builtin
/// @param i the builtin
/// @returns true if the given `i` is a DP4a builtin
bool IsDP4aBuiltin(BuiltinType i);
/// Builtin holds the semantic information for a builtin function. /// Builtin holds the semantic information for a builtin function.
class Builtin final : public Castable<Builtin, CallTarget> { class Builtin final : public Castable<Builtin, CallTarget> {
public: public:
@ -130,6 +135,10 @@ class Builtin final : public Castable<Builtin, CallTarget> {
/// @returns true if builtin is a atomic builtin /// @returns true if builtin is a atomic builtin
bool IsAtomic() const; bool IsAtomic() const;
/// @returns true if builtin is a DP4a builtin (defined in the extension
/// chromium_experimental_DP4a)
bool IsDP4a() const;
/// @returns true if intrinsic may have side-effects (i.e. writes to at least /// @returns true if intrinsic may have side-effects (i.e. writes to at least
/// one of its inputs) /// one of its inputs)
bool HasSideEffects() const; bool HasSideEffects() const;

View File

@ -1027,6 +1027,9 @@ bool GeneratorImpl::EmitBuiltinCall(std::ostream& out,
if (builtin->IsAtomic()) { if (builtin->IsAtomic()) {
return EmitWorkgroupAtomicCall(out, expr, builtin); return EmitWorkgroupAtomicCall(out, expr, builtin);
} }
if (builtin->IsDP4a()) {
return EmitDP4aCall(out, expr, builtin);
}
auto name = generate_builtin_name(builtin); auto name = generate_builtin_name(builtin);
if (name.empty()) { if (name.empty()) {
return false; return false;
@ -2033,6 +2036,32 @@ bool GeneratorImpl::EmitDataUnpackingCall(std::ostream& out,
}); });
} }
bool GeneratorImpl::EmitDP4aCall(std::ostream& out,
const ast::CallExpression* expr,
const sem::Builtin* builtin) {
// TODO(crbug.com/tint/1497): support the polyfill version of DP4a functions.
return CallBuiltinHelper(
out, expr, builtin, [&](TextBuffer* b, const std::vector<std::string>& params) {
std::string functionName;
switch (builtin->Type()) {
case sem::BuiltinType::kDot4I8Packed:
functionName = "dot4add_i8packed";
break;
case sem::BuiltinType::kDot4U8Packed:
functionName = "dot4add_u8packed";
break;
default:
diagnostics_.add_error(diag::System::Writer,
"Internal error: unhandled DP4a builtin");
return false;
}
line(b) << "return " << functionName << "(" << params[0] << ", " << params[1]
<< ", 0);";
return true;
});
}
bool GeneratorImpl::EmitBarrierCall(std::ostream& out, const sem::Builtin* builtin) { bool GeneratorImpl::EmitBarrierCall(std::ostream& out, const sem::Builtin* builtin) {
// TODO(crbug.com/tint/661): Combine sequential barriers to a single // TODO(crbug.com/tint/661): Combine sequential barriers to a single
// instruction. // instruction.

View File

@ -242,7 +242,7 @@ class GeneratorImpl : public TextGenerator {
/// Handles generating a call to data packing builtin /// Handles generating a call to data packing builtin
/// @param out the output of the expression stream /// @param out the output of the expression stream
/// @param expr the call expression /// @param expr the call expression
/// @param builtin the semantic information for the texture builtin /// @param builtin the semantic information for the builtin
/// @returns true if the call expression is emitted /// @returns true if the call expression is emitted
bool EmitDataPackingCall(std::ostream& out, bool EmitDataPackingCall(std::ostream& out,
const ast::CallExpression* expr, const ast::CallExpression* expr,
@ -250,11 +250,19 @@ class GeneratorImpl : public TextGenerator {
/// Handles generating a call to data unpacking builtin /// Handles generating a call to data unpacking builtin
/// @param out the output of the expression stream /// @param out the output of the expression stream
/// @param expr the call expression /// @param expr the call expression
/// @param builtin the semantic information for the texture builtin /// @param builtin the semantic information for the builtin
/// @returns true if the call expression is emitted /// @returns true if the call expression is emitted
bool EmitDataUnpackingCall(std::ostream& out, bool EmitDataUnpackingCall(std::ostream& out,
const ast::CallExpression* expr, const ast::CallExpression* expr,
const sem::Builtin* builtin); const sem::Builtin* builtin);
/// Handles generating a call to DP4a builtins (dot4I8Packed and dot4U8Packed)
/// @param out the output of the expression stream
/// @param expr the call expression
/// @param builtin the semantic information for the builtin
/// @returns true if the call expression is emitted
bool EmitDP4aCall(std::ostream& out,
const ast::CallExpression* expr,
const sem::Builtin* builtin);
/// Handles a case statement /// Handles a case statement
/// @param s the switch statement /// @param s the switch statement
/// @param case_idx the index of the switch case in the switch statement /// @param case_idx the index of the switch case in the switch statement

View File

@ -726,5 +726,61 @@ void main() {
)"); )");
} }
TEST_F(HlslGeneratorImplTest_Builtin, Dot4I8Packed) {
auto* ext =
create<ast::Enable>(Source{Source::Range{Source::Location{10, 2}, Source::Location{10, 5}}},
"chromium_experimental_dp4a");
AST().AddEnable(ext);
auto* val1 = Var("val1", ty.u32());
auto* val2 = Var("val2", ty.u32());
auto* call = Call("dot4I8Packed", val1, val2);
WrapInFunction(val1, val2, call);
GeneratorImpl& gen = SanitizeAndBuild();
ASSERT_TRUE(gen.Generate()) << gen.error();
EXPECT_EQ(gen.result(), R"(int tint_dot4I8Packed(uint param_0, uint param_1) {
return dot4add_i8packed(param_0, param_1, 0);
}
[numthreads(1, 1, 1)]
void test_function() {
uint val1 = 0u;
uint val2 = 0u;
const int tint_symbol = tint_dot4I8Packed(val1, val2);
return;
}
)");
}
TEST_F(HlslGeneratorImplTest_Builtin, Dot4U8Packed) {
auto* ext =
create<ast::Enable>(Source{Source::Range{Source::Location{10, 2}, Source::Location{10, 5}}},
"chromium_experimental_dp4a");
AST().AddEnable(ext);
auto* val1 = Var("val1", ty.u32());
auto* val2 = Var("val2", ty.u32());
auto* call = Call("dot4U8Packed", val1, val2);
WrapInFunction(val1, val2, call);
GeneratorImpl& gen = SanitizeAndBuild();
ASSERT_TRUE(gen.Generate()) << gen.error();
EXPECT_EQ(gen.result(), R"(uint tint_dot4U8Packed(uint param_0, uint param_1) {
return dot4add_u8packed(param_0, param_1, 0);
}
[numthreads(1, 1, 1)]
void test_function() {
uint val1 = 0u;
uint val2 = 0u;
const uint tint_symbol = tint_dot4U8Packed(val1, val2);
return;
}
)");
}
} // namespace } // namespace
} // namespace tint::writer::hlsl } // namespace tint::writer::hlsl