tint: const eval of 'select' builtin

Bug: tint:1581
Change-Id: I222433acb6a30245ab319a15081811f691aca9ff
Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/104440
Commit-Queue: Antonio Maiorano <amaiorano@google.com>
Reviewed-by: Ben Clayton <bclayton@google.com>
Kokoro: Kokoro <noreply+kokoro@google.com>
This commit is contained in:
Antonio Maiorano
2022-10-05 01:39:53 +00:00
committed by Dawn LUCI CQ
parent 44f7b8ddf7
commit 8800d885e7
374 changed files with 12808 additions and 4954 deletions

View File

@@ -510,9 +510,9 @@ fn round<T: f32_f16>(T) -> T
fn round<N: num, T: f32_f16>(vec<N, T>) -> vec<N, T>
fn saturate<T: f32_f16>(T) -> T
fn saturate<T: f32_f16, N: num>(vec<N, T>) -> vec<N, T>
fn select<T: scalar>(T, T, bool) -> T
fn select<T: scalar, N: num>(vec<N, T>, vec<N, T>, bool) -> vec<N, T>
fn select<N: num, T: scalar>(vec<N, T>, vec<N, T>, vec<N, bool>) -> vec<N, T>
@const("select_bool") fn select<T: abstract_or_scalar>(T, T, bool) -> T
@const("select_bool") fn select<T: abstract_or_scalar, N: num>(vec<N, T>, vec<N, T>, bool) -> vec<N, T>
@const("select_boolvec") fn select<N: num, T: abstract_or_scalar>(vec<N, T>, vec<N, T>, vec<N, bool>) -> vec<N, T>
fn sign<T: f32_f16>(T) -> T
fn sign<N: num, T: f32_f16>(vec<N, T>) -> vec<N, T>
fn sin<T: f32_f16>(T) -> T

View File

