tint: Implement const-eval of modf

Bug: tint:1581
Change-Id: I53151ebf43601cd6afcdd2ec91d0ff9c4e650ef3
Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/111241
Reviewed-by: Antonio Maiorano <amaiorano@google.com>
Kokoro: Kokoro <noreply+kokoro@google.com>
Commit-Queue: Ben Clayton <bclayton@google.com>
This commit is contained in:
Ben Clayton
2022-11-23 00:05:05 +00:00
committed by Dawn LUCI CQ
parent 92d858ac3c
commit 329dfd7cbd
132 changed files with 966 additions and 774 deletions

View File

@@ -2266,6 +2266,40 @@ ConstEval::Result ConstEval::min(const sem::Type* ty,
return TransformElements(builder, ty, transform, args[0], args[1]);
}
ConstEval::Result ConstEval::modf(const sem::Type* ty,
utils::VectorRef<const sem::Constant*> args,
const Source& source) {
auto transform_fract = [&](const sem::Constant* c) {
auto create = [&](auto e) {
return CreateElement(builder, source, c->Type(),
decltype(e)(e.value - std::trunc(e.value)));
};
return Dispatch_fa_f32_f16(create, c);
};
auto transform_whole = [&](const sem::Constant* c) {
auto create = [&](auto e) {
return CreateElement(builder, source, c->Type(), decltype(e)(std::trunc(e.value)));
};
return Dispatch_fa_f32_f16(create, c);
};
utils::Vector<const sem::Constant*, 2> fields;
if (auto fract = TransformElements(builder, args[0]->Type(), transform_fract, args[0])) {
fields.Push(fract.Get());
} else {
return utils::Failure;
}
if (auto whole = TransformElements(builder, args[0]->Type(), transform_whole, args[0])) {
fields.Push(whole.Get());
} else {
return utils::Failure;
}
return CreateComposite(builder, ty, std::move(fields));
}
ConstEval::Result ConstEval::pack2x16float(const sem::Type* ty,
utils::VectorRef<const sem::Constant*> args,
const Source& source) {

View File

@@ -628,6 +628,15 @@ class ConstEval {
utils::VectorRef<const sem::Constant*> args,
const Source& source);
/// modf builtin
/// @param ty the expression type
/// @param args the input arguments
/// @param source the source location
/// @return the result value, or null if the value cannot be calculated
Result modf(const sem::Type* ty,
utils::VectorRef<const sem::Constant*> args,
const Source& source);
/// pack2x16float builtin
/// @param ty the expression type
/// @param args the input arguments

View File

@@ -98,6 +98,11 @@ static Case C(std::initializer_list<Types> args, Types result) {
return Case{utils::Vector<Types, 8>{args}, utils::Vector<Types, 2>{std::move(result)}};
}
/// Creates a Case with Values for args and result
static Case C(std::initializer_list<Types> args, std::initializer_list<Types> results) {
return Case{utils::Vector<Types, 8>{args}, utils::Vector<Types, 2>{results}};
}
/// Convenience overload that creates a Case with just scalars
static Case C(std::initializer_list<ScalarTypes> sargs, ScalarTypes sresult) {
utils::Vector<Types, 8> args;
@@ -109,6 +114,20 @@ static Case C(std::initializer_list<ScalarTypes> sargs, ScalarTypes sresult) {
return Case{std::move(args), utils::Vector<Types, 2>{std::move(result)}};
}
/// Creates a Case with Values for args and result
static Case C(std::initializer_list<ScalarTypes> sargs,
std::initializer_list<ScalarTypes> sresults) {
utils::Vector<Types, 8> args;
for (auto& sa : sargs) {
std::visit([&](auto&& v) { return args.Push(Val(v)); }, sa);
}
utils::Vector<Types, 2> results;
for (auto& sa : sresults) {
std::visit([&](auto&& v) { return results.Push(Val(v)); }, sa);
}
return Case{std::move(args), std::move(results)};
}
/// Creates a Case with Values for args and expected error
static Case E(std::initializer_list<Types> args, std::string err) {
return Case{utils::Vector<Types, 8>{args}, std::move(err)};
@@ -1290,6 +1309,38 @@ INSTANTIATE_TEST_SUITE_P( //
MinCases<AFloat>(),
MinCases<f32>(),
MinCases<f16>()))));
template <typename T>
std::vector<Case> ModfCases() {
return {
// Scalar tests
// in fract whole
C({T(0.0)}, {T(0.0), T(0.0)}), //
C({T(1.0)}, {T(0.0), T(1.0)}), //
C({T(2.0)}, {T(0.0), T(2.0)}), //
C({T(1.5)}, {T(0.5), T(1.0)}), //
C({T(4.25)}, {T(0.25), T(4.0)}), //
C({T(-1.0)}, {T(0.0), T(-1.0)}), //
C({T(-2.0)}, {T(0.0), T(-2.0)}), //
C({T(-1.5)}, {T(-0.5), T(-1.0)}), //
C({T(-4.25)}, {T(-0.25), T(-4.0)}), //
C({T::Lowest()}, {T(0.0), T::Lowest()}), //
C({T::Highest()}, {T(0.0), T::Highest()}), //
// Vector tests
// in fract whole
C({Vec(T(0.0), T(0.0))}, {Vec(T(0.0), T(0.0)), Vec(T(0.0), T(0.0))}),
C({Vec(T(1.0), T(2.0))}, {Vec(T(0.0), T(0.0)), Vec(T(1), T(2))}),
C({Vec(T(-2.0), T(1.0))}, {Vec(T(0.0), T(0.0)), Vec(T(-2), T(1))}),
C({Vec(T(1.5), T(-2.25))}, {Vec(T(0.5), T(-0.25)), Vec(T(1.0), T(-2.0))}),
C({Vec(T::Lowest(), T::Highest())}, {Vec(T(0.0), T(0.0)), Vec(T::Lowest(), T::Highest())}),
};
}
INSTANTIATE_TEST_SUITE_P( //
Modf,
ResolverConstEvalBuiltinTest,
testing::Combine(testing::Values(sem::BuiltinType::kModf),
testing::ValuesIn(Concat(ModfCases<f32>(), //
ModfCases<f16>()))));
std::vector<Case> Pack4x8snormCases() {
return {

View File

@@ -12851,7 +12851,7 @@ constexpr OverloadInfo kOverloads[] = {
/* parameters */ &kParameters[878],
/* return matcher indices */ &kMatcherIndices[106],
/* flags */ OverloadFlags(OverloadFlag::kIsBuiltin, OverloadFlag::kSupportsVertexPipeline, OverloadFlag::kSupportsFragmentPipeline, OverloadFlag::kSupportsComputePipeline),
/* const eval */ nullptr,
/* const eval */ &ConstEval::modf,
},
{
/* [378] */
@@ -12863,7 +12863,7 @@ constexpr OverloadInfo kOverloads[] = {
/* parameters */ &kParameters[879],
/* return matcher indices */ &kMatcherIndices[45],
/* flags */ OverloadFlags(OverloadFlag::kIsBuiltin, OverloadFlag::kSupportsVertexPipeline, OverloadFlag::kSupportsFragmentPipeline, OverloadFlag::kSupportsComputePipeline),
/* const eval */ nullptr,
/* const eval */ &ConstEval::modf,
},
{
/* [379] */
@@ -14351,8 +14351,8 @@ constexpr IntrinsicInfo kBuiltins[] = {
},
{
/* [53] */
/* fn modf<T : f32_f16>(T) -> __modf_result<T> */
/* fn modf<N : num, T : f32_f16>(vec<N, T>) -> __modf_result_vec<N, T> */
/* fn modf<T : f32_f16>(@test_value(-1.5) T) -> __modf_result<T> */
/* fn modf<N : num, T : f32_f16>(@test_value(-1.5) vec<N, T>) -> __modf_result_vec<N, T> */
/* num overloads */ 2,
/* overloads */ &kOverloads[377],
},

View File

@@ -530,9 +530,13 @@ bool Validator::AddressSpaceLayout(const sem::Variable* var,
}
if (auto* str = var->Type()->UnwrapRef()->As<sem::Struct>()) {
if (!AddressSpaceLayout(str, var->AddressSpace(), str->Declaration()->source, layouts)) {
AddNote("see declaration of variable", var->Declaration()->source);
return false;
// Check the structure has a declaration. Builtins like modf() and frexp() return untypeable
// structures, and so they have no declaration. Just skip validation for these.
if (auto* str_decl = str->Declaration()) {
if (!AddressSpaceLayout(str, var->AddressSpace(), str_decl->source, layouts)) {
AddNote("see declaration of variable", var->Declaration()->source);
return false;
}
}
} else {
Source source = var->Declaration()->source;