tint: const eval of modulo operator

Bug: tint:1581
Change-Id: Icf9aaab29f45a41e6c367f60bf98ccc8958a56c7
Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/112322
Commit-Queue: Antonio Maiorano <amaiorano@google.com>
Reviewed-by: Ben Clayton <bclayton@google.com>
Kokoro: Kokoro <noreply+kokoro@google.com>
This commit is contained in:
Antonio Maiorano 2022-11-30 16:02:49 +00:00 committed by Dawn LUCI CQ
parent 171c542bf7
commit babc21f931
5 changed files with 246 additions and 18 deletions

View File

@ -932,10 +932,10 @@ init mat4x4<T: fa_f32_f16>(vec4<T>, vec4<T>, vec4<T>, vec4<T>) -> mat4x4<T>
@const op / <T: fia_fiu32_f16, N: num> (vec<N, T>, T) -> vec<N, T>
@const op / <T: fia_fiu32_f16, N: num> (T, vec<N, T>) -> vec<N, T>
op % <T: fiu32_f16>(T, T) -> T
op % <T: fiu32_f16, N: num> (vec<N, T>, vec<N, T>) -> vec<N, T>
op % <T: fiu32_f16, N: num> (vec<N, T>, T) -> vec<N, T>
op % <T: fiu32_f16, N: num> (T, vec<N, T>) -> vec<N, T>
@const op % <T: fia_fiu32_f16>(T, T) -> T
@const op % <T: fia_fiu32_f16, N: num> (vec<N, T>, vec<N, T>) -> vec<N, T>
@const op % <T: fia_fiu32_f16, N: num> (vec<N, T>, T) -> vec<N, T>
@const op % <T: fia_fiu32_f16, N: num> (T, vec<N, T>) -> vec<N, T>
@const op ^ <T: ia_iu32>(T, T) -> T
@const op ^ <T: ia_iu32, N: num> (vec<N, T>, vec<N, T>) -> vec<N, T>

View File