@@ -131,9 +131,9 @@ TEST_F(ResolverBuiltinTest, Select_Error_NoParams) {
R"(error: no matching call to select()
3 candidate functions:
select(T, T, bool) -> T where: T is f32, f16, i32, u32 or bool
select(vecN<T>, vecN<T>, bool) -> vecN<T> where: T is f32, f16, i32, u32 or bool
select(vecN<T>, vecN<T>, vecN<bool>) -> vecN<T> where: T is f32, f16, i32, u32 or bool
select(T, T, bool) -> T where: T is abstract-int, abstract-float, f32, f16, i32, u32 or bool
select(vecN<T>, vecN<T>, bool) -> vecN<T> where: T is abstract-int, abstract-float, f32, f16, i32, u32 or bool
select(vecN<T>, vecN<T>, vecN<bool>) -> vecN<T> where: T is abstract-int, abstract-float, f32, f16, i32, u32 or bool
)");
}
@@ -147,9 +147,9 @@ TEST_F(ResolverBuiltinTest, Select_Error_SelectorInt) {
R"(error: no matching call to select(i32, i32, i32)
3 candidate functions:
select(T, T, bool) -> T where: T is f32, f16, i32, u32 or bool
select(vecN<T>, vecN<T>, bool) -> vecN<T> where: T is f32, f16, i32, u32 or bool
select(vecN<T>, vecN<T>, vecN<bool>) -> vecN<T> where: T is f32, f16, i32, u32 or bool
select(T, T, bool) -> T where: T is abstract-int, abstract-float, f32, f16, i32, u32 or bool
select(vecN<T>, vecN<T>, bool) -> vecN<T> where: T is abstract-int, abstract-float, f32, f16, i32, u32 or bool
select(vecN<T>, vecN<T>, vecN<bool>) -> vecN<T> where: T is abstract-int, abstract-float, f32, f16, i32, u32 or bool
)");
}
@@ -164,9 +164,9 @@ TEST_F(ResolverBuiltinTest, Select_Error_Matrix) {
R"(error: no matching call to select(mat2x2<f32>, mat2x2<f32>, bool)
3 candidate functions:
select(T, T, bool) -> T where: T is f32, f16, i32, u32 or bool
select(vecN<T>, vecN<T>, bool) -> vecN<T> where: T is f32, f16, i32, u32 or bool
select(vecN<T>, vecN<T>, vecN<bool>) -> vecN<T> where: T is f32, f16, i32, u32 or bool
select(T, T, bool) -> T where: T is abstract-int, abstract-float, f32, f16, i32, u32 or bool
select(vecN<T>, vecN<T>, bool) -> vecN<T> where: T is abstract-int, abstract-float, f32, f16, i32, u32 or bool
select(vecN<T>, vecN<T>, vecN<bool>) -> vecN<T> where: T is abstract-int, abstract-float, f32, f16, i32, u32 or bool
)");
}
@@ -180,9 +180,9 @@ TEST_F(ResolverBuiltinTest, Select_Error_MismatchTypes) {
R"(error: no matching call to select(f32, vec2<f32>, bool)
3 candidate functions:
select(T, T, bool) -> T where: T is f32, f16, i32, u32 or bool
select(vecN<T>, vecN<T>, bool) -> vecN<T> where: T is f32, f16, i32, u32 or bool
select(vecN<T>, vecN<T>, vecN<bool>) -> vecN<T> where: T is f32, f16, i32, u32 or bool
select(T, T, bool) -> T where: T is abstract-int, abstract-float, f32, f16, i32, u32 or bool
select(vecN<T>, vecN<T>, bool) -> vecN<T> where: T is abstract-int, abstract-float, f32, f16, i32, u32 or bool
select(vecN<T>, vecN<T>, vecN<bool>) -> vecN<T> where: T is abstract-int, abstract-float, f32, f16, i32, u32 or bool
)");
}
@@ -196,9 +196,9 @@ TEST_F(ResolverBuiltinTest, Select_Error_MismatchVectorSize) {
R"(error: no matching call to select(vec2<f32>, vec3<f32>, bool)
3 candidate functions:
select(T, T, bool) -> T where: T is f32, f16, i32, u32 or bool
select(vecN<T>, vecN<T>, bool) -> vecN<T> where: T is f32, f16, i32, u32 or bool
select(vecN<T>, vecN<T>, vecN<bool>) -> vecN<T> where: T is f32, f16, i32, u32 or bool
select(T, T, bool) -> T where: T is abstract-int, abstract-float, f32, f16, i32, u32 or bool
select(vecN<T>, vecN<T>, bool) -> vecN<T> where: T is abstract-int, abstract-float, f32, f16, i32, u32 or bool
select(vecN<T>, vecN<T>, vecN<bool>) -> vecN<T> where: T is abstract-int, abstract-float, f32, f16, i32, u32 or bool
)");
}

View File

