tint: const eval of refract builtin

Bug: tint:1581
Change-Id: Iff64e8a680fbbc82e1f8efe2e2f8e05af8e3692c
Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/111920
Kokoro: Kokoro <noreply+kokoro@google.com>
Reviewed-by: Ben Clayton <bclayton@google.com>
Commit-Queue: Antonio Maiorano <amaiorano@google.com>
This commit is contained in:
Antonio Maiorano
2022-11-28 21:14:36 +00:00
committed by Dawn LUCI CQ
parent be49ed5719
commit 3728a505d6
74 changed files with 1972 additions and 154 deletions

View File

@@ -2954,6 +2954,104 @@ ConstEval::Result ConstEval::reflect(const sem::Type* ty,
return r;
}
ConstEval::Result ConstEval::refract(const sem::Type* ty,
utils::VectorRef<const sem::Constant*> args,
const Source& source) {
auto* vec_ty = ty->As<sem::Vector>();
auto* el_ty = vec_ty->type();
auto compute_k = [&](auto e3, auto dot_e2_e1) -> ConstEval::Result {
using NumberT = decltype(e3);
// let k = 1.0 - e3 * e3 * (1.0 - dot(e2, e1) * dot(e2, e1))
auto e3_squared = Mul(source, e3, e3);
if (!e3_squared) {
return utils::Failure;
}
auto dot_e2_e1_squared = Mul(source, dot_e2_e1, dot_e2_e1);
if (!dot_e2_e1_squared) {
return utils::Failure;
}
auto r = Sub(source, NumberT(1), dot_e2_e1_squared.Get());
if (!r) {
return utils::Failure;
}
r = Mul(source, e3_squared.Get(), r.Get());
if (!r) {
return utils::Failure;
}
r = Sub(source, NumberT(1), r.Get());
if (!r) {
return utils::Failure;
}
return CreateElement(builder, source, el_ty, r.Get());
};
auto compute_e2_scale = [&](auto e3, auto dot_e2_e1, auto k) -> ConstEval::Result {
// e3 * dot(e2, e1) + sqrt(k)
auto sqrt_k = Sqrt(source, k);
if (!sqrt_k) {
return utils::Failure;
}
auto r = Mul(source, e3, dot_e2_e1);
if (!r) {
return utils::Failure;
}
r = Add(source, r.Get(), sqrt_k.Get());
if (!r) {
return utils::Failure;
}
return CreateElement(builder, source, el_ty, r.Get());
};
auto calculate = [&]() -> ConstEval::Result {
auto* e1 = args[0];
auto* e2 = args[1];
auto* e3 = args[2];
// For the incident vector e1 and surface normal e2, and the ratio of indices of refraction
// e3, let k = 1.0 - e3 * e3 * (1.0 - dot(e2, e1) * dot(e2, e1)). If k < 0.0, returns the
// refraction vector 0.0, otherwise return the refraction vector e3 * e1 - (e3 * dot(e2, e1)
// + sqrt(k)) * e2.
// dot(e2, e1)
auto dot_e2_e1 = Dot(source, e2, e1);
if (!dot_e2_e1) {
return utils::Failure;
}
// let k = 1.0 - e3 * e3 * (1.0 - dot(e2, e1) * dot(e2, e1))
auto k = Dispatch_fa_f32_f16(compute_k, e3, dot_e2_e1.Get());
if (!k) {
return utils::Failure;
}
// If k < 0.0, returns the refraction vector 0.0
if (k.Get()->As<AFloat>() < 0) {
return ZeroValue(builder, ty);
}
// Otherwise return the refraction vector e3 * e1 - (e3 * dot(e2, e1) + sqrt(k)) * e2
auto e1_scaled = Mul(source, ty, e3, e1);
if (!e1_scaled) {
return utils::Failure;
}
auto e2_scale = Dispatch_fa_f32_f16(compute_e2_scale, e3, dot_e2_e1.Get(), k.Get());
if (!e2_scale) {
return utils::Failure;
}
auto e2_scaled = Mul(source, ty, e2_scale.Get(), e2);
if (!e1_scaled) {
return utils::Failure;
}
return Sub(source, ty, e1_scaled.Get(), e2_scaled.Get());
};
auto r = calculate();
if (!r) {
AddNote("when calculating refract", source);
}
return r;
}
ConstEval::Result ConstEval::reverseBits(const sem::Type* ty,
utils::VectorRef<const sem::Constant*> args,
const Source& source) {

View File

@@ -800,6 +800,15 @@ class ConstEval {
utils::VectorRef<const sem::Constant*> args,
const Source& source);
/// refract builtin
/// @param ty the expression type
/// @param args the input arguments
/// @param source the source location of the conversion
/// @return the result value, or null if the value cannot be calculated
Result refract(const sem::Type* ty,
utils::VectorRef<const sem::Constant*> args,
const Source& source);
/// reverseBits builtin
/// @param ty the expression type
/// @param args the input arguments

View File

@@ -1983,11 +1983,91 @@ INSTANTIATE_TEST_SUITE_P( //
Reflect,
ResolverConstEvalBuiltinTest,
testing::Combine(testing::Values(sem::BuiltinType::kReflect),
testing::ValuesIn(
// ReflectCases<f32>())));
Concat(ReflectCases<AFloat>(), //
ReflectCases<f32>(), //
ReflectCases<f16>()))));
testing::ValuesIn(Concat(ReflectCases<AFloat>(), //
ReflectCases<f32>(), //
ReflectCases<f16>()))));
template <typename T>
std::vector<Case> RefractCases() {
// Returns "eta" (Greek letter) that denotes the ratio of indices of refraction for the input
// and output vector angles from the normal vector.
auto eta = [](auto angle1, auto angle2) {
// Snell's law: sin(angle1) / sin(angle2) == n2 / n1
// We want the ratio of n1 to n2, so sin(angle2) / sin(angle1)
auto angle1_rads = T(angle1) * kPi<T> / T(180);
auto angle2_rads = T(angle2) * kPi<T> / T(180);
return T(std::sin(angle2_rads) / std::sin(angle1_rads));
};
auto zero = Vec(T(0), T(0), T(0));
auto pos_y = Vec(T(0), T(1), T(0));
auto neg_y = Vec(T(0), -T(1), T(0));
auto pos_x = Vec(T(1), T(0), T(0));
auto neg_x = Vec(-T(1), T(0), T(0));
auto cos_45 = T(0.70710678118654752440084436210485);
auto cos_30 = T(0.86602540378443864676372317075294);
auto down_right = Vec(T(cos_45), -T(cos_45), T(0));
auto up_right = Vec(T(cos_45), T(cos_45), T(0));
auto eps = 0.001;
if constexpr (std::is_same_v<T, f16>) {
eps = 0.1;
}
auto r = std::vector<Case>{
// e3 (eta) == 1, no refraction, so input is same as output
C({down_right, pos_y, Val(T(1))}, down_right),
C({neg_y, pos_y, Val(T(1))}, neg_y),
// Varying etas
C({down_right, pos_y, Val(eta(45, 45))}, down_right).FloatComp(eps), // e3 == 1
C({down_right, pos_y, Val(eta(45, 30))}, Vec(T(0.5), -T(cos_30), T(0))).FloatComp(eps),
C({down_right, pos_y, Val(eta(45, 60))}, Vec(T(cos_30), -T(0.5), T(0))).FloatComp(eps),
C({down_right, pos_y, Val(eta(45, 90))}, Vec(T(1), T(0), T(0))).FloatComp(eps),
// Flip input and normal, same result
C({up_right, neg_y, Val(eta(45, 45))}, up_right).FloatComp(eps), // e3 == 1
C({up_right, neg_y, Val(eta(45, 30))}, Vec(T(0.5), T(cos_30), T(0))).FloatComp(eps),
C({up_right, neg_y, Val(eta(45, 60))}, Vec(T(cos_30), T(0.5), T(0))).FloatComp(eps),
C({up_right, neg_y, Val(eta(45, 90))}, Vec(T(1), T(0), T(0))).FloatComp(eps),
// Flip only normal, result is flipped
C({down_right, neg_y, Val(eta(45, 45))}, up_right).FloatComp(eps), // e3 == 1
C({down_right, neg_y, Val(eta(45, 30))}, Vec(T(0.5), T(cos_30), T(0))).FloatComp(eps),
C({down_right, neg_y, Val(eta(45, 60))}, Vec(T(cos_30), T(0.5), T(0))).FloatComp(eps),
C({down_right, neg_y, Val(eta(45, 90))}, Vec(T(1), T(0), T(0))).FloatComp(eps),
// If k < 0.0, returns the refraction vector 0.0
C({down_right, pos_y, Val(T(2))}, zero).FloatComp(eps),
// A few more with a different normal (e2)
C({down_right, neg_x, Val(eta(45, 45))}, down_right).FloatComp(eps), // e3 == 1
C({down_right, neg_x, Val(eta(45, 30))}, Vec(cos_30, -T(0.5), T(0))).FloatComp(eps),
C({down_right, neg_x, Val(eta(45, 60))}, Vec(T(0.5), -T(cos_30), T(0))).FloatComp(eps),
};
auto error_msg = [](auto a, const char* op, auto b) {
return "12:34 error: " + OverflowErrorMessage(a, op, b) + R"(
12:34 note: when calculating refract)";
};
ConcatInto( //
r,
std::vector<Case>{
// Overflow the dot product operation
E({Vec(T::Highest(), T::Highest(), T(0)), Vec(T(1), T(1), T(0)), Val(T(1))},
error_msg(T::Highest(), "+", T::Highest())),
E({Vec(T::Lowest(), T::Lowest(), T(0)), Vec(T(1), T(1), T(0)), Val(T(1))},
error_msg(T::Lowest(), "+", T::Lowest())),
// Overflow the k^2 operation
E({down_right, pos_y, Val(T::Highest())}, error_msg(T::Highest(), "*", T::Highest())),
});
return r;
}
INSTANTIATE_TEST_SUITE_P( //
Refract,
ResolverConstEvalBuiltinTest,
testing::Combine(testing::Values(sem::BuiltinType::kRefract),
testing::ValuesIn(Concat(RefractCases<AFloat>(), //
RefractCases<f32>(), //
RefractCases<f16>()))));
template <typename T>
std::vector<Case> RadiansCases() {

View File

@@ -13710,12 +13710,12 @@ constexpr OverloadInfo kOverloads[] = {
/* num parameters */ 3,
/* num template types */ 1,
/* num template numbers */ 1,
/* template types */ &kTemplateTypes[26],
/* template types */ &kTemplateTypes[23],
/* template numbers */ &kTemplateNumbers[4],
/* parameters */ &kParameters[477],
/* return matcher indices */ &kMatcherIndices[30],
/* flags */ OverloadFlags(OverloadFlag::kIsBuiltin, OverloadFlag::kSupportsVertexPipeline, OverloadFlag::kSupportsFragmentPipeline, OverloadFlag::kSupportsComputePipeline),
/* const eval */ nullptr,
/* const eval */ &ConstEval::refract,
},
{
/* [450] */
@@ -14421,7 +14421,7 @@ constexpr IntrinsicInfo kBuiltins[] = {
},
{
/* [64] */
/* fn refract<N : num, T : f32_f16>(vec<N, T>, vec<N, T>, T) -> vec<N, T> */
/* fn refract<N : num, T : fa_f32_f16>(vec<N, T>, vec<N, T>, T) -> vec<N, T> */
/* num overloads */ 1,
/* overloads */ &kOverloads[449],
},