mirror of
https://github.com/encounter/dawn-cmake.git
synced 2025-12-20 18:29:23 +00:00
tint: const eval of reflect builtin
Bug: tint:1581 Change-Id: Ife4409ca897a5754fe6b76c650d26fd66ef5880f Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/111901 Kokoro: Kokoro <noreply+kokoro@google.com> Reviewed-by: Ben Clayton <bclayton@google.com> Commit-Queue: Antonio Maiorano <amaiorano@google.com>
This commit is contained in:
committed by
Dawn LUCI CQ
parent
6239f58064
commit
ee7d6db047
@@ -518,7 +518,7 @@ fn pow<N: num, T: f32_f16>(vec<N, T>, vec<N, T>) -> vec<N, T>
|
||||
@const fn quantizeToF16<N: num>(vec<N, f32>) -> vec<N, f32>
|
||||
@const fn radians<T: fa_f32_f16>(T) -> T
|
||||
@const fn radians<N: num, T: fa_f32_f16>(vec<N, T>) -> vec<N, T>
|
||||
fn reflect<N: num, T: f32_f16>(vec<N, T>, vec<N, T>) -> vec<N, T>
|
||||
@const fn reflect<N: num, T: fa_f32_f16>(vec<N, T>, vec<N, T>) -> vec<N, T>
|
||||
fn refract<N: num, T: f32_f16>(vec<N, T>, vec<N, T>, T) -> vec<N, T>
|
||||
@const fn reverseBits<T: iu32>(T) -> T
|
||||
@const fn reverseBits<N: num, T: iu32>(vec<N, T>) -> vec<N, T>
|
||||
|
||||
@@ -1189,6 +1189,26 @@ ConstEval::Result ConstEval::Length(const Source& source,
|
||||
return Dispatch_fa_f32_f16(SqrtFunc(source, ty), d.Get());
|
||||
}
|
||||
|
||||
ConstEval::Result ConstEval::Mul(const Source& source,
|
||||
const sem::Type* ty,
|
||||
const sem::Constant* v1,
|
||||
const sem::Constant* v2) {
|
||||
auto transform = [&](const sem::Constant* c0, const sem::Constant* c1) {
|
||||
return Dispatch_fia_fiu32_f16(MulFunc(source, c0->Type()), c0, c1);
|
||||
};
|
||||
return TransformBinaryElements(builder, ty, transform, v1, v2);
|
||||
}
|
||||
|
||||
ConstEval::Result ConstEval::Sub(const Source& source,
|
||||
const sem::Type* ty,
|
||||
const sem::Constant* v1,
|
||||
const sem::Constant* v2) {
|
||||
auto transform = [&](const sem::Constant* c0, const sem::Constant* c1) {
|
||||
return Dispatch_fia_fiu32_f16(SubFunc(source, c0->Type()), c0, c1);
|
||||
};
|
||||
return TransformBinaryElements(builder, ty, transform, v1, v2);
|
||||
}
|
||||
|
||||
auto ConstEval::Det2Func(const Source& source, const sem::Type* elem_ty) {
|
||||
return [=](auto a, auto b, auto c, auto d) -> ImplResult {
|
||||
if (auto r = Det2(source, a, b, c, d)) {
|
||||
@@ -1481,21 +1501,13 @@ ConstEval::Result ConstEval::OpPlus(const sem::Type* ty,
|
||||
ConstEval::Result ConstEval::OpMinus(const sem::Type* ty,
|
||||
utils::VectorRef<const sem::Constant*> args,
|
||||
const Source& source) {
|
||||
auto transform = [&](const sem::Constant* c0, const sem::Constant* c1) {
|
||||
return Dispatch_fia_fiu32_f16(SubFunc(source, c0->Type()), c0, c1);
|
||||
};
|
||||
|
||||
return TransformBinaryElements(builder, ty, transform, args[0], args[1]);
|
||||
return Sub(source, ty, args[0], args[1]);
|
||||
}
|
||||
|
||||
ConstEval::Result ConstEval::OpMultiply(const sem::Type* ty,
|
||||
utils::VectorRef<const sem::Constant*> args,
|
||||
const Source& source) {
|
||||
auto transform = [&](const sem::Constant* c0, const sem::Constant* c1) {
|
||||
return Dispatch_fia_fiu32_f16(MulFunc(source, c0->Type()), c0, c1);
|
||||
};
|
||||
|
||||
return TransformBinaryElements(builder, ty, transform, args[0], args[1]);
|
||||
return Mul(source, ty, args[0], args[1]);
|
||||
}
|
||||
|
||||
ConstEval::Result ConstEval::OpMultiplyMatVec(const sem::Type* ty,
|
||||
@@ -2213,7 +2225,7 @@ ConstEval::Result ConstEval::degrees(const sem::Type* ty,
|
||||
ConstEval::Result ConstEval::determinant(const sem::Type* ty,
|
||||
utils::VectorRef<const sem::Constant*> args,
|
||||
const Source& source) {
|
||||
auto calculate = [&]() -> ImplResult {
|
||||
auto calculate = [&]() -> ConstEval::Result {
|
||||
auto* m = args[0];
|
||||
auto* mat_ty = m->Type()->As<sem::Matrix>();
|
||||
auto me = [&](size_t r, size_t c) { return m->Index(c)->Index(r); };
|
||||
@@ -2899,6 +2911,49 @@ ConstEval::Result ConstEval::radians(const sem::Type* ty,
|
||||
return TransformElements(builder, ty, transform, args[0]);
|
||||
}
|
||||
|
||||
ConstEval::Result ConstEval::reflect(const sem::Type* ty,
|
||||
utils::VectorRef<const sem::Constant*> args,
|
||||
const Source& source) {
|
||||
auto calculate = [&]() -> ConstEval::Result {
|
||||
// For the incident vector e1 and surface orientation e2, returns the reflection direction
|
||||
// e1 - 2 * dot(e2, e1) * e2.
|
||||
auto* e1 = args[0];
|
||||
auto* e2 = args[1];
|
||||
auto* vec_ty = ty->As<sem::Vector>();
|
||||
auto* el_ty = vec_ty->type();
|
||||
|
||||
// dot(e2, e1)
|
||||
auto dot_e2_e1 = Dot(source, e2, e1);
|
||||
if (!dot_e2_e1) {
|
||||
return utils::Failure;
|
||||
}
|
||||
|
||||
// 2 * dot(e2, e1)
|
||||
auto mul2 = [&](auto v) -> ImplResult {
|
||||
using NumberT = decltype(v);
|
||||
return CreateElement(builder, source, el_ty, NumberT{NumberT{2} * v});
|
||||
};
|
||||
auto dot_e2_e1_2 = Dispatch_fa_f32_f16(mul2, dot_e2_e1.Get());
|
||||
if (!dot_e2_e1_2) {
|
||||
return utils::Failure;
|
||||
}
|
||||
|
||||
// 2 * dot(e2, e1) * e2
|
||||
auto dot_e2_e1_2_e2 = Mul(source, ty, dot_e2_e1_2.Get(), e2);
|
||||
if (!dot_e2_e1_2_e2) {
|
||||
return utils::Failure;
|
||||
}
|
||||
|
||||
// e1 - 2 * dot(e2, e1) * e2
|
||||
return Sub(source, ty, e1, dot_e2_e1_2_e2.Get());
|
||||
};
|
||||
auto r = calculate();
|
||||
if (!r) {
|
||||
AddNote("when calculating reflect", source);
|
||||
}
|
||||
return r;
|
||||
}
|
||||
|
||||
ConstEval::Result ConstEval::reverseBits(const sem::Type* ty,
|
||||
utils::VectorRef<const sem::Constant*> args,
|
||||
const Source& source) {
|
||||
|
||||
@@ -791,6 +791,15 @@ class ConstEval {
|
||||
utils::VectorRef<const sem::Constant*> args,
|
||||
const Source& source);
|
||||
|
||||
/// reflect builtin
|
||||
/// @param ty the expression type
|
||||
/// @param args the input arguments
|
||||
/// @param source the source location of the conversion
|
||||
/// @return the result value, or null if the value cannot be calculated
|
||||
Result reflect(const sem::Type* ty,
|
||||
utils::VectorRef<const sem::Constant*> args,
|
||||
const Source& source);
|
||||
|
||||
/// reverseBits builtin
|
||||
/// @param ty the expression type
|
||||
/// @param args the input arguments
|
||||
@@ -1261,13 +1270,35 @@ class ConstEval {
|
||||
/// @returns the dot product
|
||||
Result Dot(const Source& source, const sem::Constant* v1, const sem::Constant* v2);
|
||||
|
||||
/// Return sthe length of c0
|
||||
/// Returns the length of c0
|
||||
/// @param source the source location
|
||||
/// @param ty the return type
|
||||
/// @param c0 the constant to calculate the length of
|
||||
/// @returns the length of c0
|
||||
Result Length(const Source& source, const sem::Type* ty, const sem::Constant* c0);
|
||||
|
||||
/// Returns the product of v1 and v2
|
||||
/// @param source the source location
|
||||
/// @param ty the return type
|
||||
/// @param v1 lhs value
|
||||
/// @param v2 rhs value
|
||||
/// @returns the product of v1 and v2
|
||||
Result Mul(const Source& source,
|
||||
const sem::Type* ty,
|
||||
const sem::Constant* v1,
|
||||
const sem::Constant* v2);
|
||||
|
||||
/// Returns the difference between v2 and v1
|
||||
/// @param source the source location
|
||||
/// @param ty the return type
|
||||
/// @param v1 lhs value
|
||||
/// @param v2 rhs value
|
||||
/// @returns the difference between v2 and v1
|
||||
Result Sub(const Source& source,
|
||||
const sem::Type* ty,
|
||||
const sem::Constant* v1,
|
||||
const sem::Constant* v2);
|
||||
|
||||
ProgramBuilder& builder;
|
||||
};
|
||||
|
||||
|
||||
@@ -1934,6 +1934,61 @@ INSTANTIATE_TEST_SUITE_P( //
|
||||
testing::ValuesIn(Concat(ReverseBitsCases<i32>(), //
|
||||
ReverseBitsCases<u32>()))));
|
||||
|
||||
template <typename T>
|
||||
std::vector<Case> ReflectCases() {
|
||||
auto pos_y = Vec(T(0), T(1), T(0));
|
||||
auto neg_y = Vec(T(0), -T(1), T(0));
|
||||
auto pos_large_y = Vec(T(0), T(10000), T(0));
|
||||
auto neg_large_y = Vec(T(0), -T(10000), T(0));
|
||||
|
||||
auto cos_45 = T(0.70710678118654752440084436210485);
|
||||
auto pos_xyz = Vec(cos_45, cos_45, cos_45);
|
||||
|
||||
auto r = std::vector<Case>{
|
||||
C({Vec(T(1), -T(1), T(0)), pos_y}, Vec(T(1), T(1), T(0))),
|
||||
C({Vec(T(24), -T(42), T(0)), pos_y}, Vec(T(24), T(42), T(0))),
|
||||
// Flipping reflection vector doesn't change the result
|
||||
C({Vec(T(1), -T(1), T(0)), neg_y}, Vec(T(1), T(1), T(0))),
|
||||
C({Vec(T(24), -T(42), T(0)), neg_y}, Vec(T(24), T(42), T(0))),
|
||||
// Parallel input and reflection vectors: result is negation of input
|
||||
C({pos_y, pos_y}, neg_y),
|
||||
C({neg_y, pos_y}, pos_y),
|
||||
C({pos_large_y, pos_y}, neg_large_y),
|
||||
C({neg_large_y, pos_y}, pos_large_y),
|
||||
// Input axis vectors reflected by normalized(vec(1,1,1)) vector.
|
||||
C({Vec(T(1), T(0), T(0)), pos_xyz}, Vec(T(0), -T(1), -T(1))).FloatComp(0.02),
|
||||
C({Vec(T(0), T(1), T(0)), pos_xyz}, Vec(-T(1), T(0), -T(1))).FloatComp(0.02),
|
||||
C({Vec(T(0), T(0), T(1)), pos_xyz}, Vec(-T(1), -T(1), T(0))).FloatComp(0.02),
|
||||
C({Vec(-T(1), T(0), T(0)), pos_xyz}, Vec(T(0), T(1), T(1))).FloatComp(0.02),
|
||||
C({Vec(T(0), -T(1), T(0)), pos_xyz}, Vec(T(1), T(0), T(1))).FloatComp(0.02),
|
||||
C({Vec(T(0), T(0), -T(1)), pos_xyz}, Vec(T(1), T(1), T(0))).FloatComp(0.02),
|
||||
};
|
||||
|
||||
auto error_msg = [](auto a, const char* op, auto b) {
|
||||
return "12:34 error: " + OverflowErrorMessage(a, op, b) + R"(
|
||||
12:34 note: when calculating reflect)";
|
||||
};
|
||||
ConcatInto( //
|
||||
r, std::vector<Case>{
|
||||
// Overflow the dot product operation
|
||||
E({Vec(T::Highest(), T::Highest(), T(0)), Vec(T(1), T(1), T(0))},
|
||||
error_msg(T::Highest(), "+", T::Highest())),
|
||||
E({Vec(T::Lowest(), T::Lowest(), T(0)), Vec(T(1), T(1), T(0))},
|
||||
error_msg(T::Lowest(), "+", T::Lowest())),
|
||||
});
|
||||
|
||||
return r;
|
||||
}
|
||||
INSTANTIATE_TEST_SUITE_P( //
|
||||
Reflect,
|
||||
ResolverConstEvalBuiltinTest,
|
||||
testing::Combine(testing::Values(sem::BuiltinType::kReflect),
|
||||
testing::ValuesIn(
|
||||
// ReflectCases<f32>())));
|
||||
Concat(ReflectCases<AFloat>(), //
|
||||
ReflectCases<f32>(), //
|
||||
ReflectCases<f16>()))));
|
||||
|
||||
template <typename T>
|
||||
std::vector<Case> RadiansCases() {
|
||||
return std::vector<Case>{
|
||||
|
||||
@@ -105,34 +105,34 @@ inline void CheckConstant(const sem::Constant* got_constant,
|
||||
auto got = std::get<T>(got_scalar);
|
||||
|
||||
if constexpr (std::is_same_v<bool, T>) {
|
||||
EXPECT_EQ(got, expected);
|
||||
EXPECT_EQ(got, expected) << "index: " << i;
|
||||
} else if constexpr (IsFloatingPoint<T>) {
|
||||
if (std::isnan(expected)) {
|
||||
EXPECT_TRUE(std::isnan(got));
|
||||
EXPECT_TRUE(std::isnan(got)) << "index: " << i;
|
||||
} else {
|
||||
if (flags.pos_or_neg) {
|
||||
got = Abs(got);
|
||||
}
|
||||
if (flags.float_compare) {
|
||||
if (flags.float_compare_epsilon) {
|
||||
EXPECT_NEAR(got, expected, *flags.float_compare_epsilon);
|
||||
EXPECT_NEAR(got, expected, *flags.float_compare_epsilon)
|
||||
<< "index: " << i;
|
||||
} else {
|
||||
EXPECT_FLOAT_EQ(got, expected);
|
||||
EXPECT_FLOAT_EQ(got, expected) << "index: " << i;
|
||||
}
|
||||
} else {
|
||||
EXPECT_EQ(got, expected);
|
||||
EXPECT_EQ(got, expected) << "index: " << i;
|
||||
}
|
||||
}
|
||||
} else {
|
||||
if (flags.pos_or_neg) {
|
||||
auto got_abs = Abs(got);
|
||||
EXPECT_EQ(got_abs, expected);
|
||||
} else {
|
||||
EXPECT_EQ(got, expected);
|
||||
got = Abs(got);
|
||||
}
|
||||
EXPECT_EQ(got, expected) << "index: " << i;
|
||||
|
||||
// Check that the constant's integer doesn't contain unexpected
|
||||
// data in the MSBs that are outside of the bit-width of T.
|
||||
EXPECT_EQ(AInt(got), AInt(expected));
|
||||
EXPECT_EQ(AInt(got), AInt(expected)) << "index: " << i;
|
||||
}
|
||||
},
|
||||
expected_scalar);
|
||||
|
||||
@@ -13698,12 +13698,12 @@ constexpr OverloadInfo kOverloads[] = {
|
||||
/* num parameters */ 2,
|
||||
/* num template types */ 1,
|
||||
/* num template numbers */ 1,
|
||||
/* template types */ &kTemplateTypes[26],
|
||||
/* template types */ &kTemplateTypes[23],
|
||||
/* template numbers */ &kTemplateNumbers[4],
|
||||
/* parameters */ &kParameters[626],
|
||||
/* return matcher indices */ &kMatcherIndices[30],
|
||||
/* flags */ OverloadFlags(OverloadFlag::kIsBuiltin, OverloadFlag::kSupportsVertexPipeline, OverloadFlag::kSupportsFragmentPipeline, OverloadFlag::kSupportsComputePipeline),
|
||||
/* const eval */ nullptr,
|
||||
/* const eval */ &ConstEval::reflect,
|
||||
},
|
||||
{
|
||||
/* [449] */
|
||||
@@ -14415,7 +14415,7 @@ constexpr IntrinsicInfo kBuiltins[] = {
|
||||
},
|
||||
{
|
||||
/* [63] */
|
||||
/* fn reflect<N : num, T : f32_f16>(vec<N, T>, vec<N, T>) -> vec<N, T> */
|
||||
/* fn reflect<N : num, T : fa_f32_f16>(vec<N, T>, vec<N, T>) -> vec<N, T> */
|
||||
/* num overloads */ 1,
|
||||
/* overloads */ &kOverloads[448],
|
||||
},
|
||||
|
||||
Reference in New Issue
Block a user