@@ -503,25 +503,30 @@ const ImplConstant* CreateComposite(ProgramBuilder& builder,
}
}
/// TransformElements constructs a new constant of type `composite_ty` by applying the
/// transformation function 'f' on each of the most deeply nested elements of 'cs'. Assumes that all
/// input constants `cs` are of the same type.
namespace detail {
/// Implementation of TransformElements
template <typename F, typename... CONSTANTS>
ImplResult TransformElements(ProgramBuilder& builder,
const sem::Type* composite_ty,
F&& f,
size_t index,
CONSTANTS&&... cs) {
uint32_t n = 0;
auto* ty = First(cs...)->Type();
auto* el_ty = sem::Type::ElementOf(ty, &n);
if (el_ty == ty) {
return f(cs...);
constexpr bool kHasIndexParam = traits::IsType<size_t, traits::LastParameterType<F>>;
if constexpr (kHasIndexParam) {
return f(cs..., index);
} else {
return f(cs...);
}
}
utils::Vector<const sem::Constant*, 8> els;
els.Reserve(n);
for (uint32_t i = 0; i < n; i++) {
if (auto el = TransformElements(builder, sem::Type::ElementOf(composite_ty),
std::forward<F>(f), cs->Index(i)...)) {
if (auto el = detail::TransformElements(builder, sem::Type::ElementOf(composite_ty),
std::forward<F>(f), index + i, cs->Index(i)...)) {
els.Push(el.Get());
} else {
@@ -530,10 +535,24 @@ ImplResult TransformElements(ProgramBuilder& builder,
}
return CreateComposite(builder, composite_ty, std::move(els));
}
} // namespace detail
/// TransformElements constructs a new constant of type `composite_ty` by applying the
/// transformation function `f` on each of the most deeply nested elements of 'cs'. Assumes that all
/// input constants `cs` are of the same arity (all scalars or all vectors of the same size).
/// If `f`'s last argument is a `size_t`, then the index of the most deeply nested element inside
/// the most deeply nested aggregate type will be passed in.
template <typename F, typename... CONSTANTS>
ImplResult TransformElements(ProgramBuilder& builder,
const sem::Type* composite_ty,
F&& f,
CONSTANTS&&... cs) {
return detail::TransformElements(builder, composite_ty, f, 0, cs...);
}
/// TransformBinaryElements constructs a new constant of type `composite_ty` by applying the
/// transformation function 'f' on each of the most deeply nested elements of both `c0` and `c1`.
/// Unlike TransformElements, this function handles the constants being of different types, e.g.
/// Unlike TransformElements, this function handles the constants being of different arity, e.g.
/// vector-scalar, scalar-vector.
template <typename F>
ImplResult TransformBinaryElements(ProgramBuilder& builder,
@@ -1516,6 +1535,35 @@ ConstEval::Result ConstEval::clamp(const sem::Type* ty,
return TransformElements(builder, ty, transform, args[0], args[1], args[2]);
}
ConstEval::Result ConstEval::select_bool(const sem::Type* ty,
utils::VectorRef<const sem::Constant*> args,
const Source&) {
auto cond = args[2]->As<bool>();
auto transform = [&](const sem::Constant* c0, const sem::Constant* c1) {
auto create = [&](auto f, auto t) -> ImplResult {
return CreateElement(builder, sem::Type::DeepestElementOf(ty), cond ? t : f);
};
return Dispatch_fia_fiu32_f16_bool(create, c0, c1);
};
return TransformElements(builder, ty, transform, args[0], args[1]);
}
ConstEval::Result ConstEval::select_boolvec(const sem::Type* ty,
utils::VectorRef<const sem::Constant*> args,
const Source&) {
auto transform = [&](const sem::Constant* c0, const sem::Constant* c1, size_t index) {
auto create = [&](auto f, auto t) -> ImplResult {
// Get corresponding bool value at the current vector value index
auto cond = args[2]->Index(index)->As<bool>();
return CreateElement(builder, sem::Type::DeepestElementOf(ty), cond ? t : f);
};
return Dispatch_fia_fiu32_f16_bool(create, c0, c1);
};
return TransformElements(builder, ty, transform, args[0], args[1]);
}
ConstEval::Result ConstEval::Convert(const sem::Type* target_ty,
const sem::Constant* value,
const Source& source) {

View File

@@ -395,6 +395,24 @@ class ConstEval {
utils::VectorRef<const sem::Constant*> args,
const Source& source);
/// select builtin with single bool third arg
/// @param ty the expression type
/// @param args the input arguments
/// @param source the source location of the conversion
/// @return the result value, or null if the value cannot be calculated
Result select_bool(const sem::Type* ty,
utils::VectorRef<const sem::Constant*> args,
const Source& source);
/// select builtin with vector of bool third arg
/// @param ty the expression type
/// @param args the input arguments
/// @param source the source location of the conversion
/// @return the result value, or null if the value cannot be calculated
Result select_boolvec(const sem::Type* ty,
utils::VectorRef<const sem::Constant*> args,
const Source& source);
private:
/// Adds the given error message to the diagnostics
void AddError(const std::string& msg, const Source& source) const;

View File

@@ -2879,6 +2879,7 @@ using Types = std::variant< //
Value<builder::vec3<i32>>,
Value<builder::vec3<f32>>,
Value<builder::vec3<f16>>,
Value<builder::vec3<bool>>,
Value<builder::vec4<AInt>>,
Value<builder::vec4<AFloat>>,
@@ -2886,6 +2887,7 @@ using Types = std::variant< //
Value<builder::vec4<i32>>,
Value<builder::vec4<f32>>,
Value<builder::vec4<f16>>,
Value<builder::vec4<bool>>,
Value<builder::mat2x2<AInt>>,
Value<builder::mat2x2<AFloat>>,
@@ -4088,7 +4090,6 @@ static std::ostream& operator<<(std::ostream& o, const Case& c) {
}
/// Creates a Case with Values for args and result
// template <typename T>
static Case C(std::initializer_list<Types> args, Types result) {
return Case{utils::Vector<Types, 8>{args}, std::move(result)};
}
@@ -4287,6 +4288,52 @@ INSTANTIATE_TEST_SUITE_P( //
ClampCases<f32>(),
ClampCases<f16>()))));
template <typename T>
std::vector<Case> SelectCases() {
return {
C({Val(T{1}), Val(T{2}), Val(false)}, Val(T{1})),
C({Val(T{1}), Val(T{2}), Val(true)}, Val(T{2})),
C({Val(T{2}), Val(T{1}), Val(false)}, Val(T{2})),
C({Val(T{2}), Val(T{1}), Val(true)}, Val(T{1})),
C({Vec(T{1}, T{2}), Vec(T{3}, T{4}), Vec(false, false)}, Vec(T{1}, T{2})),
C({Vec(T{1}, T{2}), Vec(T{3}, T{4}), Vec(false, true)}, Vec(T{1}, T{4})),
C({Vec(T{1}, T{2}), Vec(T{3}, T{4}), Vec(true, false)}, Vec(T{3}, T{2})),
C({Vec(T{1}, T{2}), Vec(T{3}, T{4}), Vec(true, true)}, Vec(T{3}, T{4})),
C({Vec(T{1}, T{1}, T{2}, T{2}), //
Vec(T{2}, T{2}, T{1}, T{1}), //
Vec(false, true, false, true)}, //
Vec(T{1}, T{2}, T{2}, T{1})), //
};
}
static std::vector<Case> SelectBoolCases() {
return {
C({Val(true), Val(false), Val(false)}, Val(true)),
C({Val(true), Val(false), Val(true)}, Val(false)),
C({Val(false), Val(true), Val(true)}, Val(true)),
C({Val(false), Val(true), Val(false)}, Val(false)),
C({Vec(true, true, false, false), //
Vec(false, false, true, true), //
Vec(false, true, true, false)}, //
Vec(true, false, true, false)), //
};
}
INSTANTIATE_TEST_SUITE_P( //
Select,
ResolverConstEvalBuiltinTest,
testing::Combine(testing::Values(sem::BuiltinType::kSelect),
testing::ValuesIn(Concat(SelectCases<AInt>(), //
SelectCases<i32>(),
SelectCases<u32>(),
SelectCases<AFloat>(),
SelectCases<f32>(),
SelectCases<f16>(),
SelectBoolCases()))));
} // namespace builtin
} // namespace

