tint: Implement const eval of binary minus

Bug: tint:1581
Change-Id: I90ce59e89a5b4b9e94de1181ca9d85e9040be3e5
Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/99421
Reviewed-by: Ben Clayton <bclayton@google.com>
Kokoro: Kokoro <noreply+kokoro@google.com>
Commit-Queue: Antonio Maiorano <amaiorano@google.com>
This commit is contained in:
Antonio Maiorano 2022-08-17 17:38:23 +00:00 committed by Dawn LUCI CQ
parent eb0af9def7
commit b79238d7ec
16 changed files with 133 additions and 42 deletions

View File

@ -880,8 +880,8 @@ op ! <N: num> (vec<N, bool>) -> vec<N, bool>
@const op ~ <T: ia_iu32>(T) -> T @const op ~ <T: ia_iu32>(T) -> T
@const op ~ <T: ia_iu32, N: num> (vec<N, T>) -> vec<N, T> @const op ~ <T: ia_iu32, N: num> (vec<N, T>) -> vec<N, T>
@const op - <T: fia_fi32_f16>(T) -> T @const("UnaryMinus") op - <T: fia_fi32_f16>(T) -> T
@const op - <T: fia_fi32_f16, N: num> (vec<N, T>) -> vec<N, T> @const("UnaryMinus") op - <T: fia_fi32_f16, N: num> (vec<N, T>) -> vec<N, T>
//////////////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////////////
// Binary Operators // // Binary Operators //
@ -892,11 +892,11 @@ op ! <N: num> (vec<N, bool>) -> vec<N, bool>
@const op + <T: fia_fiu32_f16, N: num> (T, vec<N, T>) -> vec<N, T> @const op + <T: fia_fiu32_f16, N: num> (T, vec<N, T>) -> vec<N, T>
@const op + <T: fa_f32_f16, N: num, M: num> (mat<N, M, T>, mat<N, M, T>) -> mat<N, M, T> @const op + <T: fa_f32_f16, N: num, M: num> (mat<N, M, T>, mat<N, M, T>) -> mat<N, M, T>
op - <T: fiu32_f16>(T, T) -> T @const op - <T: fia_fiu32_f16>(T, T) -> T
op - <T: fiu32_f16, N: num> (vec<N, T>, vec<N, T>) -> vec<N, T> @const op - <T: fia_fiu32_f16, N: num> (vec<N, T>, vec<N, T>) -> vec<N, T>
op - <T: fiu32_f16, N: num> (vec<N, T>, T) -> vec<N, T> @const op - <T: fia_fiu32_f16, N: num> (vec<N, T>, T) -> vec<N, T>
op - <T: fiu32_f16, N: num> (T, vec<N, T>) -> vec<N, T> @const op - <T: fia_fiu32_f16, N: num> (T, vec<N, T>) -> vec<N, T>
op - <T: f32_f16, N: num, M: num> (mat<N, M, T>, mat<N, M, T>) -> mat<N, M, T> @const op - <T: fa_f32_f16, N: num, M: num> (mat<N, M, T>, mat<N, M, T>) -> mat<N, M, T>
op * <T: fiu32_f16>(T, T) -> T op * <T: fiu32_f16>(T, T) -> T
op * <T: fiu32_f16, N: num> (vec<N, T>, vec<N, T>) -> vec<N, T> op * <T: fiu32_f16, N: num> (vec<N, T>, vec<N, T>) -> vec<N, T>

View File

