mirror of
https://github.com/encounter/dawn-cmake.git
synced 2025-12-21 10:49:14 +00:00
tint: const eval of refract builtin
Bug: tint:1581 Change-Id: Iff64e8a680fbbc82e1f8efe2e2f8e05af8e3692c Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/111920 Kokoro: Kokoro <noreply+kokoro@google.com> Reviewed-by: Ben Clayton <bclayton@google.com> Commit-Queue: Antonio Maiorano <amaiorano@google.com>
This commit is contained in:
committed by
Dawn LUCI CQ
parent
be49ed5719
commit
3728a505d6
@@ -2954,6 +2954,104 @@ ConstEval::Result ConstEval::reflect(const sem::Type* ty,
|
||||
return r;
|
||||
}
|
||||
|
||||
ConstEval::Result ConstEval::refract(const sem::Type* ty,
|
||||
utils::VectorRef<const sem::Constant*> args,
|
||||
const Source& source) {
|
||||
auto* vec_ty = ty->As<sem::Vector>();
|
||||
auto* el_ty = vec_ty->type();
|
||||
|
||||
auto compute_k = [&](auto e3, auto dot_e2_e1) -> ConstEval::Result {
|
||||
using NumberT = decltype(e3);
|
||||
// let k = 1.0 - e3 * e3 * (1.0 - dot(e2, e1) * dot(e2, e1))
|
||||
auto e3_squared = Mul(source, e3, e3);
|
||||
if (!e3_squared) {
|
||||
return utils::Failure;
|
||||
}
|
||||
auto dot_e2_e1_squared = Mul(source, dot_e2_e1, dot_e2_e1);
|
||||
if (!dot_e2_e1_squared) {
|
||||
return utils::Failure;
|
||||
}
|
||||
auto r = Sub(source, NumberT(1), dot_e2_e1_squared.Get());
|
||||
if (!r) {
|
||||
return utils::Failure;
|
||||
}
|
||||
r = Mul(source, e3_squared.Get(), r.Get());
|
||||
if (!r) {
|
||||
return utils::Failure;
|
||||
}
|
||||
r = Sub(source, NumberT(1), r.Get());
|
||||
if (!r) {
|
||||
return utils::Failure;
|
||||
}
|
||||
return CreateElement(builder, source, el_ty, r.Get());
|
||||
};
|
||||
|
||||
auto compute_e2_scale = [&](auto e3, auto dot_e2_e1, auto k) -> ConstEval::Result {
|
||||
// e3 * dot(e2, e1) + sqrt(k)
|
||||
auto sqrt_k = Sqrt(source, k);
|
||||
if (!sqrt_k) {
|
||||
return utils::Failure;
|
||||
}
|
||||
auto r = Mul(source, e3, dot_e2_e1);
|
||||
if (!r) {
|
||||
return utils::Failure;
|
||||
}
|
||||
r = Add(source, r.Get(), sqrt_k.Get());
|
||||
if (!r) {
|
||||
return utils::Failure;
|
||||
}
|
||||
return CreateElement(builder, source, el_ty, r.Get());
|
||||
};
|
||||
|
||||
auto calculate = [&]() -> ConstEval::Result {
|
||||
auto* e1 = args[0];
|
||||
auto* e2 = args[1];
|
||||
auto* e3 = args[2];
|
||||
|
||||
// For the incident vector e1 and surface normal e2, and the ratio of indices of refraction
|
||||
// e3, let k = 1.0 - e3 * e3 * (1.0 - dot(e2, e1) * dot(e2, e1)). If k < 0.0, returns the
|
||||
// refraction vector 0.0, otherwise return the refraction vector e3 * e1 - (e3 * dot(e2, e1)
|
||||
// + sqrt(k)) * e2.
|
||||
|
||||
// dot(e2, e1)
|
||||
auto dot_e2_e1 = Dot(source, e2, e1);
|
||||
if (!dot_e2_e1) {
|
||||
return utils::Failure;
|
||||
}
|
||||
|
||||
// let k = 1.0 - e3 * e3 * (1.0 - dot(e2, e1) * dot(e2, e1))
|
||||
auto k = Dispatch_fa_f32_f16(compute_k, e3, dot_e2_e1.Get());
|
||||
if (!k) {
|
||||
return utils::Failure;
|
||||
}
|
||||
|
||||
// If k < 0.0, returns the refraction vector 0.0
|
||||
if (k.Get()->As<AFloat>() < 0) {
|
||||
return ZeroValue(builder, ty);
|
||||
}
|
||||
|
||||
// Otherwise return the refraction vector e3 * e1 - (e3 * dot(e2, e1) + sqrt(k)) * e2
|
||||
auto e1_scaled = Mul(source, ty, e3, e1);
|
||||
if (!e1_scaled) {
|
||||
return utils::Failure;
|
||||
}
|
||||
auto e2_scale = Dispatch_fa_f32_f16(compute_e2_scale, e3, dot_e2_e1.Get(), k.Get());
|
||||
if (!e2_scale) {
|
||||
return utils::Failure;
|
||||
}
|
||||
auto e2_scaled = Mul(source, ty, e2_scale.Get(), e2);
|
||||
if (!e1_scaled) {
|
||||
return utils::Failure;
|
||||
}
|
||||
return Sub(source, ty, e1_scaled.Get(), e2_scaled.Get());
|
||||
};
|
||||
auto r = calculate();
|
||||
if (!r) {
|
||||
AddNote("when calculating refract", source);
|
||||
}
|
||||
return r;
|
||||
}
|
||||
|
||||
ConstEval::Result ConstEval::reverseBits(const sem::Type* ty,
|
||||
utils::VectorRef<const sem::Constant*> args,
|
||||
const Source& source) {
|
||||
|
||||
@@ -800,6 +800,15 @@ class ConstEval {
|
||||
utils::VectorRef<const sem::Constant*> args,
|
||||
const Source& source);
|
||||
|
||||
/// refract builtin
|
||||
/// @param ty the expression type
|
||||
/// @param args the input arguments
|
||||
/// @param source the source location of the conversion
|
||||
/// @return the result value, or null if the value cannot be calculated
|
||||
Result refract(const sem::Type* ty,
|
||||
utils::VectorRef<const sem::Constant*> args,
|
||||
const Source& source);
|
||||
|
||||
/// reverseBits builtin
|
||||
/// @param ty the expression type
|
||||
/// @param args the input arguments
|
||||
|
||||
@@ -1983,11 +1983,91 @@ INSTANTIATE_TEST_SUITE_P( //
|
||||
Reflect,
|
||||
ResolverConstEvalBuiltinTest,
|
||||
testing::Combine(testing::Values(sem::BuiltinType::kReflect),
|
||||
testing::ValuesIn(
|
||||
// ReflectCases<f32>())));
|
||||
Concat(ReflectCases<AFloat>(), //
|
||||
ReflectCases<f32>(), //
|
||||
ReflectCases<f16>()))));
|
||||
testing::ValuesIn(Concat(ReflectCases<AFloat>(), //
|
||||
ReflectCases<f32>(), //
|
||||
ReflectCases<f16>()))));
|
||||
|
||||
template <typename T>
|
||||
std::vector<Case> RefractCases() {
|
||||
// Returns "eta" (Greek letter) that denotes the ratio of indices of refraction for the input
|
||||
// and output vector angles from the normal vector.
|
||||
auto eta = [](auto angle1, auto angle2) {
|
||||
// Snell's law: sin(angle1) / sin(angle2) == n2 / n1
|
||||
// We want the ratio of n1 to n2, so sin(angle2) / sin(angle1)
|
||||
auto angle1_rads = T(angle1) * kPi<T> / T(180);
|
||||
auto angle2_rads = T(angle2) * kPi<T> / T(180);
|
||||
return T(std::sin(angle2_rads) / std::sin(angle1_rads));
|
||||
};
|
||||
|
||||
auto zero = Vec(T(0), T(0), T(0));
|
||||
auto pos_y = Vec(T(0), T(1), T(0));
|
||||
auto neg_y = Vec(T(0), -T(1), T(0));
|
||||
auto pos_x = Vec(T(1), T(0), T(0));
|
||||
auto neg_x = Vec(-T(1), T(0), T(0));
|
||||
auto cos_45 = T(0.70710678118654752440084436210485);
|
||||
auto cos_30 = T(0.86602540378443864676372317075294);
|
||||
auto down_right = Vec(T(cos_45), -T(cos_45), T(0));
|
||||
auto up_right = Vec(T(cos_45), T(cos_45), T(0));
|
||||
|
||||
auto eps = 0.001;
|
||||
if constexpr (std::is_same_v<T, f16>) {
|
||||
eps = 0.1;
|
||||
}
|
||||
|
||||
auto r = std::vector<Case>{
|
||||
// e3 (eta) == 1, no refraction, so input is same as output
|
||||
C({down_right, pos_y, Val(T(1))}, down_right),
|
||||
C({neg_y, pos_y, Val(T(1))}, neg_y),
|
||||
// Varying etas
|
||||
C({down_right, pos_y, Val(eta(45, 45))}, down_right).FloatComp(eps), // e3 == 1
|
||||
C({down_right, pos_y, Val(eta(45, 30))}, Vec(T(0.5), -T(cos_30), T(0))).FloatComp(eps),
|
||||
C({down_right, pos_y, Val(eta(45, 60))}, Vec(T(cos_30), -T(0.5), T(0))).FloatComp(eps),
|
||||
C({down_right, pos_y, Val(eta(45, 90))}, Vec(T(1), T(0), T(0))).FloatComp(eps),
|
||||
// Flip input and normal, same result
|
||||
C({up_right, neg_y, Val(eta(45, 45))}, up_right).FloatComp(eps), // e3 == 1
|
||||
C({up_right, neg_y, Val(eta(45, 30))}, Vec(T(0.5), T(cos_30), T(0))).FloatComp(eps),
|
||||
C({up_right, neg_y, Val(eta(45, 60))}, Vec(T(cos_30), T(0.5), T(0))).FloatComp(eps),
|
||||
C({up_right, neg_y, Val(eta(45, 90))}, Vec(T(1), T(0), T(0))).FloatComp(eps),
|
||||
// Flip only normal, result is flipped
|
||||
C({down_right, neg_y, Val(eta(45, 45))}, up_right).FloatComp(eps), // e3 == 1
|
||||
C({down_right, neg_y, Val(eta(45, 30))}, Vec(T(0.5), T(cos_30), T(0))).FloatComp(eps),
|
||||
C({down_right, neg_y, Val(eta(45, 60))}, Vec(T(cos_30), T(0.5), T(0))).FloatComp(eps),
|
||||
C({down_right, neg_y, Val(eta(45, 90))}, Vec(T(1), T(0), T(0))).FloatComp(eps),
|
||||
|
||||
// If k < 0.0, returns the refraction vector 0.0
|
||||
C({down_right, pos_y, Val(T(2))}, zero).FloatComp(eps),
|
||||
|
||||
// A few more with a different normal (e2)
|
||||
C({down_right, neg_x, Val(eta(45, 45))}, down_right).FloatComp(eps), // e3 == 1
|
||||
C({down_right, neg_x, Val(eta(45, 30))}, Vec(cos_30, -T(0.5), T(0))).FloatComp(eps),
|
||||
C({down_right, neg_x, Val(eta(45, 60))}, Vec(T(0.5), -T(cos_30), T(0))).FloatComp(eps),
|
||||
};
|
||||
|
||||
auto error_msg = [](auto a, const char* op, auto b) {
|
||||
return "12:34 error: " + OverflowErrorMessage(a, op, b) + R"(
|
||||
12:34 note: when calculating refract)";
|
||||
};
|
||||
ConcatInto( //
|
||||
r,
|
||||
std::vector<Case>{
|
||||
// Overflow the dot product operation
|
||||
E({Vec(T::Highest(), T::Highest(), T(0)), Vec(T(1), T(1), T(0)), Val(T(1))},
|
||||
error_msg(T::Highest(), "+", T::Highest())),
|
||||
E({Vec(T::Lowest(), T::Lowest(), T(0)), Vec(T(1), T(1), T(0)), Val(T(1))},
|
||||
error_msg(T::Lowest(), "+", T::Lowest())),
|
||||
// Overflow the k^2 operation
|
||||
E({down_right, pos_y, Val(T::Highest())}, error_msg(T::Highest(), "*", T::Highest())),
|
||||
});
|
||||
|
||||
return r;
|
||||
}
|
||||
INSTANTIATE_TEST_SUITE_P( //
|
||||
Refract,
|
||||
ResolverConstEvalBuiltinTest,
|
||||
testing::Combine(testing::Values(sem::BuiltinType::kRefract),
|
||||
testing::ValuesIn(Concat(RefractCases<AFloat>(), //
|
||||
RefractCases<f32>(), //
|
||||
RefractCases<f16>()))));
|
||||
|
||||
template <typename T>
|
||||
std::vector<Case> RadiansCases() {
|
||||
|
||||
@@ -13710,12 +13710,12 @@ constexpr OverloadInfo kOverloads[] = {
|
||||
/* num parameters */ 3,
|
||||
/* num template types */ 1,
|
||||
/* num template numbers */ 1,
|
||||
/* template types */ &kTemplateTypes[26],
|
||||
/* template types */ &kTemplateTypes[23],
|
||||
/* template numbers */ &kTemplateNumbers[4],
|
||||
/* parameters */ &kParameters[477],
|
||||
/* return matcher indices */ &kMatcherIndices[30],
|
||||
/* flags */ OverloadFlags(OverloadFlag::kIsBuiltin, OverloadFlag::kSupportsVertexPipeline, OverloadFlag::kSupportsFragmentPipeline, OverloadFlag::kSupportsComputePipeline),
|
||||
/* const eval */ nullptr,
|
||||
/* const eval */ &ConstEval::refract,
|
||||
},
|
||||
{
|
||||
/* [450] */
|
||||
@@ -14421,7 +14421,7 @@ constexpr IntrinsicInfo kBuiltins[] = {
|
||||
},
|
||||
{
|
||||
/* [64] */
|
||||
/* fn refract<N : num, T : f32_f16>(vec<N, T>, vec<N, T>, T) -> vec<N, T> */
|
||||
/* fn refract<N : num, T : fa_f32_f16>(vec<N, T>, vec<N, T>, T) -> vec<N, T> */
|
||||
/* num overloads */ 1,
|
||||
/* overloads */ &kOverloads[449],
|
||||
},
|
||||
|
||||
Reference in New Issue
Block a user