File diff suppressed because it is too large Load Diff

View File

@@ -527,8 +527,7 @@ TEST_F(IntrinsicTableTest, MismatchOpenSizeMatrix) {
ASSERT_THAT(Diagnostics().str(), HasSubstr("no matching call"));
}
// TODO(amaiorano): Enable this test when constexpr `select` is implemented.
TEST_F(IntrinsicTableTest, DISABLED_MatchDifferentArgsElementType_ConstantEval) {
TEST_F(IntrinsicTableTest, MatchDifferentArgsElementType_ConstantEval) {
auto* af = create<sem::AbstractFloat>();
auto* bool_ = create<sem::Bool>();
auto result = table->Lookup(BuiltinType::kSelect, utils::Vector{af, af, bool_},

View File

@@ -73,15 +73,24 @@ struct SignatureOf<R (C::*)(ARGS...) const> {
/// SignatureOfT is an alias to `typename SignatureOf<F>::type`.
template <typename F>
using SignatureOfT = typename SignatureOf<F>::type;
using SignatureOfT = typename SignatureOf<Decay<F>>::type;
/// ParameterType is an alias to `typename SignatureOf<F>::type::parameter<N>`.
template <typename F, std::size_t N>
using ParameterType = typename SignatureOfT<F>::template parameter<N>;
using ParameterType = typename SignatureOfT<Decay<F>>::template parameter<N>;
/// LastParameterType returns the type of the last parameter of `F`. `F` must have at least one
/// parameter.
template <typename F>
using LastParameterType = ParameterType<F, SignatureOfT<Decay<F>>::parameter_count - 1>;
/// ReturnType is an alias to `typename SignatureOf<F>::type::ret`.
template <typename F>
using ReturnType = typename SignatureOfT<F>::ret;
using ReturnType = typename SignatureOfT<Decay<F>>::ret;
/// Returns true iff decayed T and decayed U are the same.
template <typename T, typename U>
static constexpr bool IsType = std::is_same<Decay<T>, Decay<U>>::value;
/// IsTypeOrDerived<T, BASE> is true iff `T` is of type `BASE`, or derives from
/// `BASE`.

View File

@@ -356,25 +356,29 @@ TEST_F(GlslGeneratorImplTest_Builtin, Builtin_Call) {
}
TEST_F(GlslGeneratorImplTest_Builtin, Select_Scalar) {
auto* call = Call("select", 1_f, 2_f, true);
GlobalVar("a", Expr(1_f), ast::AddressSpace::kPrivate);
GlobalVar("b", Expr(2_f), ast::AddressSpace::kPrivate);
auto* call = Call("select", "a", "b", true);
WrapInFunction(CallStmt(call));
GeneratorImpl& gen = Build();
gen.increment_indent();
std::stringstream out;
ASSERT_TRUE(gen.EmitExpression(out, call)) << gen.error();
EXPECT_EQ(out.str(), "(true ? 2.0f : 1.0f)");
EXPECT_EQ(out.str(), "(true ? b : a)");
}
TEST_F(GlslGeneratorImplTest_Builtin, Select_Vector) {
auto* call = Call("select", vec2<i32>(1_i, 2_i), vec2<i32>(3_i, 4_i), vec2<bool>(true, false));
GlobalVar("a", vec2<i32>(1_i, 2_i), ast::AddressSpace::kPrivate);
GlobalVar("b", vec2<i32>(3_i, 4_i), ast::AddressSpace::kPrivate);
auto* call = Call("select", "a", "b", vec2<bool>(true, false));
WrapInFunction(CallStmt(call));
GeneratorImpl& gen = Build();
gen.increment_indent();
std::stringstream out;
ASSERT_TRUE(gen.EmitExpression(out, call)) << gen.error();
EXPECT_EQ(out.str(), "mix(ivec2(1, 2), ivec2(3, 4), bvec2(true, false))");
EXPECT_EQ(out.str(), "mix(a, b, bvec2(true, false))");
}
TEST_F(GlslGeneratorImplTest_Builtin, FMA_f32) {

View File

@@ -356,25 +356,29 @@ TEST_F(HlslGeneratorImplTest_Builtin, Builtin_Call) {
}
TEST_F(HlslGeneratorImplTest_Builtin, Select_Scalar) {
auto* call = Call("select", 1_f, 2_f, true);
GlobalVar("a", Expr(1_f), ast::AddressSpace::kPrivate);
GlobalVar("b", Expr(2_f), ast::AddressSpace::kPrivate);
auto* call = Call("select", "a", "b", true);
WrapInFunction(CallStmt(call));
GeneratorImpl& gen = Build();
gen.increment_indent();
std::stringstream out;
ASSERT_TRUE(gen.EmitExpression(out, call)) << gen.error();
EXPECT_EQ(out.str(), "(true ? 2.0f : 1.0f)");
EXPECT_EQ(out.str(), "(true ? b : a)");
}
TEST_F(HlslGeneratorImplTest_Builtin, Select_Vector) {
auto* call = Call("select", vec2<i32>(1_i, 2_i), vec2<i32>(3_i, 4_i), vec2<bool>(true, false));
GlobalVar("a", vec2<i32>(1_i, 2_i), ast::AddressSpace::kPrivate);
GlobalVar("b", vec2<i32>(3_i, 4_i), ast::AddressSpace::kPrivate);
auto* call = Call("select", "a", "b", vec2<bool>(true, false));
WrapInFunction(CallStmt(call));
GeneratorImpl& gen = Build();
gen.increment_indent();
std::stringstream out;
ASSERT_TRUE(gen.EmitExpression(out, call)) << gen.error();
EXPECT_EQ(out.str(), "(bool2(true, false) ? int2(3, 4) : int2(1, 2))");
EXPECT_EQ(out.str(), "(bool2(true, false) ? b : a)");
}
TEST_F(HlslGeneratorImplTest_Builtin, Modf_Scalar_f32) {