tint: Implement const eval of binary minus
Bug: tint:1581 Change-Id: I90ce59e89a5b4b9e94de1181ca9d85e9040be3e5 Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/99421 Reviewed-by: Ben Clayton <bclayton@google.com> Kokoro: Kokoro <noreply+kokoro@google.com> Commit-Queue: Antonio Maiorano <amaiorano@google.com>
This commit is contained in:
parent
eb0af9def7
commit
b79238d7ec
|
@ -880,8 +880,8 @@ op ! <N: num> (vec<N, bool>) -> vec<N, bool>
|
|||
@const op ~ <T: ia_iu32>(T) -> T
|
||||
@const op ~ <T: ia_iu32, N: num> (vec<N, T>) -> vec<N, T>
|
||||
|
||||
@const op - <T: fia_fi32_f16>(T) -> T
|
||||
@const op - <T: fia_fi32_f16, N: num> (vec<N, T>) -> vec<N, T>
|
||||
@const("UnaryMinus") op - <T: fia_fi32_f16>(T) -> T
|
||||
@const("UnaryMinus") op - <T: fia_fi32_f16, N: num> (vec<N, T>) -> vec<N, T>
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
// Binary Operators //
|
||||
|
@ -892,11 +892,11 @@ op ! <N: num> (vec<N, bool>) -> vec<N, bool>
|
|||
@const op + <T: fia_fiu32_f16, N: num> (T, vec<N, T>) -> vec<N, T>
|
||||
@const op + <T: fa_f32_f16, N: num, M: num> (mat<N, M, T>, mat<N, M, T>) -> mat<N, M, T>
|
||||
|
||||
op - <T: fiu32_f16>(T, T) -> T
|
||||
op - <T: fiu32_f16, N: num> (vec<N, T>, vec<N, T>) -> vec<N, T>
|
||||
op - <T: fiu32_f16, N: num> (vec<N, T>, T) -> vec<N, T>
|
||||
op - <T: fiu32_f16, N: num> (T, vec<N, T>) -> vec<N, T>
|
||||
op - <T: f32_f16, N: num, M: num> (mat<N, M, T>, mat<N, M, T>) -> mat<N, M, T>
|
||||
@const op - <T: fia_fiu32_f16>(T, T) -> T
|
||||
@const op - <T: fia_fiu32_f16, N: num> (vec<N, T>, vec<N, T>) -> vec<N, T>
|
||||
@const op - <T: fia_fiu32_f16, N: num> (vec<N, T>, T) -> vec<N, T>
|
||||
@const op - <T: fia_fiu32_f16, N: num> (T, vec<N, T>) -> vec<N, T>
|
||||
@const op - <T: fa_f32_f16, N: num, M: num> (mat<N, M, T>, mat<N, M, T>) -> mat<N, M, T>
|
||||
|
||||
op * <T: fiu32_f16>(T, T) -> T
|
||||
op * <T: fiu32_f16, N: num> (vec<N, T>, vec<N, T>) -> vec<N, T>
|
||||
|
|
|
@ -731,7 +731,7 @@ ConstEval::ConstantResult ConstEval::OpComplement(const sem::Type*,
|
|||
return TransformElements(builder, transform, args[0]);
|
||||
}
|
||||
|
||||
ConstEval::ConstantResult ConstEval::OpMinus(const sem::Type*,
|
||||
ConstEval::ConstantResult ConstEval::OpUnaryMinus(const sem::Type*,
|
||||
utils::VectorRef<const sem::Constant*> args,
|
||||
const Source&) {
|
||||
auto transform = [&](const sem::Constant* c) {
|
||||
|
@ -801,6 +801,51 @@ ConstEval::ConstantResult ConstEval::OpPlus(const sem::Type* ty,
|
|||
return r;
|
||||
}
|
||||
|
||||
ConstEval::ConstantResult ConstEval::OpMinus(const sem::Type* ty,
|
||||
utils::VectorRef<const sem::Constant*> args,
|
||||
const Source& source) {
|
||||
auto transform = [&](const sem::Constant* c0, const sem::Constant* c1) {
|
||||
auto create = [&](auto i, auto j) -> const Constant* {
|
||||
using NumberT = decltype(i);
|
||||
using T = UnwrapNumber<NumberT>;
|
||||
|
||||
auto subtract_values = [](T lhs, T rhs) {
|
||||
if constexpr (std::is_integral_v<T> && std::is_signed_v<T>) {
|
||||
// Ensure no UB for signed underflow
|
||||
using UT = std::make_unsigned_t<T>;
|
||||
return static_cast<T>(static_cast<UT>(lhs) - static_cast<UT>(rhs));
|
||||
} else {
|
||||
return lhs - rhs;
|
||||
}
|
||||
};
|
||||
|
||||
NumberT result;
|
||||
if constexpr (std::is_same_v<NumberT, AInt> || std::is_same_v<NumberT, AFloat>) {
|
||||
// Check for over/underflow for abstract values
|
||||
if (auto r = CheckedSub(i, j)) {
|
||||
result = r->value;
|
||||
} else {
|
||||
AddError("'" + std::to_string(subtract_values(i.value, j.value)) +
|
||||
"' cannot be represented as '" +
|
||||
ty->FriendlyName(builder.Symbols()) + "'",
|
||||
source);
|
||||
return nullptr;
|
||||
}
|
||||
} else {
|
||||
result = subtract_values(i.value, j.value);
|
||||
}
|
||||
return CreateElement(builder, c0->Type(), result);
|
||||
};
|
||||
return Dispatch_fia_fiu32_f16(create, c0, c1);
|
||||
};
|
||||
|
||||
auto r = TransformBinaryElements(builder, transform, args[0], args[1]);
|
||||
if (builder.Diagnostics().contains_errors()) {
|
||||
return utils::Failure;
|
||||
}
|
||||
return r;
|
||||
}
|
||||
|
||||
ConstEval::ConstantResult ConstEval::atan2(const sem::Type*,
|
||||
utils::VectorRef<const sem::Constant*> args,
|
||||
const Source&) {
|
||||
|
|
|
@ -199,12 +199,12 @@ class ConstEval {
|
|||
utils::VectorRef<const sem::Constant*> args,
|
||||
const Source& source);
|
||||
|
||||
/// Minus operator '-'
|
||||
/// Unary minus operator '-'
|
||||
/// @param ty the expression type
|
||||
/// @param args the input arguments
|
||||
/// @param source the source location of the conversion
|
||||
/// @return the result value, or null if the value cannot be calculated
|
||||
ConstantResult OpMinus(const sem::Type* ty,
|
||||
ConstantResult OpUnaryMinus(const sem::Type* ty,
|
||||
utils::VectorRef<const sem::Constant*> args,
|
||||
const Source& source);
|
||||
|
||||
|
@ -221,6 +221,15 @@ class ConstEval {
|
|||
utils::VectorRef<const sem::Constant*> args,
|
||||
const Source& source);
|
||||
|
||||
/// Minus operator '-'
|
||||
/// @param ty the expression type
|
||||
/// @param args the input arguments
|
||||
/// @param source the source location of the conversion
|
||||
/// @return the result value, or null if the value cannot be calculated
|
||||
ConstantResult OpMinus(const sem::Type* ty,
|
||||
utils::VectorRef<const sem::Constant*> args,
|
||||
const Source& source);
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////
|
||||
// Builtins
|
||||
////////////////////////////////////////////////////////////////////////////
|
||||
|
|
|
@ -3236,6 +3236,43 @@ INSTANTIATE_TEST_SUITE_P(Add,
|
|||
OpAddFloatCases<f32>(),
|
||||
OpAddFloatCases<f16>()))));
|
||||
|
||||
template <typename T>
|
||||
std::vector<Case> OpSubIntCases() {
|
||||
static_assert(IsInteger<UnwrapNumber<T>>);
|
||||
return {
|
||||
C(T{0}, T{0}, T{0}),
|
||||
C(T{3}, T{2}, T{1}),
|
||||
C(T{T::Lowest() + 1}, T{1}, T::Lowest()),
|
||||
C(T{T::Highest() - 1}, Negate(T{1}), T::Highest()),
|
||||
C(Negate(T{1}), T::Highest(), T::Lowest()),
|
||||
C(T::Lowest(), T{1}, T::Highest(), true),
|
||||
C(T::Highest(), Negate(T{1}), T::Lowest(), true),
|
||||
};
|
||||
}
|
||||
template <typename T>
|
||||
std::vector<Case> OpSubFloatCases() {
|
||||
static_assert(IsFloatingPoint<UnwrapNumber<T>>);
|
||||
return {
|
||||
C(T{0}, T{0}, T{0}),
|
||||
C(T{3}, T{2}, T{1}),
|
||||
C(T::Highest(), T{1}, T{T::Highest() - 1}),
|
||||
C(T::Lowest(), Negate(T{1}), T{T::Lowest() + 1}),
|
||||
C(T{0}, T::Highest(), T::Lowest()),
|
||||
C(T::Highest(), Negate(T::Highest()), T::Inf(), true),
|
||||
C(T::Lowest(), T::Highest(), -T::Inf(), true),
|
||||
};
|
||||
}
|
||||
INSTANTIATE_TEST_SUITE_P(Sub,
|
||||
ResolverConstEvalBinaryOpTest,
|
||||
testing::Combine(testing::Values(ast::BinaryOp::kSubtract),
|
||||
testing::ValuesIn(Concat( //
|
||||
OpSubIntCases<AInt>(),
|
||||
OpSubIntCases<i32>(),
|
||||
OpSubIntCases<u32>(),
|
||||
OpSubFloatCases<AFloat>(),
|
||||
OpSubFloatCases<f32>(),
|
||||
OpSubFloatCases<f16>()))));
|
||||
|
||||
TEST_F(ResolverConstEvalTest, BinaryAbstractAddOverflow_AInt) {
|
||||
GlobalConst("c", nullptr, Add(Source{{1, 1}}, Expr(AInt::Highest()), 1_a));
|
||||
EXPECT_FALSE(r()->Resolve());
|
||||
|
|
|
@ -10952,60 +10952,60 @@ constexpr OverloadInfo kOverloads[] = {
|
|||
/* num parameters */ 2,
|
||||
/* num template types */ 1,
|
||||
/* num template numbers */ 0,
|
||||
/* template types */ &kTemplateTypes[15],
|
||||
/* template types */ &kTemplateTypes[13],
|
||||
/* template numbers */ &kTemplateNumbers[10],
|
||||
/* parameters */ &kParameters[737],
|
||||
/* return matcher indices */ &kMatcherIndices[1],
|
||||
/* flags */ OverloadFlags(OverloadFlag::kIsOperator, OverloadFlag::kSupportsVertexPipeline, OverloadFlag::kSupportsFragmentPipeline, OverloadFlag::kSupportsComputePipeline),
|
||||
/* const eval */ nullptr,
|
||||
/* const eval */ &ConstEval::OpMinus,
|
||||
},
|
||||
{
|
||||
/* [233] */
|
||||
/* num parameters */ 2,
|
||||
/* num template types */ 1,
|
||||
/* num template numbers */ 1,
|
||||
/* template types */ &kTemplateTypes[15],
|
||||
/* template types */ &kTemplateTypes[13],
|
||||
/* template numbers */ &kTemplateNumbers[6],
|
||||
/* parameters */ &kParameters[735],
|
||||
/* return matcher indices */ &kMatcherIndices[30],
|
||||
/* flags */ OverloadFlags(OverloadFlag::kIsOperator, OverloadFlag::kSupportsVertexPipeline, OverloadFlag::kSupportsFragmentPipeline, OverloadFlag::kSupportsComputePipeline),
|
||||
/* const eval */ nullptr,
|
||||
/* const eval */ &ConstEval::OpMinus,
|
||||
},
|
||||
{
|
||||
/* [234] */
|
||||
/* num parameters */ 2,
|
||||
/* num template types */ 1,
|
||||
/* num template numbers */ 1,
|
||||
/* template types */ &kTemplateTypes[15],
|
||||
/* template types */ &kTemplateTypes[13],
|
||||
/* template numbers */ &kTemplateNumbers[6],
|
||||
/* parameters */ &kParameters[733],
|
||||
/* return matcher indices */ &kMatcherIndices[30],
|
||||
/* flags */ OverloadFlags(OverloadFlag::kIsOperator, OverloadFlag::kSupportsVertexPipeline, OverloadFlag::kSupportsFragmentPipeline, OverloadFlag::kSupportsComputePipeline),
|
||||
/* const eval */ nullptr,
|
||||
/* const eval */ &ConstEval::OpMinus,
|
||||
},
|
||||
{
|
||||
/* [235] */
|
||||
/* num parameters */ 2,
|
||||
/* num template types */ 1,
|
||||
/* num template numbers */ 1,
|
||||
/* template types */ &kTemplateTypes[15],
|
||||
/* template types */ &kTemplateTypes[13],
|
||||
/* template numbers */ &kTemplateNumbers[6],
|
||||
/* parameters */ &kParameters[731],
|
||||
/* return matcher indices */ &kMatcherIndices[30],
|
||||
/* flags */ OverloadFlags(OverloadFlag::kIsOperator, OverloadFlag::kSupportsVertexPipeline, OverloadFlag::kSupportsFragmentPipeline, OverloadFlag::kSupportsComputePipeline),
|
||||
/* const eval */ nullptr,
|
||||
/* const eval */ &ConstEval::OpMinus,
|
||||
},
|
||||
{
|
||||
/* [236] */
|
||||
/* num parameters */ 2,
|
||||
/* num template types */ 1,
|
||||
/* num template numbers */ 2,
|
||||
/* template types */ &kTemplateTypes[11],
|
||||
/* template types */ &kTemplateTypes[12],
|
||||
/* template numbers */ &kTemplateNumbers[6],
|
||||
/* parameters */ &kParameters[729],
|
||||
/* return matcher indices */ &kMatcherIndices[10],
|
||||
/* flags */ OverloadFlags(OverloadFlag::kIsOperator, OverloadFlag::kSupportsVertexPipeline, OverloadFlag::kSupportsFragmentPipeline, OverloadFlag::kSupportsComputePipeline),
|
||||
/* const eval */ nullptr,
|
||||
/* const eval */ &ConstEval::OpMinus,
|
||||
},
|
||||
{
|
||||
/* [237] */
|
||||
|
@ -13237,7 +13237,7 @@ constexpr OverloadInfo kOverloads[] = {
|
|||
/* parameters */ &kParameters[862],
|
||||
/* return matcher indices */ &kMatcherIndices[1],
|
||||
/* flags */ OverloadFlags(OverloadFlag::kIsOperator, OverloadFlag::kSupportsVertexPipeline, OverloadFlag::kSupportsFragmentPipeline, OverloadFlag::kSupportsComputePipeline),
|
||||
/* const eval */ &ConstEval::OpMinus,
|
||||
/* const eval */ &ConstEval::OpUnaryMinus,
|
||||
},
|
||||
{
|
||||
/* [423] */
|
||||
|
@ -13249,7 +13249,7 @@ constexpr OverloadInfo kOverloads[] = {
|
|||
/* parameters */ &kParameters[863],
|
||||
/* return matcher indices */ &kMatcherIndices[30],
|
||||
/* flags */ OverloadFlags(OverloadFlag::kIsOperator, OverloadFlag::kSupportsVertexPipeline, OverloadFlag::kSupportsFragmentPipeline, OverloadFlag::kSupportsComputePipeline),
|
||||
/* const eval */ &ConstEval::OpMinus,
|
||||
/* const eval */ &ConstEval::OpUnaryMinus,
|
||||
},
|
||||
{
|
||||
/* [424] */
|
||||
|
@ -14620,11 +14620,11 @@ constexpr IntrinsicInfo kBinaryOperators[] = {
|
|||
},
|
||||
{
|
||||
/* [1] */
|
||||
/* op -<T : fiu32_f16>(T, T) -> T */
|
||||
/* op -<T : fiu32_f16, N : num>(vec<N, T>, vec<N, T>) -> vec<N, T> */
|
||||
/* op -<T : fiu32_f16, N : num>(vec<N, T>, T) -> vec<N, T> */
|
||||
/* op -<T : fiu32_f16, N : num>(T, vec<N, T>) -> vec<N, T> */
|
||||
/* op -<T : f32_f16, N : num, M : num>(mat<N, M, T>, mat<N, M, T>) -> mat<N, M, T> */
|
||||
/* op -<T : fia_fiu32_f16>(T, T) -> T */
|
||||
/* op -<T : fia_fiu32_f16, N : num>(vec<N, T>, vec<N, T>) -> vec<N, T> */
|
||||
/* op -<T : fia_fiu32_f16, N : num>(vec<N, T>, T) -> vec<N, T> */
|
||||
/* op -<T : fia_fiu32_f16, N : num>(T, vec<N, T>) -> vec<N, T> */
|
||||
/* op -<T : fa_f32_f16, N : num, M : num>(mat<N, M, T>, mat<N, M, T>) -> mat<N, M, T> */
|
||||
/* num overloads */ 5,
|
||||
/* overloads */ &kOverloads[232],
|
||||
},
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
[numthreads(1, 1, 1)]
|
||||
void f() {
|
||||
const float16_t r = (float16_t(1.0h) - float16_t(2.0h));
|
||||
const float16_t r = float16_t(-1.0h);
|
||||
return;
|
||||
}
|
||||
|
|
|
@ -2,7 +2,7 @@
|
|||
#extension GL_AMD_gpu_shader_half_float : require
|
||||
|
||||
void f() {
|
||||
float16_t r = (1.0hf - 2.0hf);
|
||||
float16_t r = -1.0hf;
|
||||
}
|
||||
|
||||
layout(local_size_x = 1, local_size_y = 1, local_size_z = 1) in;
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
[numthreads(1, 1, 1)]
|
||||
void f() {
|
||||
const float r = (1.0f - 2.0f);
|
||||
const float r = -1.0f;
|
||||
return;
|
||||
}
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
[numthreads(1, 1, 1)]
|
||||
void f() {
|
||||
const float r = (1.0f - 2.0f);
|
||||
const float r = -1.0f;
|
||||
return;
|
||||
}
|
||||
|
|
|
@ -1,7 +1,7 @@
|
|||
#version 310 es
|
||||
|
||||
void f() {
|
||||
float r = (1.0f - 2.0f);
|
||||
float r = -1.0f;
|
||||
}
|
||||
|
||||
layout(local_size_x = 1, local_size_y = 1, local_size_z = 1) in;
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
[numthreads(1, 1, 1)]
|
||||
void f() {
|
||||
const int r = (1 - 2);
|
||||
const int r = -1;
|
||||
return;
|
||||
}
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
[numthreads(1, 1, 1)]
|
||||
void f() {
|
||||
const int r = (1 - 2);
|
||||
const int r = -1;
|
||||
return;
|
||||
}
|
||||
|
|
|
@ -1,7 +1,7 @@
|
|||
#version 310 es
|
||||
|
||||
void f() {
|
||||
int r = (1 - 2);
|
||||
int r = -1;
|
||||
}
|
||||
|
||||
layout(local_size_x = 1, local_size_y = 1, local_size_z = 1) in;
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
[numthreads(1, 1, 1)]
|
||||
void f() {
|
||||
const uint r = (1u - 2u);
|
||||
const uint r = 4294967295u;
|
||||
return;
|
||||
}
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
[numthreads(1, 1, 1)]
|
||||
void f() {
|
||||
const uint r = (1u - 2u);
|
||||
const uint r = 4294967295u;
|
||||
return;
|
||||
}
|
||||
|
|
|
@ -1,7 +1,7 @@
|
|||
#version 310 es
|
||||
|
||||
void f() {
|
||||
uint r = (1u - 2u);
|
||||
uint r = 4294967295u;
|
||||
}
|
||||
|
||||
layout(local_size_x = 1, local_size_y = 1, local_size_z = 1) in;
|
||||
|
|
Loading…
Reference in New Issue