Const eval for normalize

This CL adds const-eval for the `normalize` builtin.

Bug: tint:1581
Change-Id: I6d5ba3e0ba507921137ca90c4caefa9daf88f735
Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/111740
Kokoro: Kokoro <noreply+kokoro@google.com>
Reviewed-by: Antonio Maiorano <amaiorano@google.com>
Commit-Queue: Dan Sinclair <dsinclair@chromium.org>
This commit is contained in:
dan sinclair
2022-11-25 04:25:18 +00:00
committed by Dawn LUCI CQ
parent 8243aeda75
commit 8392a82a40
75 changed files with 1830 additions and 152 deletions

View File

@@ -506,7 +506,7 @@ fn mix<N: num, T: f32_f16>(vec<N, T>, vec<N, T>, vec<N, T>) -> vec<N, T>
fn mix<N: num, T: f32_f16>(vec<N, T>, vec<N, T>, T) -> vec<N, T>
@const fn modf<T: fa_f32_f16>(@test_value(-1.5) T) -> __modf_result<T>
@const fn modf<N: num, T: fa_f32_f16>(@test_value(-1.5) vec<N, T>) -> __modf_result_vec<N, T>
fn normalize<N: num, T: f32_f16>(vec<N, T>) -> vec<N, T>
@const fn normalize<N: num, T: fa_f32_f16>(vec<N, T>) -> vec<N, T>
@const fn pack2x16float(vec2<f32>) -> u32
@const fn pack2x16snorm(vec2<f32>) -> u32
@const fn pack2x16unorm(vec2<f32>) -> u32

View File

@@ -1415,7 +1415,7 @@ TEST_F(ResolverBuiltinFloatTest, Normalize_Error_NoParams) {
EXPECT_EQ(r()->error(), R"(error: no matching call to normalize()
1 candidate function:
normalize(vecN<T>) -> vecN<T> where: T is f32 or f16
normalize(vecN<T>) -> vecN<T> where: T is abstract-float, f32 or f16
)");
}

View File

