const eval of clamp

Bug: tint:1581
Change-Id: Icffaf023021e704b3beaacf35136a9d6263b31e3
Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/97221
Commit-Queue: Antonio Maiorano <amaiorano@google.com>
Reviewed-by: Ben Clayton <bclayton@google.com>
This commit is contained in:
Antonio Maiorano
2022-07-27 15:02:25 +00:00
committed by Dawn LUCI CQ
parent 4de90f0bb1
commit bf8ee35498
179 changed files with 5222 additions and 701 deletions

View File

@@ -135,6 +135,7 @@ match scalar_no_i32: f32 | f16 | u32 | bool
match scalar_no_u32: f32 | f16 | i32 | bool
match scalar_no_bool: f32 | f16 | i32 | u32
match fia_fi32_f16: fa | ia | f32 | i32 | f16
match fia_fiu32: fa | ia | f32 | i32 | u32
match fa_f32: fa | f32
match fa_f32_f16: fa | f32 | f16
match ia_iu32: ia | i32 | u32
@@ -377,8 +378,8 @@ fn atanh(f32) -> f32
fn atanh<N: num>(vec<N, f32>) -> vec<N, f32>
fn ceil(f32) -> f32
fn ceil<N: num>(vec<N, f32>) -> vec<N, f32>
fn clamp<T: fiu32>(T, T, T) -> T
fn clamp<N: num, T: fiu32>(vec<N, T>, vec<N, T>, vec<N, T>) -> vec<N, T>
@const fn clamp<T: fia_fiu32>(T, T, T) -> T
@const fn clamp<N: num, T: fia_fiu32>(vec<N, T>, vec<N, T>, vec<N, T>) -> vec<N, T>
fn cos(f32) -> f32
fn cos<N: num>(vec<N, f32>) -> vec<N, f32>
fn cosh(f32) -> f32

View File

@@ -1415,14 +1415,16 @@ TEST_P(ResolverBuiltinTest_ThreeParam_FloatOrInt, Error_NoParams) {
EXPECT_FALSE(r()->Resolve());
EXPECT_EQ(r()->error(), "error: no matching call to " + std::string(param.name) +
"()\n\n"
"2 candidate functions:\n " +
std::string(param.name) +
"(T, T, T) -> T where: T is f32, i32 or u32\n " +
std::string(param.name) +
"(vecN<T>, vecN<T>, vecN<T>) -> vecN<T> where: T is f32, i32 "
"or u32\n");
EXPECT_EQ(
r()->error(),
"error: no matching call to " + std::string(param.name) +
"()\n\n"
"2 candidate functions:\n " +
std::string(param.name) +
"(T, T, T) -> T where: T is abstract-float, abstract-int, f32, i32 or u32\n " +
std::string(param.name) +
"(vecN<T>, vecN<T>, vecN<T>) -> vecN<T> where: T is abstract-float, abstract-int, "
"f32, i32 or u32\n");
}
INSTANTIATE_TEST_SUITE_P(ResolverTest,

View File

@@ -79,6 +79,19 @@ auto Dispatch_fia_fi32_f16(F&& f, CONSTANTS&&... cs) {
});
}
/// Helper that calls `f` passing in the value of all `cs`.
/// Assumes all `cs` are of the same type.
template <typename F, typename... CONSTANTS>
auto Dispatch_fia_fiu32(F&& f, CONSTANTS&&... cs) {
return Switch(
First(cs...)->Type(), //
[&](const sem::AbstractInt*) { return f(cs->template As<AInt>()...); },
[&](const sem::AbstractFloat*) { return f(cs->template As<AFloat>()...); },
[&](const sem::F32*) { return f(cs->template As<f32>()...); },
[&](const sem::I32*) { return f(cs->template As<i32>()...); },
[&](const sem::U32*) { return f(cs->template As<u32>()...); });
}
/// Helper that calls `f` passing in the value of all `cs`.
/// Assumes all `cs` are of the same type.
template <typename F, typename... CONSTANTS>
@@ -723,6 +736,20 @@ const sem::Constant* ConstEval::atan2(const sem::Type*,
args[1]->ConstantValue());
}
const sem::Constant* ConstEval::clamp(const sem::Type*,
utils::ConstVectorRef<const sem::Expression*> args) {
auto transform = [&](const sem::Constant* c0, const sem::Constant* c1,
const sem::Constant* c2) {
auto create = [&](auto e, auto low, auto high) {
return CreateElement(builder, c0->Type(),
decltype(e)(std::min(std::max(e, low), high)));
};
return Dispatch_fia_fiu32(create, c0, c1, c2);
};
return TransformElements(builder, transform, args[0]->ConstantValue(), args[1]->ConstantValue(),
args[2]->ConstantValue());
}
utils::Result<const sem::Constant*> ConstEval::Convert(const sem::Type* target_ty,
const sem::Constant* value,
const Source& source) {

View File

@@ -198,6 +198,13 @@ class ConstEval {
const sem::Constant* atan2(const sem::Type* ty,
utils::ConstVectorRef<const sem::Expression*> args);
/// clamp builtin
/// @param ty the expression type
/// @param args the input arguments
/// @return the result value, or null if the value cannot be calculated
const sem::Constant* clamp(const sem::Type* ty,
utils::ConstVectorRef<const sem::Expression*> args);
private:
/// Adds the given error message to the diagnostics
void AddError(const std::string& msg, const Source& source) const;

View File

@@ -3254,6 +3254,31 @@ INSTANTIATE_TEST_SUITE_P( //
testing::ValuesIn(Concat(Atan2Cases<AFloat, true>(), //
Atan2Cases<f32, false>()))));
template <typename T>
std::vector<Case> ClampCases() {
return {
C({T(0), T(0), T(0)}, T(0)),
C({T(0), T(42), kHighest<T>}, T(42)),
C({kLowest<T>, T(0), T(42)}, T(0)),
C({T(0), kLowest<T>, kHighest<T>}, T(0)),
C({T(0), kHighest<T>, kLowest<T>}, kLowest<T>),
C({kHighest<T>, kHighest<T>, kHighest<T>}, kHighest<T>),
C({kLowest<T>, kLowest<T>, kLowest<T>}, kLowest<T>),
C({kHighest<T>, kLowest<T>, kHighest<T>}, kHighest<T>),
C({kLowest<T>, kLowest<T>, kHighest<T>}, kLowest<T>),
};
}
INSTANTIATE_TEST_SUITE_P( //
Clamp,
ResolverConstEvalBuiltinTest,
testing::Combine(testing::Values(sem::BuiltinType::kClamp),
testing::ValuesIn(Concat(ClampCases<AInt>(), //
ClampCases<i32>(),
ClampCases<u32>(),
ClampCases<AFloat>(),
ClampCases<f32>()))));
} // namespace builtin
} // namespace

