mirror of
https://github.com/encounter/dawn-cmake.git
synced 2025-12-16 00:17:03 +00:00
Const eval for normalize
This CL adds const-eval for the `normalize` builtin. Bug: tint:1581 Change-Id: I6d5ba3e0ba507921137ca90c4caefa9daf88f735 Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/111740 Kokoro: Kokoro <noreply+kokoro@google.com> Reviewed-by: Antonio Maiorano <amaiorano@google.com> Commit-Queue: Dan Sinclair <dsinclair@chromium.org>
This commit is contained in:
committed by
Dawn LUCI CQ
parent
8243aeda75
commit
8392a82a40
@@ -506,7 +506,7 @@ fn mix<N: num, T: f32_f16>(vec<N, T>, vec<N, T>, vec<N, T>) -> vec<N, T>
|
||||
fn mix<N: num, T: f32_f16>(vec<N, T>, vec<N, T>, T) -> vec<N, T>
|
||||
@const fn modf<T: fa_f32_f16>(@test_value(-1.5) T) -> __modf_result<T>
|
||||
@const fn modf<N: num, T: fa_f32_f16>(@test_value(-1.5) vec<N, T>) -> __modf_result_vec<N, T>
|
||||
fn normalize<N: num, T: f32_f16>(vec<N, T>) -> vec<N, T>
|
||||
@const fn normalize<N: num, T: fa_f32_f16>(vec<N, T>) -> vec<N, T>
|
||||
@const fn pack2x16float(vec2<f32>) -> u32
|
||||
@const fn pack2x16snorm(vec2<f32>) -> u32
|
||||
@const fn pack2x16unorm(vec2<f32>) -> u32
|
||||
|
||||
@@ -1415,7 +1415,7 @@ TEST_F(ResolverBuiltinFloatTest, Normalize_Error_NoParams) {
|
||||
EXPECT_EQ(r()->error(), R"(error: no matching call to normalize()
|
||||
|
||||
1 candidate function:
|
||||
normalize(vecN<T>) -> vecN<T> where: T is f32 or f16
|
||||
normalize(vecN<T>) -> vecN<T> where: T is abstract-float, f32 or f16
|
||||
)");
|
||||
}
|
||||
|
||||
|
||||
@@ -461,6 +461,8 @@ struct Composite : ImplConstant {
|
||||
/// CreateElement constructs and returns an Element<T>.
|
||||
template <typename T>
|
||||
ImplResult CreateElement(ProgramBuilder& builder, const Source& source, const sem::Type* t, T v) {
|
||||
TINT_ASSERT(Resolver, t->is_scalar());
|
||||
|
||||
if constexpr (IsFloatingPoint<T>) {
|
||||
if (!std::isfinite(v.value)) {
|
||||
auto msg = OverflowErrorMessage(v, builder.FriendlyName(t));
|
||||
@@ -652,8 +654,9 @@ ImplResult TransformBinaryElements(ProgramBuilder& builder,
|
||||
F&& f,
|
||||
const sem::Constant* c0,
|
||||
const sem::Constant* c1) {
|
||||
uint32_t n0 = 0, n1 = 0;
|
||||
uint32_t n0 = 0;
|
||||
sem::Type::ElementOf(c0->Type(), &n0);
|
||||
uint32_t n1 = 0;
|
||||
sem::Type::ElementOf(c1->Type(), &n1);
|
||||
uint32_t max_n = std::max(n0, n1);
|
||||
// If arity of both constants is 1, invoke callback
|
||||
@@ -664,7 +667,7 @@ ImplResult TransformBinaryElements(ProgramBuilder& builder,
|
||||
utils::Vector<const sem::Constant*, 8> els;
|
||||
els.Reserve(max_n);
|
||||
for (uint32_t i = 0; i < max_n; i++) {
|
||||
auto nested_or_self = [&](auto& c, uint32_t num_elems) {
|
||||
auto nested_or_self = [&](auto* c, uint32_t num_elems) {
|
||||
if (num_elems == 1) {
|
||||
return c;
|
||||
}
|
||||
@@ -2734,6 +2737,23 @@ ConstEval::Result ConstEval::modf(const sem::Type* ty,
|
||||
return CreateComposite(builder, ty, std::move(fields));
|
||||
}
|
||||
|
||||
ConstEval::Result ConstEval::normalize(const sem::Type* ty,
|
||||
utils::VectorRef<const sem::Constant*> args,
|
||||
const Source& source) {
|
||||
auto* len_ty = sem::Type::DeepestElementOf(ty);
|
||||
auto len = Length(source, len_ty, args[0]);
|
||||
if (!len) {
|
||||
AddNote("when calculating normalize", source);
|
||||
return utils::Failure;
|
||||
}
|
||||
auto* v = len.Get();
|
||||
if (v->AllZero()) {
|
||||
AddError("zero length vector can not be normalized", source);
|
||||
return utils::Failure;
|
||||
}
|
||||
return OpDivide(ty, utils::Vector{args[0], v}, source);
|
||||
}
|
||||
|
||||
ConstEval::Result ConstEval::pack2x16float(const sem::Type* ty,
|
||||
utils::VectorRef<const sem::Constant*> args,
|
||||
const Source& source) {
|
||||
|
||||
@@ -719,6 +719,15 @@ class ConstEval {
|
||||
utils::VectorRef<const sem::Constant*> args,
|
||||
const Source& source);
|
||||
|
||||
/// normalize builtin
|
||||
/// @param ty the expression type
|
||||
/// @param args the input arguments
|
||||
/// @param source the source location
|
||||
/// @return the result value, or null if the value cannot be calculated
|
||||
Result normalize(const sem::Type* ty,
|
||||
utils::VectorRef<const sem::Constant*> args,
|
||||
const Source& source);
|
||||
|
||||
/// pack2x16float builtin
|
||||
/// @param ty the expression type
|
||||
/// @param args the input arguments
|
||||
|
||||
@@ -1694,6 +1694,36 @@ INSTANTIATE_TEST_SUITE_P( //
|
||||
ModfCases<f32>(), //
|
||||
ModfCases<f16>()))));
|
||||
|
||||
template <typename T>
|
||||
std::vector<Case> NormalizeCases() {
|
||||
auto error_msg = [&](auto a) {
|
||||
return "12:34 error: " + OverflowErrorMessage(a, "*", a) + R"(
|
||||
12:34 note: when calculating normalize)";
|
||||
};
|
||||
|
||||
return {
|
||||
C({Vec(T(2), T(4), T(2))}, Vec(T(0.4082482905), T(0.8164965809), T(0.4082482905)))
|
||||
.FloatComp(),
|
||||
|
||||
C({Vec(T(2), T(0), T(0))}, Vec(T(1), T(0), T(0))),
|
||||
C({Vec(T(0), T(2), T(0))}, Vec(T(0), T(1), T(0))),
|
||||
C({Vec(T(0), T(0), T(2))}, Vec(T(0), T(0), T(1))),
|
||||
C({Vec(-T(2), T(0), T(0))}, Vec(-T(1), T(0), T(0))),
|
||||
C({Vec(T(0), -T(2), T(0))}, Vec(T(0), -T(1), T(0))),
|
||||
C({Vec(T(0), T(0), -T(2))}, Vec(T(0), T(0), -T(1))),
|
||||
|
||||
E({Vec(T(0), T(0), T(0))}, "12:34 error: zero length vector can not be normalized"),
|
||||
E({Vec(T::Highest(), T::Highest(), T::Highest())}, error_msg(T::Highest())),
|
||||
};
|
||||
}
|
||||
INSTANTIATE_TEST_SUITE_P( //
|
||||
Normalize,
|
||||
ResolverConstEvalBuiltinTest,
|
||||
testing::Combine(testing::Values(sem::BuiltinType::kNormalize),
|
||||
testing::ValuesIn(Concat(NormalizeCases<AFloat>(), //
|
||||
NormalizeCases<f32>(), //
|
||||
NormalizeCases<f16>()))));
|
||||
|
||||
std::vector<Case> Pack4x8snormCases() {
|
||||
return {
|
||||
C({Vec(f32(0), f32(0), f32(0), f32(0))}, Val(u32(0x0000'0000))),
|
||||
|
||||
@@ -13626,12 +13626,12 @@ constexpr OverloadInfo kOverloads[] = {
|
||||
/* num parameters */ 1,
|
||||
/* num template types */ 1,
|
||||
/* num template numbers */ 1,
|
||||
/* template types */ &kTemplateTypes[26],
|
||||
/* template types */ &kTemplateTypes[23],
|
||||
/* template numbers */ &kTemplateNumbers[4],
|
||||
/* parameters */ &kParameters[880],
|
||||
/* return matcher indices */ &kMatcherIndices[30],
|
||||
/* flags */ OverloadFlags(OverloadFlag::kIsBuiltin, OverloadFlag::kSupportsVertexPipeline, OverloadFlag::kSupportsFragmentPipeline, OverloadFlag::kSupportsComputePipeline),
|
||||
/* const eval */ nullptr,
|
||||
/* const eval */ &ConstEval::normalize,
|
||||
},
|
||||
{
|
||||
/* [443] */
|
||||
@@ -14358,7 +14358,7 @@ constexpr IntrinsicInfo kBuiltins[] = {
|
||||
},
|
||||
{
|
||||
/* [54] */
|
||||
/* fn normalize<N : num, T : f32_f16>(vec<N, T>) -> vec<N, T> */
|
||||
/* fn normalize<N : num, T : fa_f32_f16>(vec<N, T>) -> vec<N, T> */
|
||||
/* num overloads */ 1,
|
||||
/* overloads */ &kOverloads[442],
|
||||
},
|
||||
|
||||
Reference in New Issue
Block a user