@@ -461,6 +461,8 @@ struct Composite : ImplConstant {
/// CreateElement constructs and returns an Element<T>.
template <typename T>
ImplResult CreateElement(ProgramBuilder& builder, const Source& source, const sem::Type* t, T v) {
TINT_ASSERT(Resolver, t->is_scalar());
if constexpr (IsFloatingPoint<T>) {
if (!std::isfinite(v.value)) {
auto msg = OverflowErrorMessage(v, builder.FriendlyName(t));
@@ -652,8 +654,9 @@ ImplResult TransformBinaryElements(ProgramBuilder& builder,
F&& f,
const sem::Constant* c0,
const sem::Constant* c1) {
uint32_t n0 = 0, n1 = 0;
uint32_t n0 = 0;
sem::Type::ElementOf(c0->Type(), &n0);
uint32_t n1 = 0;
sem::Type::ElementOf(c1->Type(), &n1);
uint32_t max_n = std::max(n0, n1);
// If arity of both constants is 1, invoke callback
@@ -664,7 +667,7 @@ ImplResult TransformBinaryElements(ProgramBuilder& builder,
utils::Vector<const sem::Constant*, 8> els;
els.Reserve(max_n);
for (uint32_t i = 0; i < max_n; i++) {
auto nested_or_self = [&](auto& c, uint32_t num_elems) {
auto nested_or_self = [&](auto* c, uint32_t num_elems) {
if (num_elems == 1) {
return c;
}
@@ -2734,6 +2737,23 @@ ConstEval::Result ConstEval::modf(const sem::Type* ty,
return CreateComposite(builder, ty, std::move(fields));
}
ConstEval::Result ConstEval::normalize(const sem::Type* ty,
utils::VectorRef<const sem::Constant*> args,
const Source& source) {
auto* len_ty = sem::Type::DeepestElementOf(ty);
auto len = Length(source, len_ty, args[0]);
if (!len) {
AddNote("when calculating normalize", source);
return utils::Failure;
}
auto* v = len.Get();
if (v->AllZero()) {
AddError("zero length vector can not be normalized", source);
return utils::Failure;
}
return OpDivide(ty, utils::Vector{args[0], v}, source);
}
ConstEval::Result ConstEval::pack2x16float(const sem::Type* ty,
utils::VectorRef<const sem::Constant*> args,
const Source& source) {

View File

@@ -719,6 +719,15 @@ class ConstEval {
utils::VectorRef<const sem::Constant*> args,
const Source& source);
/// normalize builtin
/// @param ty the expression type
/// @param args the input arguments
/// @param source the source location
/// @return the result value, or null if the value cannot be calculated
Result normalize(const sem::Type* ty,
utils::VectorRef<const sem::Constant*> args,
const Source& source);
/// pack2x16float builtin
/// @param ty the expression type
/// @param args the input arguments

View File

@@ -1694,6 +1694,36 @@ INSTANTIATE_TEST_SUITE_P( //
ModfCases<f32>(), //
ModfCases<f16>()))));
template <typename T>
std::vector<Case> NormalizeCases() {
auto error_msg = [&](auto a) {
return "12:34 error: " + OverflowErrorMessage(a, "*", a) + R"(
12:34 note: when calculating normalize)";
};
return {
C({Vec(T(2), T(4), T(2))}, Vec(T(0.4082482905), T(0.8164965809), T(0.4082482905)))
.FloatComp(),
C({Vec(T(2), T(0), T(0))}, Vec(T(1), T(0), T(0))),
C({Vec(T(0), T(2), T(0))}, Vec(T(0), T(1), T(0))),
C({Vec(T(0), T(0), T(2))}, Vec(T(0), T(0), T(1))),
C({Vec(-T(2), T(0), T(0))}, Vec(-T(1), T(0), T(0))),
C({Vec(T(0), -T(2), T(0))}, Vec(T(0), -T(1), T(0))),
C({Vec(T(0), T(0), -T(2))}, Vec(T(0), T(0), -T(1))),
E({Vec(T(0), T(0), T(0))}, "12:34 error: zero length vector can not be normalized"),
E({Vec(T::Highest(), T::Highest(), T::Highest())}, error_msg(T::Highest())),
};
}
INSTANTIATE_TEST_SUITE_P( //
Normalize,
ResolverConstEvalBuiltinTest,
testing::Combine(testing::Values(sem::BuiltinType::kNormalize),
testing::ValuesIn(Concat(NormalizeCases<AFloat>(), //
NormalizeCases<f32>(), //
NormalizeCases<f16>()))));
std::vector<Case> Pack4x8snormCases() {
return {
C({Vec(f32(0), f32(0), f32(0), f32(0))}, Val(u32(0x0000'0000))),

View File

@@ -13626,12 +13626,12 @@ constexpr OverloadInfo kOverloads[] = {
/* num parameters */ 1,
/* num template types */ 1,
/* num template numbers */ 1,
/* template types */ &kTemplateTypes[26],
/* template types */ &kTemplateTypes[23],
/* template numbers */ &kTemplateNumbers[4],
/* parameters */ &kParameters[880],
/* return matcher indices */ &kMatcherIndices[30],
/* flags */ OverloadFlags(OverloadFlag::kIsBuiltin, OverloadFlag::kSupportsVertexPipeline, OverloadFlag::kSupportsFragmentPipeline, OverloadFlag::kSupportsComputePipeline),
/* const eval */ nullptr,
/* const eval */ &ConstEval::normalize,
},
{
/* [443] */
@@ -14358,7 +14358,7 @@ constexpr IntrinsicInfo kBuiltins[] = {
},
{
/* [54] */
/* fn normalize<N : num, T : f32_f16>(vec<N, T>) -> vec<N, T> */
/* fn normalize<N : num, T : fa_f32_f16>(vec<N, T>) -> vec<N, T> */
/* num overloads */ 1,
/* overloads */ &kOverloads[442],
},