tint: const eval of modulo operator
Bug: tint:1581 Change-Id: Icf9aaab29f45a41e6c367f60bf98ccc8958a56c7 Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/112322 Commit-Queue: Antonio Maiorano <amaiorano@google.com> Reviewed-by: Ben Clayton <bclayton@google.com> Kokoro: Kokoro <noreply+kokoro@google.com>
This commit is contained in:
parent
171c542bf7
commit
babc21f931
|
@ -932,10 +932,10 @@ init mat4x4<T: fa_f32_f16>(vec4<T>, vec4<T>, vec4<T>, vec4<T>) -> mat4x4<T>
|
|||
@const op / <T: fia_fiu32_f16, N: num> (vec<N, T>, T) -> vec<N, T>
|
||||
@const op / <T: fia_fiu32_f16, N: num> (T, vec<N, T>) -> vec<N, T>
|
||||
|
||||
op % <T: fiu32_f16>(T, T) -> T
|
||||
op % <T: fiu32_f16, N: num> (vec<N, T>, vec<N, T>) -> vec<N, T>
|
||||
op % <T: fiu32_f16, N: num> (vec<N, T>, T) -> vec<N, T>
|
||||
op % <T: fiu32_f16, N: num> (T, vec<N, T>) -> vec<N, T>
|
||||
@const op % <T: fia_fiu32_f16>(T, T) -> T
|
||||
@const op % <T: fia_fiu32_f16, N: num> (vec<N, T>, vec<N, T>) -> vec<N, T>
|
||||
@const op % <T: fia_fiu32_f16, N: num> (vec<N, T>, T) -> vec<N, T>
|
||||
@const op % <T: fia_fiu32_f16, N: num> (T, vec<N, T>) -> vec<N, T>
|
||||
|
||||
@const op ^ <T: ia_iu32>(T, T) -> T
|
||||
@const op ^ <T: ia_iu32, N: num> (vec<N, T>, vec<N, T>) -> vec<N, T>
|
||||
|
|
|
@ -744,7 +744,6 @@ utils::Result<NumberT> ConstEval::Mul(const Source& source, NumberT a, NumberT b
|
|||
using T = UnwrapNumber<NumberT>;
|
||||
NumberT result;
|
||||
if constexpr (IsAbstract<NumberT> || IsFloatingPoint<NumberT>) {
|
||||
// Check for over/underflow for abstract values
|
||||
if (auto r = CheckedMul(a, b)) {
|
||||
result = r->value;
|
||||
} else {
|
||||
|
@ -770,7 +769,6 @@ template <typename NumberT>
|
|||
utils::Result<NumberT> ConstEval::Div(const Source& source, NumberT a, NumberT b) {
|
||||
NumberT result;
|
||||
if constexpr (IsAbstract<NumberT> || IsFloatingPoint<NumberT>) {
|
||||
// Check for over/underflow for abstract values
|
||||
if (auto r = CheckedDiv(a, b)) {
|
||||
result = r->value;
|
||||
} else {
|
||||
|
@ -799,6 +797,38 @@ utils::Result<NumberT> ConstEval::Div(const Source& source, NumberT a, NumberT b
|
|||
return result;
|
||||
}
|
||||
|
||||
template <typename NumberT>
|
||||
utils::Result<NumberT> ConstEval::Mod(const Source& source, NumberT a, NumberT b) {
|
||||
NumberT result;
|
||||
if constexpr (IsAbstract<NumberT> || IsFloatingPoint<NumberT>) {
|
||||
if (auto r = CheckedMod(a, b)) {
|
||||
result = r->value;
|
||||
} else {
|
||||
AddError(OverflowErrorMessage(a, "%", b), source);
|
||||
return utils::Failure;
|
||||
}
|
||||
} else {
|
||||
using T = UnwrapNumber<NumberT>;
|
||||
auto lhs = a.value;
|
||||
auto rhs = b.value;
|
||||
if (rhs == 0) {
|
||||
// lhs % 0 is an error
|
||||
AddError(OverflowErrorMessage(a, "%", b), source);
|
||||
return utils::Failure;
|
||||
}
|
||||
if constexpr (std::is_signed_v<T>) {
|
||||
// For signed integers, lhs % -1 where lhs is the
|
||||
// most negative value is an error
|
||||
if (rhs == -1 && lhs == std::numeric_limits<T>::min()) {
|
||||
AddError(OverflowErrorMessage(a, "%", b), source);
|
||||
return utils::Failure;
|
||||
}
|
||||
}
|
||||
result = lhs % rhs;
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
template <typename NumberT>
|
||||
utils::Result<NumberT> ConstEval::Dot2(const Source& source,
|
||||
NumberT a1,
|
||||
|
@ -1111,6 +1141,15 @@ auto ConstEval::DivFunc(const Source& source, const sem::Type* elem_ty) {
|
|||
};
|
||||
}
|
||||
|
||||
auto ConstEval::ModFunc(const Source& source, const sem::Type* elem_ty) {
|
||||
return [=](auto a1, auto a2) -> ImplResult {
|
||||
if (auto r = Mod(source, a1, a2)) {
|
||||
return CreateElement(builder, source, elem_ty, r.Get());
|
||||
}
|
||||
return utils::Failure;
|
||||
};
|
||||
}
|
||||
|
||||
auto ConstEval::Dot2Func(const Source& source, const sem::Type* elem_ty) {
|
||||
return [=](auto a1, auto a2, auto b1, auto b2) -> ImplResult {
|
||||
if (auto r = Dot2(source, a1, a2, b1, b2)) {
|
||||
|
@ -1683,6 +1722,16 @@ ConstEval::Result ConstEval::OpDivide(const sem::Type* ty,
|
|||
return TransformBinaryElements(builder, ty, transform, args[0], args[1]);
|
||||
}
|
||||
|
||||
ConstEval::Result ConstEval::OpModulo(const sem::Type* ty,
|
||||
utils::VectorRef<const sem::Constant*> args,
|
||||
const Source& source) {
|
||||
auto transform = [&](const sem::Constant* c0, const sem::Constant* c1) {
|
||||
return Dispatch_fia_fiu32_f16(ModFunc(source, c0->Type()), c0, c1);
|
||||
};
|
||||
|
||||
return TransformBinaryElements(builder, ty, transform, args[0], args[1]);
|
||||
}
|
||||
|
||||
ConstEval::Result ConstEval::OpEqual(const sem::Type* ty,
|
||||
utils::VectorRef<const sem::Constant*> args,
|
||||
const Source& source) {
|
||||
|
|
|
@ -283,6 +283,15 @@ class ConstEval {
|
|||
utils::VectorRef<const sem::Constant*> args,
|
||||
const Source& source);
|
||||
|
||||
/// Modulo operator '%'
|
||||
/// @param ty the expression type
|
||||
/// @param args the input arguments
|
||||
/// @param source the source location
|
||||
/// @return the result value, or null if the value cannot be calculated
|
||||
Result OpModulo(const sem::Type* ty,
|
||||
utils::VectorRef<const sem::Constant*> args,
|
||||
const Source& source);
|
||||
|
||||
/// Equality operator '=='
|
||||
/// @param ty the expression type
|
||||
/// @param args the input arguments
|
||||
|
@ -1040,6 +1049,14 @@ class ConstEval {
|
|||
template <typename NumberT>
|
||||
utils::Result<NumberT> Div(const Source& source, NumberT a, NumberT b);
|
||||
|
||||
/// Returns the (signed) remainder of the division of two Number<T>s
|
||||
/// @param source the source location
|
||||
/// @param a the lhs number
|
||||
/// @param b the rhs number
|
||||
/// @returns the result number on success, or logs an error and returns Failure
|
||||
template <typename NumberT>
|
||||
utils::Result<NumberT> Mod(const Source& source, NumberT a, NumberT b);
|
||||
|
||||
/// Returns the dot product of (a1,a2) with (b1,b2)
|
||||
/// @param source the source location
|
||||
/// @param a1 component 1 of lhs vector
|
||||
|
@ -1216,6 +1233,13 @@ class ConstEval {
|
|||
/// @returns the callable function
|
||||
auto DivFunc(const Source& source, const sem::Type* elem_ty);
|
||||
|
||||
/// Returns a callable that calls Mod, and creates a Constant with its result of type `elem_ty`
|
||||
/// if successful, or returns Failure otherwise.
|
||||
/// @param source the source location
|
||||
/// @param elem_ty the element type of the Constant to create on success
|
||||
/// @returns the callable function
|
||||
auto ModFunc(const Source& source, const sem::Type* elem_ty);
|
||||
|
||||
/// Returns a callable that calls Dot2, and creates a Constant with its result of type `elem_ty`
|
||||
/// if successful, or returns Failure otherwise.
|
||||
/// @param source the source location
|
||||
|
|
|
@ -485,6 +485,161 @@ INSTANTIATE_TEST_SUITE_P(Div,
|
|||
OpDivFloatCases<f32>(),
|
||||
OpDivFloatCases<f16>()))));
|
||||
|
||||
template <typename T>
|
||||
std::vector<Case> OpModCases() {
|
||||
auto error_msg = [](auto a, auto b) {
|
||||
return "12:34 error: " + OverflowErrorMessage(a, "%", b);
|
||||
};
|
||||
|
||||
// Common cases for all types
|
||||
std::vector<Case> r = {
|
||||
C(T{0}, T{1}, T{0}), //
|
||||
C(T{1}, T{1}, T{0}), //
|
||||
C(T{10}, T{1}, T{0}), //
|
||||
C(T{10}, T{2}, T{0}), //
|
||||
C(T{10}, T{3}, T{1}), //
|
||||
C(T{10}, T{4}, T{2}), //
|
||||
C(T{10}, T{5}, T{0}), //
|
||||
C(T{10}, T{6}, T{4}), //
|
||||
C(T{10}, T{5}, T{0}), //
|
||||
C(T{10}, T{8}, T{2}), //
|
||||
C(T{10}, T{9}, T{1}), //
|
||||
C(T{10}, T{10}, T{0}), //
|
||||
|
||||
// Error on divide by zero
|
||||
E(T{123}, T{0}, error_msg(T{123}, T{0})),
|
||||
E(T::Highest(), T{0}, error_msg(T::Highest(), T{0})),
|
||||
E(T::Lowest(), T{0}, error_msg(T::Lowest(), T{0})),
|
||||
};
|
||||
|
||||
if constexpr (IsIntegral<T>) {
|
||||
ConcatInto( //
|
||||
r, std::vector<Case>{
|
||||
C(T::Highest(), T{T::Highest() - T{1}}, T{1}),
|
||||
});
|
||||
}
|
||||
|
||||
if constexpr (IsSignedIntegral<T>) {
|
||||
ConcatInto( //
|
||||
r, std::vector<Case>{
|
||||
C(T::Lowest(), T{T::Lowest() + T{1}}, -T(1)),
|
||||
|
||||
// Error on most negative integer divided by -1
|
||||
E(T::Lowest(), T{-1}, error_msg(T::Lowest(), T{-1})),
|
||||
});
|
||||
}
|
||||
|
||||
// Negative values (both signed integrals and floating point)
|
||||
if constexpr (IsSignedIntegral<T> || IsFloatingPoint<T>) {
|
||||
ConcatInto( //
|
||||
r, std::vector<Case>{
|
||||
C(-T{1}, T{1}, T{0}), //
|
||||
|
||||
// lhs negative, rhs positive
|
||||
C(-T{10}, T{1}, T{0}), //
|
||||
C(-T{10}, T{2}, T{0}), //
|
||||
C(-T{10}, T{3}, -T{1}), //
|
||||
C(-T{10}, T{4}, -T{2}), //
|
||||
C(-T{10}, T{5}, T{0}), //
|
||||
C(-T{10}, T{6}, -T{4}), //
|
||||
C(-T{10}, T{5}, T{0}), //
|
||||
C(-T{10}, T{8}, -T{2}), //
|
||||
C(-T{10}, T{9}, -T{1}), //
|
||||
C(-T{10}, T{10}, T{0}), //
|
||||
|
||||
// lhs positive, rhs negative
|
||||
C(T{10}, -T{1}, T{0}), //
|
||||
C(T{10}, -T{2}, T{0}), //
|
||||
C(T{10}, -T{3}, T{1}), //
|
||||
C(T{10}, -T{4}, T{2}), //
|
||||
C(T{10}, -T{5}, T{0}), //
|
||||
C(T{10}, -T{6}, T{4}), //
|
||||
C(T{10}, -T{5}, T{0}), //
|
||||
C(T{10}, -T{8}, T{2}), //
|
||||
C(T{10}, -T{9}, T{1}), //
|
||||
C(T{10}, -T{10}, T{0}), //
|
||||
|
||||
// lhs negative, rhs negative
|
||||
C(-T{10}, -T{1}, T{0}), //
|
||||
C(-T{10}, -T{2}, T{0}), //
|
||||
C(-T{10}, -T{3}, -T{1}), //
|
||||
C(-T{10}, -T{4}, -T{2}), //
|
||||
C(-T{10}, -T{5}, T{0}), //
|
||||
C(-T{10}, -T{6}, -T{4}), //
|
||||
C(-T{10}, -T{5}, T{0}), //
|
||||
C(-T{10}, -T{8}, -T{2}), //
|
||||
C(-T{10}, -T{9}, -T{1}), //
|
||||
C(-T{10}, -T{10}, T{0}), //
|
||||
});
|
||||
}
|
||||
|
||||
// Float values
|
||||
if constexpr (IsFloatingPoint<T>) {
|
||||
ConcatInto( //
|
||||
r, std::vector<Case>{
|
||||
C(T{10.5}, T{1}, T{0.5}), //
|
||||
C(T{10.5}, T{2}, T{0.5}), //
|
||||
C(T{10.5}, T{3}, T{1.5}), //
|
||||
C(T{10.5}, T{4}, T{2.5}), //
|
||||
C(T{10.5}, T{5}, T{0.5}), //
|
||||
C(T{10.5}, T{6}, T{4.5}), //
|
||||
C(T{10.5}, T{5}, T{0.5}), //
|
||||
C(T{10.5}, T{8}, T{2.5}), //
|
||||
C(T{10.5}, T{9}, T{1.5}), //
|
||||
C(T{10.5}, T{10}, T{0.5}), //
|
||||
|
||||
// lhs negative, rhs positive
|
||||
C(-T{10.5}, T{1}, -T{0.5}), //
|
||||
C(-T{10.5}, T{2}, -T{0.5}), //
|
||||
C(-T{10.5}, T{3}, -T{1.5}), //
|
||||
C(-T{10.5}, T{4}, -T{2.5}), //
|
||||
C(-T{10.5}, T{5}, -T{0.5}), //
|
||||
C(-T{10.5}, T{6}, -T{4.5}), //
|
||||
C(-T{10.5}, T{5}, -T{0.5}), //
|
||||
C(-T{10.5}, T{8}, -T{2.5}), //
|
||||
C(-T{10.5}, T{9}, -T{1.5}), //
|
||||
C(-T{10.5}, T{10}, -T{0.5}), //
|
||||
|
||||
// lhs positive, rhs negative
|
||||
C(T{10.5}, -T{1}, T{0.5}), //
|
||||
C(T{10.5}, -T{2}, T{0.5}), //
|
||||
C(T{10.5}, -T{3}, T{1.5}), //
|
||||
C(T{10.5}, -T{4}, T{2.5}), //
|
||||
C(T{10.5}, -T{5}, T{0.5}), //
|
||||
C(T{10.5}, -T{6}, T{4.5}), //
|
||||
C(T{10.5}, -T{5}, T{0.5}), //
|
||||
C(T{10.5}, -T{8}, T{2.5}), //
|
||||
C(T{10.5}, -T{9}, T{1.5}), //
|
||||
C(T{10.5}, -T{10}, T{0.5}), //
|
||||
|
||||
// lhs negative, rhs negative
|
||||
C(-T{10.5}, -T{1}, -T{0.5}), //
|
||||
C(-T{10.5}, -T{2}, -T{0.5}), //
|
||||
C(-T{10.5}, -T{3}, -T{1.5}), //
|
||||
C(-T{10.5}, -T{4}, -T{2.5}), //
|
||||
C(-T{10.5}, -T{5}, -T{0.5}), //
|
||||
C(-T{10.5}, -T{6}, -T{4.5}), //
|
||||
C(-T{10.5}, -T{5}, -T{0.5}), //
|
||||
C(-T{10.5}, -T{8}, -T{2.5}), //
|
||||
C(-T{10.5}, -T{9}, -T{1.5}), //
|
||||
C(-T{10.5}, -T{10}, -T{0.5}), //
|
||||
});
|
||||
}
|
||||
|
||||
return r;
|
||||
}
|
||||
INSTANTIATE_TEST_SUITE_P(Mod,
|
||||
ResolverConstEvalBinaryOpTest,
|
||||
testing::Combine( //
|
||||
testing::Values(ast::BinaryOp::kModulo),
|
||||
testing::ValuesIn(Concat( //
|
||||
OpModCases<AInt>(),
|
||||
OpModCases<i32>(),
|
||||
OpModCases<u32>(),
|
||||
OpModCases<AFloat>(),
|
||||
OpModCases<f32>(),
|
||||
OpModCases<f16>()))));
|
||||
|
||||
template <typename T, bool equals>
|
||||
std::vector<Case> OpEqualCases() {
|
||||
return {
|
||||
|
|
|
@ -11322,48 +11322,48 @@ constexpr OverloadInfo kOverloads[] = {
|
|||
/* num parameters */ 2,
|
||||
/* num template types */ 1,
|
||||
/* num template numbers */ 0,
|
||||
/* template types */ &kTemplateTypes[30],
|
||||
/* template types */ &kTemplateTypes[22],
|
||||
/* template numbers */ &kTemplateNumbers[10],
|
||||
/* parameters */ &kParameters[718],
|
||||
/* return matcher indices */ &kMatcherIndices[3],
|
||||
/* flags */ OverloadFlags(OverloadFlag::kIsOperator, OverloadFlag::kSupportsVertexPipeline, OverloadFlag::kSupportsFragmentPipeline, OverloadFlag::kSupportsComputePipeline),
|
||||
/* const eval */ nullptr,
|
||||
/* const eval */ &ConstEval::OpModulo,
|
||||
},
|
||||
{
|
||||
/* [251] */
|
||||
/* num parameters */ 2,
|
||||
/* num template types */ 1,
|
||||
/* num template numbers */ 1,
|
||||
/* template types */ &kTemplateTypes[30],
|
||||
/* template types */ &kTemplateTypes[22],
|
||||
/* template numbers */ &kTemplateNumbers[4],
|
||||
/* parameters */ &kParameters[720],
|
||||
/* return matcher indices */ &kMatcherIndices[30],
|
||||
/* flags */ OverloadFlags(OverloadFlag::kIsOperator, OverloadFlag::kSupportsVertexPipeline, OverloadFlag::kSupportsFragmentPipeline, OverloadFlag::kSupportsComputePipeline),
|
||||
/* const eval */ nullptr,
|
||||
/* const eval */ &ConstEval::OpModulo,
|
||||
},
|
||||
{
|
||||
/* [252] */
|
||||
/* num parameters */ 2,
|
||||
/* num template types */ 1,
|
||||
/* num template numbers */ 1,
|
||||
/* template types */ &kTemplateTypes[30],
|
||||
/* template types */ &kTemplateTypes[22],
|
||||
/* template numbers */ &kTemplateNumbers[4],
|
||||
/* parameters */ &kParameters[722],
|
||||
/* return matcher indices */ &kMatcherIndices[30],
|
||||
/* flags */ OverloadFlags(OverloadFlag::kIsOperator, OverloadFlag::kSupportsVertexPipeline, OverloadFlag::kSupportsFragmentPipeline, OverloadFlag::kSupportsComputePipeline),
|
||||
/* const eval */ nullptr,
|
||||
/* const eval */ &ConstEval::OpModulo,
|
||||
},
|
||||
{
|
||||
/* [253] */
|
||||
/* num parameters */ 2,
|
||||
/* num template types */ 1,
|
||||
/* num template numbers */ 1,
|
||||
/* template types */ &kTemplateTypes[30],
|
||||
/* template types */ &kTemplateTypes[22],
|
||||
/* template numbers */ &kTemplateNumbers[4],
|
||||
/* parameters */ &kParameters[724],
|
||||
/* return matcher indices */ &kMatcherIndices[30],
|
||||
/* flags */ OverloadFlags(OverloadFlag::kIsOperator, OverloadFlag::kSupportsVertexPipeline, OverloadFlag::kSupportsFragmentPipeline, OverloadFlag::kSupportsComputePipeline),
|
||||
/* const eval */ nullptr,
|
||||
/* const eval */ &ConstEval::OpModulo,
|
||||
},
|
||||
{
|
||||
/* [254] */
|
||||
|
@ -14929,10 +14929,10 @@ constexpr IntrinsicInfo kBinaryOperators[] = {
|
|||
},
|
||||
{
|
||||
/* [4] */
|
||||
/* op %<T : fiu32_f16>(T, T) -> T */
|
||||
/* op %<T : fiu32_f16, N : num>(vec<N, T>, vec<N, T>) -> vec<N, T> */
|
||||
/* op %<T : fiu32_f16, N : num>(vec<N, T>, T) -> vec<N, T> */
|
||||
/* op %<T : fiu32_f16, N : num>(T, vec<N, T>) -> vec<N, T> */
|
||||
/* op %<T : fia_fiu32_f16>(T, T) -> T */
|
||||
/* op %<T : fia_fiu32_f16, N : num>(vec<N, T>, vec<N, T>) -> vec<N, T> */
|
||||
/* op %<T : fia_fiu32_f16, N : num>(vec<N, T>, T) -> vec<N, T> */
|
||||
/* op %<T : fia_fiu32_f16, N : num>(T, vec<N, T>) -> vec<N, T> */
|
||||
/* num overloads */ 4,
|
||||
/* overloads */ &kOverloads[250],
|
||||
},
|
||||
|
|
Loading…
Reference in New Issue