@ -744,7 +744,6 @@ utils::Result<NumberT> ConstEval::Mul(const Source& source, NumberT a, NumberT b
using T = UnwrapNumber<NumberT>;
NumberT result;
if constexpr (IsAbstract<NumberT> || IsFloatingPoint<NumberT>) {
// Check for over/underflow for abstract values
if (auto r = CheckedMul(a, b)) {
result = r->value;
} else {
@ -770,7 +769,6 @@ template <typename NumberT>
utils::Result<NumberT> ConstEval::Div(const Source& source, NumberT a, NumberT b) {
NumberT result;
if constexpr (IsAbstract<NumberT> || IsFloatingPoint<NumberT>) {
// Check for over/underflow for abstract values
if (auto r = CheckedDiv(a, b)) {
result = r->value;
} else {
@ -799,6 +797,38 @@ utils::Result<NumberT> ConstEval::Div(const Source& source, NumberT a, NumberT b
return result;
}
template <typename NumberT>
utils::Result<NumberT> ConstEval::Mod(const Source& source, NumberT a, NumberT b) {
NumberT result;
if constexpr (IsAbstract<NumberT> || IsFloatingPoint<NumberT>) {
if (auto r = CheckedMod(a, b)) {
result = r->value;
} else {
AddError(OverflowErrorMessage(a, "%", b), source);
return utils::Failure;
}
} else {
using T = UnwrapNumber<NumberT>;
auto lhs = a.value;
auto rhs = b.value;
if (rhs == 0) {
// lhs % 0 is an error
AddError(OverflowErrorMessage(a, "%", b), source);
return utils::Failure;
}
if constexpr (std::is_signed_v<T>) {
// For signed integers, lhs % -1 where lhs is the
// most negative value is an error
if (rhs == -1 && lhs == std::numeric_limits<T>::min()) {
AddError(OverflowErrorMessage(a, "%", b), source);
return utils::Failure;
}
}
result = lhs % rhs;
}
return result;
}
template <typename NumberT>
utils::Result<NumberT> ConstEval::Dot2(const Source& source,
NumberT a1,
@ -1111,6 +1141,15 @@ auto ConstEval::DivFunc(const Source& source, const sem::Type* elem_ty) {
};
}
auto ConstEval::ModFunc(const Source& source, const sem::Type* elem_ty) {
return [=](auto a1, auto a2) -> ImplResult {
if (auto r = Mod(source, a1, a2)) {
return CreateElement(builder, source, elem_ty, r.Get());
}
return utils::Failure;
};
}
auto ConstEval::Dot2Func(const Source& source, const sem::Type* elem_ty) {
return [=](auto a1, auto a2, auto b1, auto b2) -> ImplResult {
if (auto r = Dot2(source, a1, a2, b1, b2)) {
@ -1683,6 +1722,16 @@ ConstEval::Result ConstEval::OpDivide(const sem::Type* ty,
return TransformBinaryElements(builder, ty, transform, args[0], args[1]);
}
ConstEval::Result ConstEval::OpModulo(const sem::Type* ty,
utils::VectorRef<const sem::Constant*> args,
const Source& source) {
auto transform = [&](const sem::Constant* c0, const sem::Constant* c1) {
return Dispatch_fia_fiu32_f16(ModFunc(source, c0->Type()), c0, c1);
};
return TransformBinaryElements(builder, ty, transform, args[0], args[1]);
}
ConstEval::Result ConstEval::OpEqual(const sem::Type* ty,
utils::VectorRef<const sem::Constant*> args,
const Source& source) {

View File

@ -283,6 +283,15 @@ class ConstEval {
utils::VectorRef<const sem::Constant*> args,
const Source& source);
/// Modulo operator '%'
/// @param ty the expression type
/// @param args the input arguments
/// @param source the source location
/// @return the result value, or null if the value cannot be calculated
Result OpModulo(const sem::Type* ty,
utils::VectorRef<const sem::Constant*> args,
const Source& source);
/// Equality operator '=='
/// @param ty the expression type
/// @param args the input arguments
@ -1040,6 +1049,14 @@ class ConstEval {
template <typename NumberT>
utils::Result<NumberT> Div(const Source& source, NumberT a, NumberT b);
/// Returns the (signed) remainder of the division of two Number<T>s
/// @param source the source location
/// @param a the lhs number
/// @param b the rhs number
/// @returns the result number on success, or logs an error and returns Failure
template <typename NumberT>
utils::Result<NumberT> Mod(const Source& source, NumberT a, NumberT b);
/// Returns the dot product of (a1,a2) with (b1,b2)
/// @param source the source location
/// @param a1 component 1 of lhs vector
@ -1216,6 +1233,13 @@ class ConstEval {
/// @returns the callable function
auto DivFunc(const Source& source, const sem::Type* elem_ty);
/// Returns a callable that calls Mod, and creates a Constant with its result of type `elem_ty`
/// if successful, or returns Failure otherwise.
/// @param source the source location
/// @param elem_ty the element type of the Constant to create on success
/// @returns the callable function
auto ModFunc(const Source& source, const sem::Type* elem_ty);
/// Returns a callable that calls Dot2, and creates a Constant with its result of type `elem_ty`
/// if successful, or returns Failure otherwise.
/// @param source the source location

View File

@ -485,6 +485,161 @@ INSTANTIATE_TEST_SUITE_P(Div,
OpDivFloatCases<f32>(),
OpDivFloatCases<f16>()))));
template <typename T>
std::vector<Case> OpModCases() {
auto error_msg = [](auto a, auto b) {
return "12:34 error: " + OverflowErrorMessage(a, "%", b);
};
// Common cases for all types
std::vector<Case> r = {
C(T{0}, T{1}, T{0}), //
C(T{1}, T{1}, T{0}), //
C(T{10}, T{1}, T{0}), //
C(T{10}, T{2}, T{0}), //
C(T{10}, T{3}, T{1}), //
C(T{10}, T{4}, T{2}), //
C(T{10}, T{5}, T{0}), //
C(T{10}, T{6}, T{4}), //
C(T{10}, T{5}, T{0}), //
C(T{10}, T{8}, T{2}), //
C(T{10}, T{9}, T{1}), //
C(T{10}, T{10}, T{0}), //
// Error on divide by zero
E(T{123}, T{0}, error_msg(T{123}, T{0})),
E(T::Highest(), T{0}, error_msg(T::Highest(), T{0})),
E(T::Lowest(), T{0}, error_msg(T::Lowest(), T{0})),
};
if constexpr (IsIntegral<T>) {
ConcatInto( //
r, std::vector<Case>{
C(T::Highest(), T{T::Highest() - T{1}}, T{1}),
});
}
if constexpr (IsSignedIntegral<T>) {
ConcatInto( //
r, std::vector<Case>{
C(T::Lowest(), T{T::Lowest() + T{1}}, -T(1)),
// Error on most negative integer divided by -1
E(T::Lowest(), T{-1}, error_msg(T::Lowest(), T{-1})),
});
}
// Negative values (both signed integrals and floating point)
if constexpr (IsSignedIntegral<T> || IsFloatingPoint<T>) {
ConcatInto( //
r, std::vector<Case>{
C(-T{1}, T{1}, T{0}), //
// lhs negative, rhs positive
C(-T{10}, T{1}, T{0}), //
C(-T{10}, T{2}, T{0}), //
C(-T{10}, T{3}, -T{1}), //
C(-T{10}, T{4}, -T{2}), //
C(-T{10}, T{5}, T{0}), //
C(-T{10}, T{6}, -T{4}), //
C(-T{10}, T{5}, T{0}), //
C(-T{10}, T{8}, -T{2}), //
C(-T{10}, T{9}, -T{1}), //
C(-T{10}, T{10}, T{0}), //
// lhs positive, rhs negative
C(T{10}, -T{1}, T{0}), //
C(T{10}, -T{2}, T{0}), //
C(T{10}, -T{3}, T{1}), //
C(T{10}, -T{4}, T{2}), //
C(T{10}, -T{5}, T{0}), //
C(T{10}, -T{6}, T{4}), //
C(T{10}, -T{5}, T{0}), //
C(T{10}, -T{8}, T{2}), //
C(T{10}, -T{9}, T{1}), //
C(T{10}, -T{10}, T{0}), //
// lhs negative, rhs negative
C(-T{10}, -T{1}, T{0}), //
C(-T{10}, -T{2}, T{0}), //
C(-T{10}, -T{3}, -T{1}), //
C(-T{10}, -T{4}, -T{2}), //
C(-T{10}, -T{5}, T{0}), //
C(-T{10}, -T{6}, -T{4}), //
C(-T{10}, -T{5}, T{0}), //
C(-T{10}, -T{8}, -T{2}), //
C(-T{10}, -T{9}, -T{1}), //
C(-T{10}, -T{10}, T{0}), //
});
}
// Float values
if constexpr (IsFloatingPoint<T>) {
ConcatInto( //
r, std::vector<Case>{
C(T{10.5}, T{1}, T{0.5}), //
C(T{10.5}, T{2}, T{0.5}), //
C(T{10.5}, T{3}, T{1.5}), //
C(T{10.5}, T{4}, T{2.5}), //
C(T{10.5}, T{5}, T{0.5}), //
C(T{10.5}, T{6}, T{4.5}), //
C(T{10.5}, T{5}, T{0.5}), //
C(T{10.5}, T{8}, T{2.5}), //
C(T{10.5}, T{9}, T{1.5}), //
C(T{10.5}, T{10}, T{0.5}), //
// lhs negative, rhs positive
C(-T{10.5}, T{1}, -T{0.5}), //
C(-T{10.5}, T{2}, -T{0.5}), //
C(-T{10.5}, T{3}, -T{1.5}), //
C(-T{10.5}, T{4}, -T{2.5}), //
C(-T{10.5}, T{5}, -T{0.5}), //
C(-T{10.5}, T{6}, -T{4.5}), //
C(-T{10.5}, T{5}, -T{0.5}), //
C(-T{10.5}, T{8}, -T{2.5}), //
C(-T{10.5}, T{9}, -T{1.5}), //
C(-T{10.5}, T{10}, -T{0.5}), //
// lhs positive, rhs negative
C(T{10.5}, -T{1}, T{0.5}), //
C(T{10.5}, -T{2}, T{0.5}), //
C(T{10.5}, -T{3}, T{1.5}), //
C(T{10.5}, -T{4}, T{2.5}), //
C(T{10.5}, -T{5}, T{0.5}), //
C(T{10.5}, -T{6}, T{4.5}), //
C(T{10.5}, -T{5}, T{0.5}), //
C(T{10.5}, -T{8}, T{2.5}), //
C(T{10.5}, -T{9}, T{1.5}), //
C(T{10.5}, -T{10}, T{0.5}), //
// lhs negative, rhs negative
C(-T{10.5}, -T{1}, -T{0.5}), //
C(-T{10.5}, -T{2}, -T{0.5}), //
C(-T{10.5}, -T{3}, -T{1.5}), //
C(-T{10.5}, -T{4}, -T{2.5}), //
C(-T{10.5}, -T{5}, -T{0.5}), //
C(-T{10.5}, -T{6}, -T{4.5}), //
C(-T{10.5}, -T{5}, -T{0.5}), //
C(-T{10.5}, -T{8}, -T{2.5}), //
C(-T{10.5}, -T{9}, -T{1.5}), //
C(-T{10.5}, -T{10}, -T{0.5}), //
});
}
return r;
}
INSTANTIATE_TEST_SUITE_P(Mod,
ResolverConstEvalBinaryOpTest,
testing::Combine( //
testing::Values(ast::BinaryOp::kModulo),
testing::ValuesIn(Concat( //
OpModCases<AInt>(),
OpModCases<i32>(),
OpModCases<u32>(),
OpModCases<AFloat>(),
OpModCases<f32>(),
OpModCases<f16>()))));
template <typename T, bool equals>
std::vector<Case> OpEqualCases() {
return {

View File

@ -11322,48 +11322,48 @@ constexpr OverloadInfo kOverloads[] = {
/* num parameters */ 2,
/* num template types */ 1,
/* num template numbers */ 0,
/* template types */ &kTemplateTypes[30],
/* template types */ &kTemplateTypes[22],
/* template numbers */ &kTemplateNumbers[10],
/* parameters */ &kParameters[718],
/* return matcher indices */ &kMatcherIndices[3],
/* flags */ OverloadFlags(OverloadFlag::kIsOperator, OverloadFlag::kSupportsVertexPipeline, OverloadFlag::kSupportsFragmentPipeline, OverloadFlag::kSupportsComputePipeline),
/* const eval */ nullptr,
/* const eval */ &ConstEval::OpModulo,
},
{
/* [251] */
/* num parameters */ 2,
/* num template types */ 1,
/* num template numbers */ 1,
/* template types */ &kTemplateTypes[30],
/* template types */ &kTemplateTypes[22],
/* template numbers */ &kTemplateNumbers[4],
/* parameters */ &kParameters[720],
/* return matcher indices */ &kMatcherIndices[30],
/* flags */ OverloadFlags(OverloadFlag::kIsOperator, OverloadFlag::kSupportsVertexPipeline, OverloadFlag::kSupportsFragmentPipeline, OverloadFlag::kSupportsComputePipeline),
/* const eval */ nullptr,
/* const eval */ &ConstEval::OpModulo,
},
{
/* [252] */
/* num parameters */ 2,
/* num template types */ 1,
/* num template numbers */ 1,
/* template types */ &kTemplateTypes[30],
/* template types */ &kTemplateTypes[22],
/* template numbers */ &kTemplateNumbers[4],
/* parameters */ &kParameters[722],
/* return matcher indices */ &kMatcherIndices[30],
/* flags */ OverloadFlags(OverloadFlag::kIsOperator, OverloadFlag::kSupportsVertexPipeline, OverloadFlag::kSupportsFragmentPipeline, OverloadFlag::kSupportsComputePipeline),
/* const eval */ nullptr,
/* const eval */ &ConstEval::OpModulo,
},
{
/* [253] */
/* num parameters */ 2,
/* num template types */ 1,
/* num template numbers */ 1,
/* template types */ &kTemplateTypes[30],
/* template types */ &kTemplateTypes[22],
/* template numbers */ &kTemplateNumbers[4],
/* parameters */ &kParameters[724],
/* return matcher indices */ &kMatcherIndices[30],
/* flags */ OverloadFlags(OverloadFlag::kIsOperator, OverloadFlag::kSupportsVertexPipeline, OverloadFlag::kSupportsFragmentPipeline, OverloadFlag::kSupportsComputePipeline),
/* const eval */ nullptr,
/* const eval */ &ConstEval::OpModulo,
},
{
/* [254] */
@ -14929,10 +14929,10 @@ constexpr IntrinsicInfo kBinaryOperators[] = {
},
{
/* [4] */
/* op %<T : fiu32_f16>(T, T) -> T */
/* op %<T : fiu32_f16, N : num>(vec<N, T>, vec<N, T>) -> vec<N, T> */
/* op %<T : fiu32_f16, N : num>(vec<N, T>, T) -> vec<N, T> */
/* op %<T : fiu32_f16, N : num>(T, vec<N, T>) -> vec<N, T> */
/* op %<T : fia_fiu32_f16>(T, T) -> T */
/* op %<T : fia_fiu32_f16, N : num>(vec<N, T>, vec<N, T>) -> vec<N, T> */
/* op %<T : fia_fiu32_f16, N : num>(vec<N, T>, T) -> vec<N, T> */
/* op %<T : fia_fiu32_f16, N : num>(T, vec<N, T>) -> vec<N, T> */
/* num overloads */ 4,
/* overloads */ &kOverloads[250],
},