File diff suppressed because it is too large Load Diff

View File

@@ -1075,14 +1075,14 @@ INSTANTIATE_TEST_SUITE_P(
IntrinsicTableAbstractTernaryTest,
testing::Values( // clang-format off
// result | param a | param b | param c | arg a | arg b | arg c
Case::Create<f32, f32, f32, f32, AFloat, AFloat, AFloat>(),
Case::Create<f32, f32, f32, f32, AFloat, AFloat, AInt>(),
Case::Create<f32, f32, f32, f32, AFloat, AInt, AFloat>(),
Case::Create<f32, f32, f32, f32, AFloat, AInt, AInt>(),
Case::Create<f32, f32, f32, f32, AInt, AFloat, AFloat>(),
Case::Create<f32, f32, f32, f32, AInt, AFloat, AInt>(),
Case::Create<f32, f32, f32, f32, AInt, AInt, AFloat>(),
Case::Create<i32, i32, i32, i32, AInt, AInt, AInt>()
Case::Create<AFloat, AFloat, AFloat, AFloat, AFloat, AFloat, AFloat>(),
Case::Create<AFloat, AFloat, AFloat, AFloat, AFloat, AFloat, AInt>(),
Case::Create<AFloat, AFloat, AFloat, AFloat, AFloat, AInt, AFloat>(),
Case::Create<AFloat, AFloat, AFloat, AFloat, AFloat, AInt, AInt>(),
Case::Create<AFloat, AFloat, AFloat, AFloat, AInt, AFloat, AFloat>(),
Case::Create<AFloat, AFloat, AFloat, AFloat, AInt, AFloat, AInt>(),
Case::Create<AFloat, AFloat, AFloat, AFloat, AInt, AInt, AFloat>(),
Case::Create<AInt, AInt, AInt, AInt, AInt, AInt, AInt>()
// clang-format on
));
@@ -1091,14 +1091,14 @@ INSTANTIATE_TEST_SUITE_P(
IntrinsicTableAbstractTernaryTest,
testing::Values( // clang-format off
// result | param a | param b | param c | arg a | arg b | arg c
Case::Create<f32V, f32V, f32V, f32V, AFloatV, AFloatV, AFloatV>(),
Case::Create<f32V, f32V, f32V, f32V, AFloatV, AFloatV, AIntV>(),
Case::Create<f32V, f32V, f32V, f32V, AFloatV, AIntV, AFloatV>(),
Case::Create<f32V, f32V, f32V, f32V, AFloatV, AIntV, AIntV>(),
Case::Create<f32V, f32V, f32V, f32V, AIntV, AFloatV, AFloatV>(),
Case::Create<f32V, f32V, f32V, f32V, AIntV, AFloatV, AIntV>(),
Case::Create<f32V, f32V, f32V, f32V, AIntV, AIntV, AFloatV>(),
Case::Create<i32V, i32V, i32V, i32V, AIntV, AIntV, AIntV>()
Case::Create<AFloatV, AFloatV, AFloatV, AFloatV, AFloatV, AFloatV, AFloatV>(),
Case::Create<AFloatV, AFloatV, AFloatV, AFloatV, AFloatV, AFloatV, AIntV>(),
Case::Create<AFloatV, AFloatV, AFloatV, AFloatV, AFloatV, AIntV, AFloatV>(),
Case::Create<AFloatV, AFloatV, AFloatV, AFloatV, AFloatV, AIntV, AIntV>(),
Case::Create<AFloatV, AFloatV, AFloatV, AFloatV, AIntV, AFloatV, AFloatV>(),
Case::Create<AFloatV, AFloatV, AFloatV, AFloatV, AIntV, AFloatV, AIntV>(),
Case::Create<AFloatV, AFloatV, AFloatV, AFloatV, AIntV, AIntV, AFloatV>(),
Case::Create<AIntV, AIntV, AIntV, AIntV, AIntV, AIntV, AIntV>()
// clang-format on
));
@@ -1270,6 +1270,68 @@ Case::Create<u32V, u32V, u32V, u32V, u32V, u32V, AIntV>(
// clang-format on
));
struct IntrinsicTableAbstractTernaryTest_NonConstEval : public ResolverTestWithParam<Case> {
std::unique_ptr<IntrinsicTable> table = IntrinsicTable::Create(*this);
};
TEST_P(IntrinsicTableAbstractTernaryTest_NonConstEval, MatchMix) {
auto* arg_a = GetParam().arg_a(*this);
auto* arg_b = GetParam().arg_b(*this);
auto* arg_c = GetParam().arg_c(*this);
auto builtin =
table->Lookup(sem::BuiltinType::kMix, utils::Vector{arg_a, arg_b, arg_c}, Source{{12, 34}});
bool matched = builtin.sem != nullptr;
bool expected_match = GetParam().expected_match;
EXPECT_EQ(matched, expected_match) << Diagnostics().str();
auto* result = builtin.sem ? builtin.sem->ReturnType() : nullptr;
auto* expected_result = GetParam().expected_result(*this);
EXPECT_TYPE(result, expected_result);
auto* param_a = builtin.sem ? builtin.sem->Parameters()[0]->Type() : nullptr;
auto* expected_param_a = GetParam().expected_param_a(*this);
EXPECT_TYPE(param_a, expected_param_a);
auto* param_b = builtin.sem ? builtin.sem->Parameters()[1]->Type() : nullptr;
auto* expected_param_b = GetParam().expected_param_b(*this);
EXPECT_TYPE(param_b, expected_param_b);
auto* param_c = builtin.sem ? builtin.sem->Parameters()[2]->Type() : nullptr;
auto* expected_param_c = GetParam().expected_param_c(*this);
EXPECT_TYPE(param_c, expected_param_c);
}
INSTANTIATE_TEST_SUITE_P(
AFloat_f32,
IntrinsicTableAbstractTernaryTest_NonConstEval,
testing::Values( // clang-format off
// result | param a | param b | param c | arg a | arg b | arg c
Case::Create<f32, f32, f32, f32, AFloat, AFloat, AFloat>(),
Case::Create<f32, f32, f32, f32, AFloat, AFloat, f32>(),
Case::Create<f32, f32, f32, f32, AFloat, f32, AFloat>(),
Case::Create<f32, f32, f32, f32, AFloat, f32, f32>(),
Case::Create<f32, f32, f32, f32, f32, AFloat, AFloat>(),
Case::Create<f32, f32, f32, f32, f32, AFloat, f32>(),
Case::Create<f32, f32, f32, f32, f32, f32, AFloat>()
// clang-format on
));
INSTANTIATE_TEST_SUITE_P(
VecAFloat_Vecf32,
IntrinsicTableAbstractTernaryTest_NonConstEval,
testing::Values( // clang-format off
// result | param a | param b | param c | arg a | arg b | arg c
Case::Create<f32V, f32V, f32V, f32V, AFloatV, AFloatV, AFloatV>(),
Case::Create<f32V, f32V, f32V, f32V, AFloatV, AFloatV, f32V>(),
Case::Create<f32V, f32V, f32V, f32V, AFloatV, f32V, AFloatV>(),
Case::Create<f32V, f32V, f32V, f32V, AFloatV, f32V, f32V>(),
Case::Create<f32V, f32V, f32V, f32V, f32V, AFloatV, AFloatV>(),
Case::Create<f32V, f32V, f32V, f32V, f32V, AFloatV, f32V>(),
Case::Create<f32V, f32V, f32V, f32V, f32V, f32V, AFloatV> ()
// clang-format on
));
} // namespace AbstractTernaryTests
} // namespace