mirror of
https://github.com/encounter/dawn-cmake.git
synced 2025-12-18 01:15:39 +00:00
Add const-eval for smoothstep
This CL adds const-eval for `smoothstep`. Bug: tint:1581 Change-Id: I78aa5c4a39882f29ff78e37313e6c44708719095 Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/110176 Reviewed-by: Antonio Maiorano <amaiorano@google.com> Kokoro: Kokoro <noreply+kokoro@google.com> Commit-Queue: Dan Sinclair <dsinclair@chromium.org>
This commit is contained in:
committed by
Dawn LUCI CQ
parent
19e5042ade
commit
32c28cbc90
@@ -535,8 +535,8 @@ fn refract<N: num, T: f32_f16>(vec<N, T>, vec<N, T>, T) -> vec<N, T>
|
||||
@const fn sin<N: num, T: fa_f32_f16>(vec<N, T>) -> vec<N, T>
|
||||
@const fn sinh<T: fa_f32_f16>(T) -> T
|
||||
@const fn sinh<N: num, T: fa_f32_f16>(vec<N, T>) -> vec<N, T>
|
||||
fn smoothstep<T: f32_f16>(T, T, T) -> T
|
||||
fn smoothstep<N: num, T: f32_f16>(vec<N, T>, vec<N, T>, vec<N, T>) -> vec<N, T>
|
||||
@const fn smoothstep<T: fa_f32_f16>(@test_value(2) T, @test_value(4) T, @test_value(3) T) -> T
|
||||
@const fn smoothstep<N: num, T: fa_f32_f16>(@test_value(2) vec<N, T>, @test_value(4) vec<N, T>, @test_value(3) vec<N, T>) -> vec<N, T>
|
||||
@const fn sqrt<T: fa_f32_f16>(T) -> T
|
||||
@const fn sqrt<N: num, T: fa_f32_f16>(vec<N, T>) -> vec<N, T>
|
||||
@const fn step<T: fa_f32_f16>(T, T) -> T
|
||||
|
||||
@@ -733,6 +733,41 @@ utils::Result<NumberT> ConstEval::Mul(NumberT a, NumberT b) {
|
||||
return result;
|
||||
}
|
||||
|
||||
template <typename NumberT>
|
||||
utils::Result<NumberT> ConstEval::Div(NumberT a, NumberT b) {
|
||||
NumberT result;
|
||||
if constexpr (IsAbstract<NumberT>) {
|
||||
// Check for over/underflow for abstract values
|
||||
if (auto r = CheckedDiv(a, b)) {
|
||||
result = r->value;
|
||||
} else {
|
||||
AddError(OverflowErrorMessage(a, "/", b), *current_source);
|
||||
return utils::Failure;
|
||||
}
|
||||
} else {
|
||||
using T = UnwrapNumber<NumberT>;
|
||||
auto divide_values = [](T lhs, T rhs) {
|
||||
if constexpr (std::is_integral_v<T>) {
|
||||
// For integers, lhs / 0 returns lhs
|
||||
if (rhs == 0) {
|
||||
return lhs;
|
||||
}
|
||||
|
||||
if constexpr (std::is_signed_v<T>) {
|
||||
// For signed integers, for lhs / -1, return lhs if lhs is the
|
||||
// most negative value
|
||||
if (rhs == -1 && lhs == std::numeric_limits<T>::min()) {
|
||||
return lhs;
|
||||
}
|
||||
}
|
||||
}
|
||||
return lhs / rhs;
|
||||
};
|
||||
result = divide_values(a.value, b.value);
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
template <typename NumberT>
|
||||
utils::Result<NumberT> ConstEval::Dot2(NumberT a1, NumberT a2, NumberT b1, NumberT b2) {
|
||||
auto r1 = Mul(a1, b1);
|
||||
@@ -878,6 +913,15 @@ auto ConstEval::MulFunc(const sem::Type* elem_ty) {
|
||||
};
|
||||
}
|
||||
|
||||
auto ConstEval::DivFunc(const sem::Type* elem_ty) {
|
||||
return [=](auto a1, auto a2) -> ImplResult {
|
||||
if (auto r = Div(a1, a2)) {
|
||||
return CreateElement(builder, elem_ty, r.Get());
|
||||
}
|
||||
return utils::Failure;
|
||||
};
|
||||
}
|
||||
|
||||
auto ConstEval::Dot2Func(const sem::Type* elem_ty) {
|
||||
return [=](auto a1, auto a2, auto b1, auto b2) -> ImplResult {
|
||||
if (auto r = Dot2(a1, a2, b1, b2)) {
|
||||
@@ -1366,42 +1410,9 @@ ConstEval::Result ConstEval::OpMultiplyMatMat(const sem::Type* ty,
|
||||
ConstEval::Result ConstEval::OpDivide(const sem::Type* ty,
|
||||
utils::VectorRef<const sem::Constant*> args,
|
||||
const Source& source) {
|
||||
TINT_SCOPED_ASSIGNMENT(current_source, &source);
|
||||
auto transform = [&](const sem::Constant* c0, const sem::Constant* c1) {
|
||||
auto create = [&](auto i, auto j) -> ImplResult {
|
||||
using NumberT = decltype(i);
|
||||
NumberT result;
|
||||
if constexpr (IsAbstract<NumberT>) {
|
||||
// Check for over/underflow for abstract values
|
||||
if (auto r = CheckedDiv(i, j)) {
|
||||
result = r->value;
|
||||
} else {
|
||||
AddError(OverflowErrorMessage(i, "/", j), source);
|
||||
return utils::Failure;
|
||||
}
|
||||
} else {
|
||||
using T = UnwrapNumber<NumberT>;
|
||||
auto divide_values = [](T lhs, T rhs) {
|
||||
if constexpr (std::is_integral_v<T>) {
|
||||
// For integers, lhs / 0 returns lhs
|
||||
if (rhs == 0) {
|
||||
return lhs;
|
||||
}
|
||||
|
||||
if constexpr (std::is_signed_v<T>) {
|
||||
// For signed integers, for lhs / -1, return lhs if lhs is the
|
||||
// most negative value
|
||||
if (rhs == -1 && lhs == std::numeric_limits<T>::min()) {
|
||||
return lhs;
|
||||
}
|
||||
}
|
||||
}
|
||||
return lhs / rhs;
|
||||
};
|
||||
result = divide_values(i.value, j.value);
|
||||
}
|
||||
return CreateElement(builder, c0->Type(), result);
|
||||
};
|
||||
return Dispatch_fia_fiu32_f16(create, c0, c1);
|
||||
return Dispatch_fia_fiu32_f16(DivFunc(c0->Type()), c0, c1);
|
||||
};
|
||||
|
||||
return TransformBinaryElements(builder, ty, transform, args[0], args[1]);
|
||||
@@ -2397,6 +2408,59 @@ ConstEval::Result ConstEval::sinh(const sem::Type* ty,
|
||||
return TransformElements(builder, ty, transform, args[0]);
|
||||
}
|
||||
|
||||
ConstEval::Result ConstEval::smoothstep(const sem::Type* ty,
|
||||
utils::VectorRef<const sem::Constant*> args,
|
||||
const Source& source) {
|
||||
TINT_SCOPED_ASSIGNMENT(current_source, &source);
|
||||
|
||||
auto transform = [&](const sem::Constant* c0, const sem::Constant* c1,
|
||||
const sem::Constant* c2) {
|
||||
auto create = [&](auto low, auto high, auto x) -> ImplResult {
|
||||
using NumberT = decltype(low);
|
||||
|
||||
auto err = [&] {
|
||||
AddNote("when calculating smoothstep", source);
|
||||
return utils::Failure;
|
||||
};
|
||||
|
||||
// t = clamp((x - low) / (high - low), 0.0, 1.0)
|
||||
auto x_minus_low = Sub(x, low);
|
||||
auto high_minus_low = Sub(high, low);
|
||||
if (!x_minus_low || !high_minus_low) {
|
||||
return err();
|
||||
}
|
||||
|
||||
auto div = Div(x_minus_low.Get(), high_minus_low.Get());
|
||||
if (!div) {
|
||||
return err();
|
||||
}
|
||||
|
||||
auto clamp = Clamp(div.Get(), NumberT(0), NumberT(1));
|
||||
auto t = clamp.Get();
|
||||
|
||||
// result = t * t * (3.0 - 2.0 * t)
|
||||
auto t_times_t = Mul(t, t);
|
||||
auto t_times_2 = Mul(NumberT(2), t);
|
||||
if (!t_times_t || !t_times_2) {
|
||||
return err();
|
||||
}
|
||||
|
||||
auto three_minus_t_times_2 = Sub(NumberT(3), t_times_2.Get());
|
||||
if (!three_minus_t_times_2) {
|
||||
return err();
|
||||
}
|
||||
|
||||
auto result = Mul(t_times_t.Get(), three_minus_t_times_2.Get());
|
||||
if (!result) {
|
||||
return err();
|
||||
}
|
||||
return CreateElement(builder, c0->Type(), result.Get());
|
||||
};
|
||||
return Dispatch_fa_f32_f16(create, c0, c1, c2);
|
||||
};
|
||||
return TransformElements(builder, ty, transform, args[0], args[1], args[2]);
|
||||
}
|
||||
|
||||
ConstEval::Result ConstEval::step(const sem::Type* ty,
|
||||
utils::VectorRef<const sem::Constant*> args,
|
||||
const Source&) {
|
||||
@@ -2587,4 +2651,8 @@ void ConstEval::AddWarning(const std::string& msg, const Source& source) const {
|
||||
builder.Diagnostics().add_warning(diag::System::Resolver, msg, source);
|
||||
}
|
||||
|
||||
void ConstEval::AddNote(const std::string& msg, const Source& source) const {
|
||||
builder.Diagnostics().add_note(diag::System::Resolver, msg, source);
|
||||
}
|
||||
|
||||
} // namespace tint::resolver
|
||||
|
||||
@@ -719,6 +719,15 @@ class ConstEval {
|
||||
utils::VectorRef<const sem::Constant*> args,
|
||||
const Source& source);
|
||||
|
||||
/// smoothstep builtin
|
||||
/// @param ty the expression type
|
||||
/// @param args the input arguments
|
||||
/// @param source the source location of the conversion
|
||||
/// @return the result value, or null if the value cannot be calculated
|
||||
Result smoothstep(const sem::Type* ty,
|
||||
utils::VectorRef<const sem::Constant*> args,
|
||||
const Source& source);
|
||||
|
||||
/// step builtin
|
||||
/// @param ty the expression type
|
||||
/// @param args the input arguments
|
||||
@@ -825,6 +834,9 @@ class ConstEval {
|
||||
/// Adds the given warning message to the diagnostics
|
||||
void AddWarning(const std::string& msg, const Source& source) const;
|
||||
|
||||
/// Adds the given note message to the diagnostics
|
||||
void AddNote(const std::string& msg, const Source& source) const;
|
||||
|
||||
/// Adds two Number<T>s
|
||||
/// @param a the lhs number
|
||||
/// @param b the rhs number
|
||||
@@ -846,6 +858,13 @@ class ConstEval {
|
||||
template <typename NumberT>
|
||||
utils::Result<NumberT> Mul(NumberT a, NumberT b);
|
||||
|
||||
/// Divides two Number<T>s
|
||||
/// @param a the lhs number
|
||||
/// @param b the rhs number
|
||||
/// @returns the result number on success, or logs an error and returns Failure
|
||||
template <typename NumberT>
|
||||
utils::Result<NumberT> Div(NumberT a, NumberT b);
|
||||
|
||||
/// Returns the dot product of (a1,a2) with (b1,b2)
|
||||
/// @param a1 component 1 of lhs vector
|
||||
/// @param a2 component 2 of lhs vector
|
||||
@@ -925,6 +944,12 @@ class ConstEval {
|
||||
/// @returns the callable function
|
||||
auto MulFunc(const sem::Type* elem_ty);
|
||||
|
||||
/// Returns a callable that calls Div, and creates a Constant with its result of type `elem_ty`
|
||||
/// if successful, or returns Failure otherwise.
|
||||
/// @param elem_ty the element type of the Constant to create on success
|
||||
/// @returns the callable function
|
||||
auto DivFunc(const sem::Type* elem_ty);
|
||||
|
||||
/// Returns a callable that calls Dot2, and creates a Constant with its result of type `elem_ty`
|
||||
/// if successful, or returns Failure otherwise.
|
||||
/// @param elem_ty the element type of the Constant to create on success
|
||||
|
||||
@@ -1648,6 +1648,90 @@ INSTANTIATE_TEST_SUITE_P( //
|
||||
SinhCases<f32>(),
|
||||
SinhCases<f16>()))));
|
||||
|
||||
template <typename T>
|
||||
std::vector<Case> SmoothstepCases() {
|
||||
return {
|
||||
// t == 0
|
||||
C({T(4), T(6), T(2)}, T(0)),
|
||||
// t == 1
|
||||
C({T(4), T(6), T(8)}, T(1)),
|
||||
// t == .5
|
||||
C({T(4), T(6), T(5)}, T(.5)),
|
||||
|
||||
// Vector tests
|
||||
C({Vec(T(4), T(4)), Vec(T(6), T(6)), Vec(T(2), T(8))}, Vec(T(0), T(1))),
|
||||
};
|
||||
}
|
||||
INSTANTIATE_TEST_SUITE_P( //
|
||||
Smoothstep,
|
||||
ResolverConstEvalBuiltinTest,
|
||||
testing::Combine(testing::Values(sem::BuiltinType::kSmoothstep),
|
||||
testing::ValuesIn(Concat(SmoothstepCases<AFloat>(), //
|
||||
SmoothstepCases<f32>(),
|
||||
SmoothstepCases<f16>()))));
|
||||
|
||||
template <typename T>
|
||||
std::vector<Case> SmoothstepAFloatErrorCases() {
|
||||
auto error_msg = [](auto a, const char* op, auto b) {
|
||||
return "12:34 error: " + OverflowErrorMessage(a, op, b) + R"(
|
||||
12:34 note: when calculating smoothstep)";
|
||||
};
|
||||
|
||||
return {// `x - low` underflows
|
||||
E({T::Highest(), T(1), T::Lowest()}, error_msg(T::Lowest(), "-", T::Highest())),
|
||||
// `high - low` underflows
|
||||
E({T::Highest(), T::Lowest(), T(0)}, error_msg(T::Lowest(), "-", T::Highest())),
|
||||
// Divid by zero on `(x - low) / (high - low)`
|
||||
E({T(0), T(0), T(0)}, error_msg(T(0), "/", T(0)))};
|
||||
}
|
||||
INSTANTIATE_TEST_SUITE_P( //
|
||||
SmoothstepAFloatError,
|
||||
ResolverConstEvalBuiltinTest,
|
||||
testing::Combine(testing::Values(sem::BuiltinType::kSmoothstep),
|
||||
testing::ValuesIn(SmoothstepAFloatErrorCases<AFloat>())));
|
||||
|
||||
template <typename T>
|
||||
std::vector<Case> SmoothstepF16ErrorCases() {
|
||||
auto error_msg = [](auto a, const char* op, auto b) {
|
||||
return "12:34 error: " + OverflowErrorMessage(a, op, b) + R"(
|
||||
12:34 note: when calculating smoothstep)";
|
||||
};
|
||||
|
||||
return {// `x - low` underflows
|
||||
E({T::Highest(), T(1), T::Lowest()}, error_msg(T::Lowest(), "-", T::Highest())),
|
||||
// `high - low` underflows
|
||||
E({T::Highest(), T::Lowest(), T(0)}, error_msg(T::Lowest(), "-", T::Highest())),
|
||||
// Divid by zero on `(x - low) / (high - low)`
|
||||
E({T(0), T(0), T(0)}, error_msg(T(0), "/", T(0)))};
|
||||
}
|
||||
// TODO(crbug.com/tint/1581): Enable when non-abstract math is checked.
|
||||
INSTANTIATE_TEST_SUITE_P( //
|
||||
DISABLED_SmoothstepF16Error,
|
||||
ResolverConstEvalBuiltinTest,
|
||||
testing::Combine(testing::Values(sem::BuiltinType::kSmoothstep),
|
||||
testing::ValuesIn(SmoothstepF16ErrorCases<f16>())));
|
||||
|
||||
template <typename T>
|
||||
std::vector<Case> SmoothstepF32ErrorCases() {
|
||||
auto error_msg = [](auto a, const char* op, auto b) {
|
||||
return "12:34 error: " + OverflowErrorMessage(a, op, b) + R"(
|
||||
12:34 note: when calculating smoothstep)";
|
||||
};
|
||||
|
||||
return {// `x - low` underflows
|
||||
E({T::Highest(), T(1), T::Lowest()}, error_msg(T::Lowest(), "-", T::Highest())),
|
||||
// `high - low` underflows
|
||||
E({T::Highest(), T::Lowest(), T(0)}, error_msg(T::Lowest(), "-", T::Highest())),
|
||||
// Divid by zero on `(x - low) / (high - low)`
|
||||
E({T(0), T(0), T(0)}, error_msg(T(0), "/", T(0)))};
|
||||
}
|
||||
// TODO(crbug.com/tint/1581): Enable when non-abstract math is checked.
|
||||
INSTANTIATE_TEST_SUITE_P( //
|
||||
DISABLED_SmoothstepF32Error,
|
||||
ResolverConstEvalBuiltinTest,
|
||||
testing::Combine(testing::Values(sem::BuiltinType::kSmoothstep),
|
||||
testing::ValuesIn(SmoothstepF32ErrorCases<f32>())));
|
||||
|
||||
template <typename T>
|
||||
std::vector<Case> StepCases() {
|
||||
return {
|
||||
|
||||
@@ -13113,24 +13113,24 @@ constexpr OverloadInfo kOverloads[] = {
|
||||
/* num parameters */ 3,
|
||||
/* num template types */ 1,
|
||||
/* num template numbers */ 0,
|
||||
/* template types */ &kTemplateTypes[26],
|
||||
/* template types */ &kTemplateTypes[23],
|
||||
/* template numbers */ &kTemplateNumbers[10],
|
||||
/* parameters */ &kParameters[489],
|
||||
/* return matcher indices */ &kMatcherIndices[3],
|
||||
/* flags */ OverloadFlags(OverloadFlag::kIsBuiltin, OverloadFlag::kSupportsVertexPipeline, OverloadFlag::kSupportsFragmentPipeline, OverloadFlag::kSupportsComputePipeline),
|
||||
/* const eval */ nullptr,
|
||||
/* const eval */ &ConstEval::smoothstep,
|
||||
},
|
||||
{
|
||||
/* [399] */
|
||||
/* num parameters */ 3,
|
||||
/* num template types */ 1,
|
||||
/* num template numbers */ 1,
|
||||
/* template types */ &kTemplateTypes[26],
|
||||
/* template types */ &kTemplateTypes[23],
|
||||
/* template numbers */ &kTemplateNumbers[4],
|
||||
/* parameters */ &kParameters[492],
|
||||
/* return matcher indices */ &kMatcherIndices[30],
|
||||
/* flags */ OverloadFlags(OverloadFlag::kIsBuiltin, OverloadFlag::kSupportsVertexPipeline, OverloadFlag::kSupportsFragmentPipeline, OverloadFlag::kSupportsComputePipeline),
|
||||
/* const eval */ nullptr,
|
||||
/* const eval */ &ConstEval::smoothstep,
|
||||
},
|
||||
{
|
||||
/* [400] */
|
||||
@@ -14504,8 +14504,8 @@ constexpr IntrinsicInfo kBuiltins[] = {
|
||||
},
|
||||
{
|
||||
/* [72] */
|
||||
/* fn smoothstep<T : f32_f16>(T, T, T) -> T */
|
||||
/* fn smoothstep<N : num, T : f32_f16>(vec<N, T>, vec<N, T>, vec<N, T>) -> vec<N, T> */
|
||||
/* fn smoothstep<T : fa_f32_f16>(@test_value(2) T, @test_value(4) T, @test_value(3) T) -> T */
|
||||
/* fn smoothstep<N : num, T : fa_f32_f16>(@test_value(2) vec<N, T>, @test_value(4) vec<N, T>, @test_value(3) vec<N, T>) -> vec<N, T> */
|
||||
/* num overloads */ 2,
|
||||
/* overloads */ &kOverloads[398],
|
||||
},
|
||||
|
||||
Reference in New Issue
Block a user