instrinsics: Implement dot() for integer vector types

Fixed: tint:1263
Change-Id: I642ea0b6c9be7f04930cf6ea1a8059825661e326
Reviewed-on: https://dawn-review.googlesource.com/c/tint/+/68520
Kokoro: Kokoro <noreply+kokoro@google.com>
Reviewed-by: David Neto <dneto@google.com>
Commit-Queue: Ben Clayton <bclayton@google.com>
This commit is contained in:
Ben Clayton
2021-11-05 18:37:16 +00:00
committed by Tint LUCI CQ
parent a9156ff091
commit 189dc7d3fd
42 changed files with 1617 additions and 35 deletions

View File

@@ -3780,12 +3780,12 @@ constexpr ParameterInfo kParameters[] = {
{
/* [398] */
/* usage */ ParameterUsage::kNone,
/* matcher indices */ &kMatcherIndices[21],
/* matcher indices */ &kMatcherIndices[39],
},
{
/* [399] */
/* usage */ ParameterUsage::kNone,
/* matcher indices */ &kMatcherIndices[21],
/* matcher indices */ &kMatcherIndices[39],
},
{
/* [400] */
@@ -7791,12 +7791,12 @@ constexpr OverloadInfo kOverloads[] = {
{
/* [251] */
/* num parameters */ 2,
/* num open types */ 0,
/* num open types */ 1,
/* num open numbers */ 1,
/* open types */ &kOpenTypes[4],
/* open types */ &kOpenTypes[1],
/* open numbers */ &kOpenNumbers[3],
/* parameters */ &kParameters[398],
/* return matcher indices */ &kMatcherIndices[12],
/* return matcher indices */ &kMatcherIndices[1],
/* supported_stages */ PipelineStageSet(PipelineStage::kVertex, PipelineStage::kFragment, PipelineStage::kCompute),
/* is_deprecated */ false,
},
@@ -8082,7 +8082,7 @@ constexpr IntrinsicInfo kIntrinsics[] = {
},
{
/* [16] */
/* fn dot<N : num>(vec<N, f32>, vec<N, f32>) -> f32 */
/* fn dot<N : num, T : fiu32>(vec<N, T>, vec<N, T>) -> T */
/* num overloads */ 1,
/* overloads */ &kOverloads[251],
},

View File

@@ -292,7 +292,7 @@ fn cross(vec3<f32>, vec3<f32>) -> vec3<f32>
fn determinant<N: num>(mat<N, N, f32>) -> f32
fn distance(f32, f32) -> f32
fn distance<N: num>(vec<N, f32>, vec<N, f32>) -> f32
fn dot<N: num>(vec<N, f32>, vec<N, f32>) -> f32
fn dot<N: num, T: fiu32>(vec<N, T>, vec<N, T>) -> T
[[stage("fragment")]] fn dpdx(f32) -> f32
[[stage("fragment")]] fn dpdx<N: num>(vec<N, f32>) -> vec<N, f32>
[[stage("fragment")]] fn dpdxCoarse(f32) -> f32

View File

@@ -339,7 +339,7 @@ TEST_F(ResolverIntrinsicTest, Dot_Vec2) {
}
TEST_F(ResolverIntrinsicTest, Dot_Vec3) {
Global("my_var", ty.vec3<f32>(), ast::StorageClass::kPrivate);
Global("my_var", ty.vec3<i32>(), ast::StorageClass::kPrivate);
auto* expr = Call("dot", "my_var", "my_var");
WrapInFunction(expr);
@@ -347,11 +347,11 @@ TEST_F(ResolverIntrinsicTest, Dot_Vec3) {
EXPECT_TRUE(r()->Resolve()) << r()->error();
ASSERT_NE(TypeOf(expr), nullptr);
EXPECT_TRUE(TypeOf(expr)->Is<sem::F32>());
EXPECT_TRUE(TypeOf(expr)->Is<sem::I32>());
}
TEST_F(ResolverIntrinsicTest, Dot_Vec4) {
Global("my_var", ty.vec4<f32>(), ast::StorageClass::kPrivate);
Global("my_var", ty.vec4<u32>(), ast::StorageClass::kPrivate);
auto* expr = Call("dot", "my_var", "my_var");
WrapInFunction(expr);
@@ -359,7 +359,7 @@ TEST_F(ResolverIntrinsicTest, Dot_Vec4) {
EXPECT_TRUE(r()->Resolve()) << r()->error();
ASSERT_NE(TypeOf(expr), nullptr);
EXPECT_TRUE(TypeOf(expr)->Is<sem::F32>());
EXPECT_TRUE(TypeOf(expr)->Is<sem::U32>());
}
TEST_F(ResolverIntrinsicTest, Dot_Error_Scalar) {
@@ -372,23 +372,7 @@ TEST_F(ResolverIntrinsicTest, Dot_Error_Scalar) {
R"(error: no matching call to dot(f32, f32)
1 candidate function:
dot(vecN<f32>, vecN<f32>) -> f32
)");
}
TEST_F(ResolverIntrinsicTest, Dot_Error_VectorInt) {
Global("my_var", ty.vec4<i32>(), ast::StorageClass::kPrivate);
auto* expr = Call("dot", "my_var", "my_var");
WrapInFunction(expr);
EXPECT_FALSE(r()->Resolve());
EXPECT_EQ(r()->error(),
R"(error: no matching call to dot(vec4<i32>, vec4<i32>)
1 candidate function:
dot(vecN<f32>, vecN<f32>) -> f32
dot(vecN<T>, vecN<T>) -> T where: T is f32, i32 or u32
)");
}

View File

@@ -107,7 +107,11 @@ bool GeneratorImpl::Generate() {
size_t last_padding_line = 0;
line() << "#version 310 es";
line() << "precision mediump float;" << std::endl;
line() << "precision mediump float;";
auto helpers_insertion_point = current_buffer_->lines.size();
line();
for (auto* decl : builder_.AST().GlobalDeclarations()) {
if (decl->Is<ast::Alias>()) {
@@ -153,7 +157,8 @@ bool GeneratorImpl::Generate() {
}
if (!helpers_.lines.empty()) {
current_buffer_->Insert(helpers_, 0, 0);
current_buffer_->Insert("", helpers_insertion_point++, 0);
current_buffer_->Insert(helpers_, helpers_insertion_point++, 0);
}
return true;
@@ -407,6 +412,8 @@ bool GeneratorImpl::EmitCall(std::ostream& out,
return EmitTextureCall(out, expr, intrinsic);
} else if (intrinsic->Type() == sem::IntrinsicType::kSelect) {
return EmitSelectCall(out, expr);
} else if (intrinsic->Type() == sem::IntrinsicType::kDot) {
return EmitDotCall(out, expr, intrinsic);
} else if (intrinsic->Type() == sem::IntrinsicType::kModf) {
return EmitModfCall(out, expr, intrinsic);
} else if (intrinsic->Type() == sem::IntrinsicType::kFrexp) {
@@ -671,6 +678,78 @@ bool GeneratorImpl::EmitSelectCall(std::ostream& out,
return true;
}
bool GeneratorImpl::EmitDotCall(std::ostream& out,
const ast::CallExpression* expr,
const sem::Intrinsic* intrinsic) {
auto* vec_ty = intrinsic->Parameters()[0]->Type()->As<sem::Vector>();
std::string fn = "dot";
if (vec_ty->type()->is_integer_scalar()) {
// GLSL does not have a builtin for dot() with integer vector types.
// Generate the helper function if it hasn't been created already
fn = utils::GetOrCreate(int_dot_funcs_, vec_ty, [&]() -> std::string {
TextBuffer b;
TINT_DEFER(helpers_.Append(b));
auto fn_name = UniqueIdentifier("tint_int_dot");
std::string v;
{
std::stringstream s;
if (!EmitType(s, vec_ty->type(), ast::StorageClass::kNone,
ast::Access::kRead, "")) {
return "";
}
v = s.str();
}
{ // (u)int tint_int_dot([i|u]vecN a, [i|u]vecN b) {
auto l = line(&b);
if (!EmitType(l, vec_ty->type(), ast::StorageClass::kNone,
ast::Access::kRead, "")) {
return "";
}
l << " " << fn_name << "(";
if (!EmitType(l, vec_ty, ast::StorageClass::kNone, ast::Access::kRead,
"")) {
return "";
}
l << " a, ";
if (!EmitType(l, vec_ty, ast::StorageClass::kNone, ast::Access::kRead,
"")) {
return "";
}
l << " b) {";
}
{
auto l = line(&b);
l << " return ";
for (uint32_t i = 0; i < vec_ty->Width(); i++) {
if (i > 0) {
l << " + ";
}
l << "a[" << i << "]*b[" << i << "]";
}
l << ";";
}
line(&b) << "}";
return fn_name;
});
if (fn.empty()) {
return false;
}
}
out << fn << "(";
if (!EmitExpression(out, expr->args[0])) {
return false;
}
out << ", ";
if (!EmitExpression(out, expr->args[1])) {
return false;
}
out << ")";
return true;
}
bool GeneratorImpl::EmitModfCall(std::ostream& out,
const ast::CallExpression* expr,
const sem::Intrinsic* intrinsic) {
@@ -2216,9 +2295,6 @@ bool GeneratorImpl::EmitStatement(const ast::Statement* stmt) {
}
if (auto* c = stmt->As<ast::CallStatement>()) {
auto out = line();
if (!TypeOf(c->expr)->Is<sem::Void>()) {
out << "(void) ";
}
if (!EmitCall(out, c->expr)) {
return false;
}

View File

@@ -136,6 +136,14 @@ class GeneratorImpl : public TextGenerator {
/// @param expr the call expression
/// @returns true if the call expression is emitted
bool EmitSelectCall(std::ostream& out, const ast::CallExpression* expr);
/// Handles generating a call to the `dot()` intrinsic
/// @param out the output of the expression stream
/// @param expr the call expression
/// @param intrinsic the semantic information for the intrinsic
/// @returns true if the call expression is emitted
bool EmitDotCall(std::ostream& out,
const ast::CallExpression* expr,
const sem::Intrinsic* intrinsic);
/// Handles generating a call to the `modf()` intrinsic
/// @param out the output of the expression stream
/// @param expr the call expression
@@ -416,6 +424,7 @@ class GeneratorImpl : public TextGenerator {
std::unordered_map<const sem::Intrinsic*, std::string> intrinsics_;
std::unordered_map<const sem::Struct*, std::string> structure_builders_;
std::unordered_map<const sem::Vector*, std::string> dynamic_vector_write_;
std::unordered_map<const sem::Vector*, std::string> int_dot_funcs_;
};
} // namespace glsl

View File

@@ -599,6 +599,64 @@ void main() {
}
#endif
TEST_F(GlslGeneratorImplTest_Intrinsic, DotI32) {
Global("v", ty.vec3<i32>(), ast::StorageClass::kPrivate);
WrapInFunction(Call("dot", "v", "v"));
GeneratorImpl& gen = SanitizeAndBuild();
ASSERT_TRUE(gen.Generate()) << gen.error();
EXPECT_EQ(gen.result(), R"(#version 310 es
precision mediump float;
int tint_int_dot(ivec3 a, ivec3 b) {
return a[0]*b[0] + a[1]*b[1] + a[2]*b[2];
}
ivec3 v = ivec3(0, 0, 0);
layout(local_size_x = 1, local_size_y = 1, local_size_z = 1) in;
void test_function() {
tint_int_dot(v, v);
return;
}
void main() {
test_function();
}
)");
}
TEST_F(GlslGeneratorImplTest_Intrinsic, DotU32) {
Global("v", ty.vec3<u32>(), ast::StorageClass::kPrivate);
WrapInFunction(Call("dot", "v", "v"));
GeneratorImpl& gen = SanitizeAndBuild();
ASSERT_TRUE(gen.Generate()) << gen.error();
EXPECT_EQ(gen.result(), R"(#version 310 es
precision mediump float;
uint tint_int_dot(uvec3 a, uvec3 b) {
return a[0]*b[0] + a[1]*b[1] + a[2]*b[2];
}
uvec3 v = uvec3(0u, 0u, 0u);
layout(local_size_x = 1, local_size_y = 1, local_size_z = 1) in;
void test_function() {
tint_int_dot(v, v);
return;
}
void main() {
test_function();
}
)");
}
} // namespace
} // namespace glsl
} // namespace writer

View File

@@ -543,6 +543,8 @@ bool GeneratorImpl::EmitIntrinsicCall(std::ostream& out,
auto name = generate_builtin_name(intrinsic);
switch (intrinsic->Type()) {
case sem::IntrinsicType::kDot:
return EmitDotCall(out, expr, intrinsic);
case sem::IntrinsicType::kModf:
return EmitModfCall(out, expr, intrinsic);
case sem::IntrinsicType::kFrexp:
@@ -1005,6 +1007,53 @@ bool GeneratorImpl::EmitTextureCall(std::ostream& out,
return true;
}
bool GeneratorImpl::EmitDotCall(std::ostream& out,
const ast::CallExpression* expr,
const sem::Intrinsic* intrinsic) {
auto* vec_ty = intrinsic->Parameters()[0]->Type()->As<sem::Vector>();
std::string fn = "dot";
if (vec_ty->type()->is_integer_scalar()) {
// MSL does not have a builtin for dot() with integer vector types.
// Generate the helper function if it hasn't been created already
fn = utils::GetOrCreate(
int_dot_funcs_, vec_ty->Width(), [&]() -> std::string {
TextBuffer b;
TINT_DEFER(helpers_.Append(b));
auto fn_name =
UniqueIdentifier("tint_dot" + std::to_string(vec_ty->Width()));
auto v = "vec<T," + std::to_string(vec_ty->Width()) + ">";
line(&b) << "template<typename T>";
line(&b) << "T " << fn_name << "(" << v << " a, " << v << " b) {";
{
auto l = line(&b);
l << " return ";
for (uint32_t i = 0; i < vec_ty->Width(); i++) {
if (i > 0) {
l << " + ";
}
l << "a[" << i << "]*b[" << i << "]";
}
l << ";";
}
line(&b) << "}";
return fn_name;
});
}
out << fn << "(";
if (!EmitExpression(out, expr->args[0])) {
return false;
}
out << ", ";
if (!EmitExpression(out, expr->args[1])) {
return false;
}
out << ")";
return true;
}
bool GeneratorImpl::EmitModfCall(std::ostream& out,
const ast::CallExpression* expr,
const sem::Intrinsic* intrinsic) {

View File

@@ -154,6 +154,14 @@ class GeneratorImpl : public TextGenerator {
bool EmitTextureCall(std::ostream& out,
const ast::CallExpression* expr,
const sem::Intrinsic* intrinsic);
/// Handles generating a call to the `dot()` intrinsic
/// @param out the output of the expression stream
/// @param expr the call expression
/// @param intrinsic the semantic information for the intrinsic
/// @returns true if the call expression is emitted
bool EmitDotCall(std::ostream& out,
const ast::CallExpression* expr,
const sem::Intrinsic* intrinsic);
/// Handles generating a call to the `modf()` intrinsic
/// @param out the output of the expression stream
/// @param expr the call expression
@@ -394,6 +402,7 @@ class GeneratorImpl : public TextGenerator {
std::unordered_map<const sem::Intrinsic*, std::string> intrinsics_;
std::unordered_map<const sem::Type*, std::string> unary_minus_funcs_;
std::unordered_map<uint32_t, std::string> int_dot_funcs_;
};
} // namespace msl

View File

@@ -343,6 +343,30 @@ TEST_F(MslGeneratorImplTest, Unpack2x16Float) {
EXPECT_EQ(out.str(), "float2(as_type<half2>(p1))");
}
TEST_F(MslGeneratorImplTest, DotI32) {
Global("v", ty.vec3<i32>(), ast::StorageClass::kPrivate);
WrapInFunction(Call("dot", "v", "v"));
GeneratorImpl& gen = SanitizeAndBuild();
ASSERT_TRUE(gen.Generate()) << gen.error();
EXPECT_EQ(gen.result(), R"(#include <metal_stdlib>
using namespace metal;
template<typename T>
T tint_dot3(vec<T,3> a, vec<T,3> b) {
return a[0]*b[0] + a[1]*b[1] + a[2]*b[2];
}
kernel void test_function() {
thread int3 tint_symbol = 0;
tint_dot3(tint_symbol, tint_symbol);
return;
}
)");
}
TEST_F(MslGeneratorImplTest, Ignore) {
Func("f", {Param("a", ty.i32()), Param("b", ty.i32()), Param("c", ty.i32())},
ty.i32(), {Return(Mul(Add("a", "b"), "c"))});

View File

@@ -2423,9 +2423,48 @@ uint32_t Builder::GenerateIntrinsic(const ast::CallExpression* call,
case IntrinsicType::kCountOneBits:
op = spv::Op::OpBitCount;
break;
case IntrinsicType::kDot:
case IntrinsicType::kDot: {
op = spv::Op::OpDot;
auto* vec_ty = intrinsic->Parameters()[0]->Type()->As<sem::Vector>();
if (vec_ty->type()->is_integer_scalar()) {
// TODO(crbug.com/tint/1267): OpDot requires floating-point types, but
// WGSL also supports integer types. SPV_KHR_integer_dot_product adds
// support for integer vectors. Use it if it is available.
auto el_ty = Operand::Int(GenerateTypeIfNeeded(vec_ty->type()));
auto vec_a = Operand::Int(get_arg_as_value_id(0));
auto vec_b = Operand::Int(get_arg_as_value_id(1));
if (vec_a.to_i() == 0 || vec_b.to_i() == 0) {
return 0;
}
auto sum = Operand::Int(0);
for (uint32_t i = 0; i < vec_ty->Width(); i++) {
auto a = result_op();
auto b = result_op();
auto mul = result_op();
if (!push_function_inst(spv::Op::OpCompositeExtract,
{el_ty, a, vec_a, Operand::Int(i)}) ||
!push_function_inst(spv::Op::OpCompositeExtract,
{el_ty, b, vec_b, Operand::Int(i)}) ||
!push_function_inst(spv::Op::OpIMul, {el_ty, mul, a, b})) {
return 0;
}
if (i == 0) {
sum = mul;
} else {
auto prev_sum = sum;
auto is_last_el = i == (vec_ty->Width() - 1);
sum = is_last_el ? Operand::Int(result_id) : result_op();
if (!push_function_inst(spv::Op::OpIAdd,
{el_ty, sum, prev_sum, mul})) {
return 0;
}
}
}
return result_id;
}
break;
}
case IntrinsicType::kDpdx:
op = spv::Op::OpDPdx;
break;

View File

@@ -407,7 +407,7 @@ INSTANTIATE_TEST_SUITE_P(
testing::Values(IntrinsicData{"countOneBits", "OpBitCount"},
IntrinsicData{"reverseBits", "OpBitReverse"}));
TEST_F(IntrinsicBuilderTest, Call_Dot) {
TEST_F(IntrinsicBuilderTest, Call_Dot_F32) {
auto* var = Global("v", ty.vec3<f32>(), ast::StorageClass::kPrivate);
auto* expr = Call("dot", "v", "v");
@@ -432,6 +432,76 @@ TEST_F(IntrinsicBuilderTest, Call_Dot) {
)");
}
TEST_F(IntrinsicBuilderTest, Call_Dot_U32) {
auto* var = Global("v", ty.vec3<u32>(), ast::StorageClass::kPrivate);
auto* expr = Call("dot", "v", "v");
WrapInFunction(expr);
spirv::Builder& b = Build();
b.push_function(Function{});
ASSERT_TRUE(b.GenerateGlobalVariable(var)) << b.error();
EXPECT_EQ(b.GenerateCallExpression(expr), 6u) << b.error();
EXPECT_EQ(DumpInstructions(b.types()), R"(%4 = OpTypeInt 32 0
%3 = OpTypeVector %4 3
%2 = OpTypePointer Private %3
%5 = OpConstantNull %3
%1 = OpVariable %2 Private %5
)");
EXPECT_EQ(DumpInstructions(b.functions()[0].instructions()),
R"(%7 = OpLoad %3 %1
%8 = OpLoad %3 %1
%9 = OpCompositeExtract %4 %7 0
%10 = OpCompositeExtract %4 %8 0
%11 = OpIMul %4 %9 %10
%12 = OpCompositeExtract %4 %7 1
%13 = OpCompositeExtract %4 %8 1
%14 = OpIMul %4 %12 %13
%15 = OpIAdd %4 %11 %14
%16 = OpCompositeExtract %4 %7 2
%17 = OpCompositeExtract %4 %8 2
%18 = OpIMul %4 %16 %17
%6 = OpIAdd %4 %15 %18
)");
}
TEST_F(IntrinsicBuilderTest, Call_Dot_I32) {
auto* var = Global("v", ty.vec3<i32>(), ast::StorageClass::kPrivate);
auto* expr = Call("dot", "v", "v");
WrapInFunction(expr);
spirv::Builder& b = Build();
b.push_function(Function{});
ASSERT_TRUE(b.GenerateGlobalVariable(var)) << b.error();
EXPECT_EQ(b.GenerateCallExpression(expr), 6u) << b.error();
EXPECT_EQ(DumpInstructions(b.types()), R"(%4 = OpTypeInt 32 1
%3 = OpTypeVector %4 3
%2 = OpTypePointer Private %3
%5 = OpConstantNull %3
%1 = OpVariable %2 Private %5
)");
EXPECT_EQ(DumpInstructions(b.functions()[0].instructions()),
R"(%7 = OpLoad %3 %1
%8 = OpLoad %3 %1
%9 = OpCompositeExtract %4 %7 0
%10 = OpCompositeExtract %4 %8 0
%11 = OpIMul %4 %9 %10
%12 = OpCompositeExtract %4 %7 1
%13 = OpCompositeExtract %4 %8 1
%14 = OpIMul %4 %12 %13
%15 = OpIAdd %4 %11 %14
%16 = OpCompositeExtract %4 %7 2
%17 = OpCompositeExtract %4 %8 2
%18 = OpIMul %4 %16 %17
%6 = OpIAdd %4 %15 %18
)");
}
using IntrinsicDeriveTest = IntrinsicBuilderTestWithParam<IntrinsicData>;
TEST_P(IntrinsicDeriveTest, Call_Derivative_Scalar) {
auto param = GetParam();