@ -731,7 +731,7 @@ ConstEval::ConstantResult ConstEval::OpComplement(const sem::Type*,
return TransformElements(builder, transform, args[0]); return TransformElements(builder, transform, args[0]);
} }
ConstEval::ConstantResult ConstEval::OpMinus(const sem::Type*, ConstEval::ConstantResult ConstEval::OpUnaryMinus(const sem::Type*,
utils::VectorRef<const sem::Constant*> args, utils::VectorRef<const sem::Constant*> args,
const Source&) { const Source&) {
auto transform = [&](const sem::Constant* c) { auto transform = [&](const sem::Constant* c) {
@ -801,6 +801,51 @@ ConstEval::ConstantResult ConstEval::OpPlus(const sem::Type* ty,
return r; return r;
} }
ConstEval::ConstantResult ConstEval::OpMinus(const sem::Type* ty,
utils::VectorRef<const sem::Constant*> args,
const Source& source) {
auto transform = [&](const sem::Constant* c0, const sem::Constant* c1) {
auto create = [&](auto i, auto j) -> const Constant* {
using NumberT = decltype(i);
using T = UnwrapNumber<NumberT>;
auto subtract_values = [](T lhs, T rhs) {
if constexpr (std::is_integral_v<T> && std::is_signed_v<T>) {
// Ensure no UB for signed underflow
using UT = std::make_unsigned_t<T>;
return static_cast<T>(static_cast<UT>(lhs) - static_cast<UT>(rhs));
} else {
return lhs - rhs;
}
};
NumberT result;
if constexpr (std::is_same_v<NumberT, AInt> || std::is_same_v<NumberT, AFloat>) {
// Check for over/underflow for abstract values
if (auto r = CheckedSub(i, j)) {
result = r->value;
} else {
AddError("'" + std::to_string(subtract_values(i.value, j.value)) +
"' cannot be represented as '" +
ty->FriendlyName(builder.Symbols()) + "'",
source);
return nullptr;
}
} else {
result = subtract_values(i.value, j.value);
}
return CreateElement(builder, c0->Type(), result);
};
return Dispatch_fia_fiu32_f16(create, c0, c1);
};
auto r = TransformBinaryElements(builder, transform, args[0], args[1]);
if (builder.Diagnostics().contains_errors()) {
return utils::Failure;
}
return r;
}
ConstEval::ConstantResult ConstEval::atan2(const sem::Type*, ConstEval::ConstantResult ConstEval::atan2(const sem::Type*,
utils::VectorRef<const sem::Constant*> args, utils::VectorRef<const sem::Constant*> args,
const Source&) { const Source&) {

View File

@ -199,12 +199,12 @@ class ConstEval {
utils::VectorRef<const sem::Constant*> args, utils::VectorRef<const sem::Constant*> args,
const Source& source); const Source& source);
/// Minus operator '-' /// Unary minus operator '-'
/// @param ty the expression type /// @param ty the expression type
/// @param args the input arguments /// @param args the input arguments
/// @param source the source location of the conversion /// @param source the source location of the conversion
/// @return the result value, or null if the value cannot be calculated /// @return the result value, or null if the value cannot be calculated
ConstantResult OpMinus(const sem::Type* ty, ConstantResult OpUnaryMinus(const sem::Type* ty,
utils::VectorRef<const sem::Constant*> args, utils::VectorRef<const sem::Constant*> args,
const Source& source); const Source& source);
@ -221,6 +221,15 @@ class ConstEval {
utils::VectorRef<const sem::Constant*> args, utils::VectorRef<const sem::Constant*> args,
const Source& source); const Source& source);
/// Minus operator '-'
/// @param ty the expression type
/// @param args the input arguments
/// @param source the source location of the conversion
/// @return the result value, or null if the value cannot be calculated
ConstantResult OpMinus(const sem::Type* ty,
utils::VectorRef<const sem::Constant*> args,
const Source& source);
//////////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////////
// Builtins // Builtins
//////////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////////

View File

@ -3236,6 +3236,43 @@ INSTANTIATE_TEST_SUITE_P(Add,
OpAddFloatCases<f32>(), OpAddFloatCases<f32>(),
OpAddFloatCases<f16>())))); OpAddFloatCases<f16>()))));
template <typename T>
std::vector<Case> OpSubIntCases() {
static_assert(IsInteger<UnwrapNumber<T>>);
return {
C(T{0}, T{0}, T{0}),
C(T{3}, T{2}, T{1}),
C(T{T::Lowest() + 1}, T{1}, T::Lowest()),
C(T{T::Highest() - 1}, Negate(T{1}), T::Highest()),
C(Negate(T{1}), T::Highest(), T::Lowest()),
C(T::Lowest(), T{1}, T::Highest(), true),
C(T::Highest(), Negate(T{1}), T::Lowest(), true),
};
}
template <typename T>
std::vector<Case> OpSubFloatCases() {
static_assert(IsFloatingPoint<UnwrapNumber<T>>);
return {
C(T{0}, T{0}, T{0}),
C(T{3}, T{2}, T{1}),
C(T::Highest(), T{1}, T{T::Highest() - 1}),
C(T::Lowest(), Negate(T{1}), T{T::Lowest() + 1}),
C(T{0}, T::Highest(), T::Lowest()),
C(T::Highest(), Negate(T::Highest()), T::Inf(), true),
C(T::Lowest(), T::Highest(), -T::Inf(), true),
};
}
INSTANTIATE_TEST_SUITE_P(Sub,
ResolverConstEvalBinaryOpTest,
testing::Combine(testing::Values(ast::BinaryOp::kSubtract),
testing::ValuesIn(Concat( //
OpSubIntCases<AInt>(),
OpSubIntCases<i32>(),
OpSubIntCases<u32>(),
OpSubFloatCases<AFloat>(),
OpSubFloatCases<f32>(),
OpSubFloatCases<f16>()))));
TEST_F(ResolverConstEvalTest, BinaryAbstractAddOverflow_AInt) { TEST_F(ResolverConstEvalTest, BinaryAbstractAddOverflow_AInt) {
GlobalConst("c", nullptr, Add(Source{{1, 1}}, Expr(AInt::Highest()), 1_a)); GlobalConst("c", nullptr, Add(Source{{1, 1}}, Expr(AInt::Highest()), 1_a));
EXPECT_FALSE(r()->Resolve()); EXPECT_FALSE(r()->Resolve());

View File

@ -10952,60 +10952,60 @@ constexpr OverloadInfo kOverloads[] = {
/* num parameters */ 2, /* num parameters */ 2,
/* num template types */ 1, /* num template types */ 1,
/* num template numbers */ 0, /* num template numbers */ 0,
/* template types */ &kTemplateTypes[15], /* template types */ &kTemplateTypes[13],
/* template numbers */ &kTemplateNumbers[10], /* template numbers */ &kTemplateNumbers[10],
/* parameters */ &kParameters[737], /* parameters */ &kParameters[737],
/* return matcher indices */ &kMatcherIndices[1], /* return matcher indices */ &kMatcherIndices[1],
/* flags */ OverloadFlags(OverloadFlag::kIsOperator, OverloadFlag::kSupportsVertexPipeline, OverloadFlag::kSupportsFragmentPipeline, OverloadFlag::kSupportsComputePipeline), /* flags */ OverloadFlags(OverloadFlag::kIsOperator, OverloadFlag::kSupportsVertexPipeline, OverloadFlag::kSupportsFragmentPipeline, OverloadFlag::kSupportsComputePipeline),
/* const eval */ nullptr, /* const eval */ &ConstEval::OpMinus,
}, },
{ {
/* [233] */ /* [233] */
/* num parameters */ 2, /* num parameters */ 2,
/* num template types */ 1, /* num template types */ 1,
/* num template numbers */ 1, /* num template numbers */ 1,
/* template types */ &kTemplateTypes[15], /* template types */ &kTemplateTypes[13],
/* template numbers */ &kTemplateNumbers[6], /* template numbers */ &kTemplateNumbers[6],
/* parameters */ &kParameters[735], /* parameters */ &kParameters[735],
/* return matcher indices */ &kMatcherIndices[30], /* return matcher indices */ &kMatcherIndices[30],
/* flags */ OverloadFlags(OverloadFlag::kIsOperator, OverloadFlag::kSupportsVertexPipeline, OverloadFlag::kSupportsFragmentPipeline, OverloadFlag::kSupportsComputePipeline), /* flags */ OverloadFlags(OverloadFlag::kIsOperator, OverloadFlag::kSupportsVertexPipeline, OverloadFlag::kSupportsFragmentPipeline, OverloadFlag::kSupportsComputePipeline),
/* const eval */ nullptr, /* const eval */ &ConstEval::OpMinus,
}, },
{ {
/* [234] */ /* [234] */
/* num parameters */ 2, /* num parameters */ 2,
/* num template types */ 1, /* num template types */ 1,
/* num template numbers */ 1, /* num template numbers */ 1,
/* template types */ &kTemplateTypes[15], /* template types */ &kTemplateTypes[13],
/* template numbers */ &kTemplateNumbers[6], /* template numbers */ &kTemplateNumbers[6],
/* parameters */ &kParameters[733], /* parameters */ &kParameters[733],
/* return matcher indices */ &kMatcherIndices[30], /* return matcher indices */ &kMatcherIndices[30],
/* flags */ OverloadFlags(OverloadFlag::kIsOperator, OverloadFlag::kSupportsVertexPipeline, OverloadFlag::kSupportsFragmentPipeline, OverloadFlag::kSupportsComputePipeline), /* flags */ OverloadFlags(OverloadFlag::kIsOperator, OverloadFlag::kSupportsVertexPipeline, OverloadFlag::kSupportsFragmentPipeline, OverloadFlag::kSupportsComputePipeline),
/* const eval */ nullptr, /* const eval */ &ConstEval::OpMinus,
}, },
{ {
/* [235] */ /* [235] */
/* num parameters */ 2, /* num parameters */ 2,
/* num template types */ 1, /* num template types */ 1,
/* num template numbers */ 1, /* num template numbers */ 1,
/* template types */ &kTemplateTypes[15], /* template types */ &kTemplateTypes[13],
/* template numbers */ &kTemplateNumbers[6], /* template numbers */ &kTemplateNumbers[6],
/* parameters */ &kParameters[731], /* parameters */ &kParameters[731],
/* return matcher indices */ &kMatcherIndices[30], /* return matcher indices */ &kMatcherIndices[30],
/* flags */ OverloadFlags(OverloadFlag::kIsOperator, OverloadFlag::kSupportsVertexPipeline, OverloadFlag::kSupportsFragmentPipeline, OverloadFlag::kSupportsComputePipeline), /* flags */ OverloadFlags(OverloadFlag::kIsOperator, OverloadFlag::kSupportsVertexPipeline, OverloadFlag::kSupportsFragmentPipeline, OverloadFlag::kSupportsComputePipeline),
/* const eval */ nullptr, /* const eval */ &ConstEval::OpMinus,
}, },
{ {
/* [236] */ /* [236] */
/* num parameters */ 2, /* num parameters */ 2,
/* num template types */ 1, /* num template types */ 1,
/* num template numbers */ 2, /* num template numbers */ 2,
/* template types */ &kTemplateTypes[11], /* template types */ &kTemplateTypes[12],
/* template numbers */ &kTemplateNumbers[6], /* template numbers */ &kTemplateNumbers[6],
/* parameters */ &kParameters[729], /* parameters */ &kParameters[729],
/* return matcher indices */ &kMatcherIndices[10], /* return matcher indices */ &kMatcherIndices[10],
/* flags */ OverloadFlags(OverloadFlag::kIsOperator, OverloadFlag::kSupportsVertexPipeline, OverloadFlag::kSupportsFragmentPipeline, OverloadFlag::kSupportsComputePipeline), /* flags */ OverloadFlags(OverloadFlag::kIsOperator, OverloadFlag::kSupportsVertexPipeline, OverloadFlag::kSupportsFragmentPipeline, OverloadFlag::kSupportsComputePipeline),
/* const eval */ nullptr, /* const eval */ &ConstEval::OpMinus,
}, },
{ {
/* [237] */ /* [237] */
@ -13237,7 +13237,7 @@ constexpr OverloadInfo kOverloads[] = {
/* parameters */ &kParameters[862], /* parameters */ &kParameters[862],
/* return matcher indices */ &kMatcherIndices[1], /* return matcher indices */ &kMatcherIndices[1],
/* flags */ OverloadFlags(OverloadFlag::kIsOperator, OverloadFlag::kSupportsVertexPipeline, OverloadFlag::kSupportsFragmentPipeline, OverloadFlag::kSupportsComputePipeline), /* flags */ OverloadFlags(OverloadFlag::kIsOperator, OverloadFlag::kSupportsVertexPipeline, OverloadFlag::kSupportsFragmentPipeline, OverloadFlag::kSupportsComputePipeline),
/* const eval */ &ConstEval::OpMinus, /* const eval */ &ConstEval::OpUnaryMinus,
}, },
{ {
/* [423] */ /* [423] */
@ -13249,7 +13249,7 @@ constexpr OverloadInfo kOverloads[] = {
/* parameters */ &kParameters[863], /* parameters */ &kParameters[863],
/* return matcher indices */ &kMatcherIndices[30], /* return matcher indices */ &kMatcherIndices[30],
/* flags */ OverloadFlags(OverloadFlag::kIsOperator, OverloadFlag::kSupportsVertexPipeline, OverloadFlag::kSupportsFragmentPipeline, OverloadFlag::kSupportsComputePipeline), /* flags */ OverloadFlags(OverloadFlag::kIsOperator, OverloadFlag::kSupportsVertexPipeline, OverloadFlag::kSupportsFragmentPipeline, OverloadFlag::kSupportsComputePipeline),
/* const eval */ &ConstEval::OpMinus, /* const eval */ &ConstEval::OpUnaryMinus,
}, },
{ {
/* [424] */ /* [424] */
@ -14620,11 +14620,11 @@ constexpr IntrinsicInfo kBinaryOperators[] = {
}, },
{ {
/* [1] */ /* [1] */
/* op -<T : fiu32_f16>(T, T) -> T */ /* op -<T : fia_fiu32_f16>(T, T) -> T */
/* op -<T : fiu32_f16, N : num>(vec<N, T>, vec<N, T>) -> vec<N, T> */ /* op -<T : fia_fiu32_f16, N : num>(vec<N, T>, vec<N, T>) -> vec<N, T> */
/* op -<T : fiu32_f16, N : num>(vec<N, T>, T) -> vec<N, T> */ /* op -<T : fia_fiu32_f16, N : num>(vec<N, T>, T) -> vec<N, T> */
/* op -<T : fiu32_f16, N : num>(T, vec<N, T>) -> vec<N, T> */ /* op -<T : fia_fiu32_f16, N : num>(T, vec<N, T>) -> vec<N, T> */
/* op -<T : f32_f16, N : num, M : num>(mat<N, M, T>, mat<N, M, T>) -> mat<N, M, T> */ /* op -<T : fa_f32_f16, N : num, M : num>(mat<N, M, T>, mat<N, M, T>) -> mat<N, M, T> */
/* num overloads */ 5, /* num overloads */ 5,
/* overloads */ &kOverloads[232], /* overloads */ &kOverloads[232],
}, },

View File

@ -1,5 +1,5 @@
[numthreads(1, 1, 1)] [numthreads(1, 1, 1)]
void f() { void f() {
const float16_t r = (float16_t(1.0h) - float16_t(2.0h)); const float16_t r = float16_t(-1.0h);
return; return;
} }

View File

@ -2,7 +2,7 @@
#extension GL_AMD_gpu_shader_half_float : require #extension GL_AMD_gpu_shader_half_float : require
void f() { void f() {
float16_t r = (1.0hf - 2.0hf); float16_t r = -1.0hf;
} }
layout(local_size_x = 1, local_size_y = 1, local_size_z = 1) in; layout(local_size_x = 1, local_size_y = 1, local_size_z = 1) in;

View File

@ -1,5 +1,5 @@
[numthreads(1, 1, 1)] [numthreads(1, 1, 1)]
void f() { void f() {
const float r = (1.0f - 2.0f); const float r = -1.0f;
return; return;
} }

View File

@ -1,5 +1,5 @@
[numthreads(1, 1, 1)] [numthreads(1, 1, 1)]
void f() { void f() {
const float r = (1.0f - 2.0f); const float r = -1.0f;
return; return;
} }

View File

@ -1,7 +1,7 @@
#version 310 es #version 310 es
void f() { void f() {
float r = (1.0f - 2.0f); float r = -1.0f;
} }
layout(local_size_x = 1, local_size_y = 1, local_size_z = 1) in; layout(local_size_x = 1, local_size_y = 1, local_size_z = 1) in;

View File

@ -1,5 +1,5 @@
[numthreads(1, 1, 1)] [numthreads(1, 1, 1)]
void f() { void f() {
const int r = (1 - 2); const int r = -1;
return; return;
} }

View File

@ -1,5 +1,5 @@
[numthreads(1, 1, 1)] [numthreads(1, 1, 1)]
void f() { void f() {
const int r = (1 - 2); const int r = -1;
return; return;
} }

View File

@ -1,7 +1,7 @@
#version 310 es #version 310 es
void f() { void f() {
int r = (1 - 2); int r = -1;
} }
layout(local_size_x = 1, local_size_y = 1, local_size_z = 1) in; layout(local_size_x = 1, local_size_y = 1, local_size_z = 1) in;

View File

@ -1,5 +1,5 @@
[numthreads(1, 1, 1)] [numthreads(1, 1, 1)]
void f() { void f() {
const uint r = (1u - 2u); const uint r = 4294967295u;
return; return;
} }

View File

@ -1,5 +1,5 @@
[numthreads(1, 1, 1)] [numthreads(1, 1, 1)]
void f() { void f() {
const uint r = (1u - 2u); const uint r = 4294967295u;
return; return;
} }

View File

@ -1,7 +1,7 @@
#version 310 es #version 310 es
void f() { void f() {
uint r = (1u - 2u); uint r = 4294967295u;
} }
layout(local_size_x = 1, local_size_y = 1, local_size_z = 1) in; layout(local_size_x = 1, local_size_y = 1, local_size_z = 1) in;