tint: const eval of binary bitwise AND and OR

Bug: tint:1581
Change-Id: Id6a7a1c8e45ee91bede8014dca03a59035b29678
Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/102060
Reviewed-by: Ben Clayton <bclayton@google.com>
Commit-Queue: Antonio Maiorano <amaiorano@google.com>
This commit is contained in:
Antonio Maiorano 2022-09-13 18:13:01 +00:00 committed by Dawn LUCI CQ
parent 5b3707a2d7
commit e53b6f9502
28 changed files with 272 additions and 50 deletions

View File

@ -921,15 +921,15 @@ op % <T: fiu32_f16, N: num> (T, vec<N, T>) -> vec<N, T>
op ^ <T: iu32>(T, T) -> T
op ^ <T: iu32, N: num> (vec<N, T>, vec<N, T>) -> vec<N, T>
op & (bool, bool) -> bool
op & <N: num> (vec<N, bool>, vec<N, bool>) -> vec<N, bool>
op & <T: iu32>(T, T) -> T
op & <T: iu32, N: num> (vec<N, T>, vec<N, T>) -> vec<N, T>
@const op & (bool, bool) -> bool
@const op & <N: num> (vec<N, bool>, vec<N, bool>) -> vec<N, bool>
@const op & <T: ia_iu32>(T, T) -> T
@const op & <T: ia_iu32, N: num> (vec<N, T>, vec<N, T>) -> vec<N, T>
op | (bool, bool) -> bool
op | <N: num> (vec<N, bool>, vec<N, bool>) -> vec<N, bool>
op | <T: iu32>(T, T) -> T
op | <T: iu32, N: num> (vec<N, T>, vec<N, T>) -> vec<N, T>
@const op | (bool, bool) -> bool
@const op | <N: num> (vec<N, bool>, vec<N, bool>) -> vec<N, bool>
@const op | <T: ia_iu32>(T, T) -> T
@const op | <T: ia_iu32, N: num> (vec<N, T>, vec<N, T>) -> vec<N, T>
op && (bool, bool) -> bool
op || (bool, bool) -> bool

View File

@ -64,6 +64,18 @@ auto Dispatch_ia_iu32(F&& f, CONSTANTS&&... cs) {
[&](const sem::U32*) { return f(cs->template As<u32>()...); });
}
/// Helper that calls `f` passing in the value of all `cs`.
/// Assumes all `cs` are of the same type.
template <typename F, typename... CONSTANTS>
auto Dispatch_ia_iu32_bool(F&& f, CONSTANTS&&... cs) {
return Switch(
First(cs...)->Type(), //
[&](const sem::AbstractInt*) { return f(cs->template As<AInt>()...); },
[&](const sem::I32*) { return f(cs->template As<i32>()...); },
[&](const sem::U32*) { return f(cs->template As<u32>()...); },
[&](const sem::Bool*) { return f(cs->template As<bool>()...); });
}
/// Helper that calls `f` passing in the value of all `cs`.
/// Assumes all `cs` are of the same type.
template <typename F, typename... CONSTANTS>
@ -1384,6 +1396,54 @@ ConstEval::ConstantResult ConstEval::OpGreaterThanEqual(const sem::Type* ty,
return r;
}
ConstEval::ConstantResult ConstEval::OpAnd(const sem::Type* ty,
utils::VectorRef<const sem::Constant*> args,
const Source&) {
auto transform = [&](const sem::Constant* c0, const sem::Constant* c1) {
auto create = [&](auto i, auto j) -> const Constant* {
using T = decltype(i);
T result;
if constexpr (std::is_same_v<T, bool>) {
result = i && j;
} else { // integral
result = i & j;
}
return CreateElement(builder, sem::Type::DeepestElementOf(ty), result);
};
return Dispatch_ia_iu32_bool(create, c0, c1);
};
auto r = TransformElements(builder, ty, transform, args[0], args[1]);
if (builder.Diagnostics().contains_errors()) {
return utils::Failure;
}
return r;
}
ConstEval::ConstantResult ConstEval::OpOr(const sem::Type* ty,
utils::VectorRef<const sem::Constant*> args,
const Source&) {
auto transform = [&](const sem::Constant* c0, const sem::Constant* c1) {
auto create = [&](auto i, auto j) -> const Constant* {
using T = decltype(i);
T result;
if constexpr (std::is_same_v<T, bool>) {
result = i || j;
} else { // integral
result = i | j;
}
return CreateElement(builder, sem::Type::DeepestElementOf(ty), result);
};
return Dispatch_ia_iu32_bool(create, c0, c1);
};
auto r = TransformElements(builder, ty, transform, args[0], args[1]);
if (builder.Diagnostics().contains_errors()) {
return utils::Failure;
}
return r;
}
ConstEval::ConstantResult ConstEval::atan2(const sem::Type* ty,
utils::VectorRef<const sem::Constant*> args,
const Source&) {

View File

@ -338,6 +338,24 @@ class ConstEval {
utils::VectorRef<const sem::Constant*> args,
const Source& source);
/// Bitwise and operator '&'
/// @param ty the expression type
/// @param args the input arguments
/// @param source the source location of the conversion
/// @return the result value, or null if the value cannot be calculated
ConstantResult OpAnd(const sem::Type* ty,
utils::VectorRef<const sem::Constant*> args,
const Source& source);
/// Bitwise or operator '|'
/// @param ty the expression type
/// @param args the input arguments
/// @param source the source location of the conversion
/// @return the result value, or null if the value cannot be calculated
ConstantResult OpOr(const sem::Type* ty,
utils::VectorRef<const sem::Constant*> args,
const Source& source);
////////////////////////////////////////////////////////////////////////////
// Builtins
////////////////////////////////////////////////////////////////////////////

View File

@ -3135,6 +3135,17 @@ bool ForEachElemPair(const sem::Constant* a, const sem::Constant* b, Func&& f) {
return true;
}
template <typename T>
struct BitValues {
using UT = UnwrapNumber<T>;
static constexpr size_t NumBits = sizeof(UT) * 8;
static inline const T All = T{~T{0}};
static inline const T LeftMost = T{T{1} << (NumBits - 1u)};
static inline const T AllButLeftMost = T{~LeftMost};
static inline const T RightMost = T{1};
static inline const T AllButRightMost = T{~RightMost};
};
////////////////////////////////////////////////////////////////////////////////////////////////////
// Unary op
////////////////////////////////////////////////////////////////////////////////////////////////////
@ -3722,6 +3733,136 @@ INSTANTIATE_TEST_SUITE_P(LessThanEqual,
OpGreaterThanCases<f32, false>(),
OpGreaterThanCases<f16, false>()))));
static std::vector<Case> OpAndBoolCases() {
return {
C(true, true, true),
C(true, false, false),
C(false, true, false),
C(false, false, false),
C(Vec(true, true), Vec(true, false), Vec(true, false)),
C(Vec(true, true), Vec(false, true), Vec(false, true)),
C(Vec(true, false), Vec(true, false), Vec(true, false)),
C(Vec(false, true), Vec(true, false), Vec(false, false)),
C(Vec(false, false), Vec(true, false), Vec(false, false)),
};
}
template <typename T>
std::vector<Case> OpAndIntCases() {
using B = BitValues<T>;
return {
C(T{0b1010}, T{0b1111}, T{0b1010}),
C(T{0b1010}, T{0b0000}, T{0b0000}),
C(T{0b1010}, T{0b0011}, T{0b0010}),
C(T{0b1010}, T{0b1100}, T{0b1000}),
C(T{0b1010}, T{0b0101}, T{0b0000}),
C(B::All, B::All, B::All),
C(B::LeftMost, B::LeftMost, B::LeftMost),
C(B::RightMost, B::RightMost, B::RightMost),
C(B::All, T{0}, T{0}),
C(T{0}, B::All, T{0}),
C(B::LeftMost, B::AllButLeftMost, T{0}),
C(B::AllButLeftMost, B::LeftMost, T{0}),
C(B::RightMost, B::AllButRightMost, T{0}),
C(B::AllButRightMost, B::RightMost, T{0}),
C(Vec(B::All, B::LeftMost, B::RightMost), //
Vec(B::All, B::All, B::All), //
Vec(B::All, B::LeftMost, B::RightMost)), //
C(Vec(B::All, B::LeftMost, B::RightMost), //
Vec(T{0}, T{0}, T{0}), //
Vec(T{0}, T{0}, T{0})), //
C(Vec(B::LeftMost, B::RightMost), //
Vec(B::AllButLeftMost, B::AllButRightMost), //
Vec(T{0}, T{0})),
};
}
INSTANTIATE_TEST_SUITE_P(And,
ResolverConstEvalBinaryOpTest,
testing::Combine( //
testing::Values(ast::BinaryOp::kAnd),
testing::ValuesIn( //
Concat(OpAndBoolCases(), //
OpAndIntCases<AInt>(),
OpAndIntCases<i32>(),
OpAndIntCases<u32>()))));
static std::vector<Case> OpOrBoolCases() {
return {
C(true, true, true),
C(true, false, true),
C(false, true, true),
C(false, false, false),
C(Vec(true, true), Vec(true, false), Vec(true, true)),
C(Vec(true, true), Vec(false, true), Vec(true, true)),
C(Vec(true, false), Vec(true, false), Vec(true, false)),
C(Vec(false, true), Vec(true, false), Vec(true, true)),
C(Vec(false, false), Vec(true, false), Vec(true, false)),
};
}
template <typename T>
std::vector<Case> OpOrIntCases() {
using B = BitValues<T>;
return {
C(T{0b1010}, T{0b1111}, T{0b1111}),
C(T{0b1010}, T{0b0000}, T{0b1010}),
C(T{0b1010}, T{0b0011}, T{0b1011}),
C(T{0b1010}, T{0b1100}, T{0b1110}),
C(T{0b1010}, T{0b0101}, T{0b1111}),
C(B::All, B::All, B::All),
C(B::LeftMost, B::LeftMost, B::LeftMost),
C(B::RightMost, B::RightMost, B::RightMost),
C(B::All, T{0}, B::All),
C(T{0}, B::All, B::All),
C(B::LeftMost, B::AllButLeftMost, B::All),
C(B::AllButLeftMost, B::LeftMost, B::All),
C(B::RightMost, B::AllButRightMost, B::All),
C(B::AllButRightMost, B::RightMost, B::All),
C(Vec(B::All, B::LeftMost, B::RightMost), //
Vec(B::All, B::All, B::All), //
Vec(B::All, B::All, B::All)), //
C(Vec(B::All, B::LeftMost, B::RightMost), //
Vec(T{0}, T{0}, T{0}), //
Vec(B::All, B::LeftMost, B::RightMost)), //
C(Vec(B::LeftMost, B::RightMost), //
Vec(B::AllButLeftMost, B::AllButRightMost), //
Vec(B::All, B::All)),
};
}
INSTANTIATE_TEST_SUITE_P(Or,
ResolverConstEvalBinaryOpTest,
testing::Combine( //
testing::Values(ast::BinaryOp::kOr),
testing::ValuesIn(Concat(OpOrBoolCases(),
OpOrIntCases<AInt>(),
OpOrIntCases<i32>(),
OpOrIntCases<u32>()))));
TEST_F(ResolverConstEvalTest, NotAndOrOfVecs) {
// const C = !((vec2(true, true) & vec2(true, false)) | vec2(false, true));
auto v1 = Vec(true, true).Expr(*this);
auto v2 = Vec(true, false).Expr(*this);
auto v3 = Vec(false, true).Expr(*this);
auto expr = Not(Or(And(v1, v2), v3));
GlobalConst("C", expr);
auto expected_expr = Vec(false, false).Expr(*this);
GlobalConst("E", expected_expr);
EXPECT_TRUE(r()->Resolve()) << r()->error();
auto* sem = Sem().Get(expr);
const sem::Constant* value = sem->ConstantValue();
ASSERT_NE(value, nullptr);
EXPECT_TYPE(value->Type(), sem->Type());
auto* expected_sem = Sem().Get(expected_expr);
const sem::Constant* expected_value = expected_sem->ConstantValue();
ASSERT_NE(expected_value, nullptr);
EXPECT_TYPE(expected_value->Type(), expected_sem->Type());
ForEachElemPair(value, expected_value, [&](const sem::Constant* a, const sem::Constant* b) {
EXPECT_EQ(a->As<bool>(), b->As<bool>());
return !HasFailure();
});
}
// Tests for errors on overflow/underflow of binary operations with abstract numbers
struct OverflowCase {
ast::BinaryOp op;

View File

@ -11185,7 +11185,7 @@ constexpr OverloadInfo kOverloads[] = {
/* parameters */ &kParameters[687],
/* return matcher indices */ &kMatcherIndices[16],
/* flags */ OverloadFlags(OverloadFlag::kIsOperator, OverloadFlag::kSupportsVertexPipeline, OverloadFlag::kSupportsFragmentPipeline, OverloadFlag::kSupportsComputePipeline),
/* const eval */ nullptr,
/* const eval */ &ConstEval::OpAnd,
},
{
/* [252] */
@ -11197,31 +11197,31 @@ constexpr OverloadInfo kOverloads[] = {
/* parameters */ &kParameters[685],
/* return matcher indices */ &kMatcherIndices[39],
/* flags */ OverloadFlags(OverloadFlag::kIsOperator, OverloadFlag::kSupportsVertexPipeline, OverloadFlag::kSupportsFragmentPipeline, OverloadFlag::kSupportsComputePipeline),
/* const eval */ nullptr,
/* const eval */ &ConstEval::OpAnd,
},
{
/* [253] */
/* num parameters */ 2,
/* num template types */ 1,
/* num template numbers */ 0,
/* template types */ &kTemplateTypes[14],
/* template types */ &kTemplateTypes[10],
/* template numbers */ &kTemplateNumbers[10],
/* parameters */ &kParameters[683],
/* return matcher indices */ &kMatcherIndices[1],
/* flags */ OverloadFlags(OverloadFlag::kIsOperator, OverloadFlag::kSupportsVertexPipeline, OverloadFlag::kSupportsFragmentPipeline, OverloadFlag::kSupportsComputePipeline),
/* const eval */ nullptr,
/* const eval */ &ConstEval::OpAnd,
},
{
/* [254] */
/* num parameters */ 2,
/* num template types */ 1,
/* num template numbers */ 1,
/* template types */ &kTemplateTypes[14],
/* template types */ &kTemplateTypes[10],
/* template numbers */ &kTemplateNumbers[6],
/* parameters */ &kParameters[681],
/* return matcher indices */ &kMatcherIndices[30],
/* flags */ OverloadFlags(OverloadFlag::kIsOperator, OverloadFlag::kSupportsVertexPipeline, OverloadFlag::kSupportsFragmentPipeline, OverloadFlag::kSupportsComputePipeline),
/* const eval */ nullptr,
/* const eval */ &ConstEval::OpAnd,
},
{
/* [255] */
@ -11233,7 +11233,7 @@ constexpr OverloadInfo kOverloads[] = {
/* parameters */ &kParameters[679],
/* return matcher indices */ &kMatcherIndices[16],
/* flags */ OverloadFlags(OverloadFlag::kIsOperator, OverloadFlag::kSupportsVertexPipeline, OverloadFlag::kSupportsFragmentPipeline, OverloadFlag::kSupportsComputePipeline),
/* const eval */ nullptr,
/* const eval */ &ConstEval::OpOr,
},
{
/* [256] */
@ -11245,31 +11245,31 @@ constexpr OverloadInfo kOverloads[] = {
/* parameters */ &kParameters[677],
/* return matcher indices */ &kMatcherIndices[39],
/* flags */ OverloadFlags(OverloadFlag::kIsOperator, OverloadFlag::kSupportsVertexPipeline, OverloadFlag::kSupportsFragmentPipeline, OverloadFlag::kSupportsComputePipeline),
/* const eval */ nullptr,
/* const eval */ &ConstEval::OpOr,
},
{
/* [257] */
/* num parameters */ 2,
/* num template types */ 1,
/* num template numbers */ 0,
/* template types */ &kTemplateTypes[14],
/* template types */ &kTemplateTypes[10],
/* template numbers */ &kTemplateNumbers[10],
/* parameters */ &kParameters[675],
/* return matcher indices */ &kMatcherIndices[1],
/* flags */ OverloadFlags(OverloadFlag::kIsOperator, OverloadFlag::kSupportsVertexPipeline, OverloadFlag::kSupportsFragmentPipeline, OverloadFlag::kSupportsComputePipeline),
/* const eval */ nullptr,
/* const eval */ &ConstEval::OpOr,
},
{
/* [258] */
/* num parameters */ 2,
/* num template types */ 1,
/* num template numbers */ 1,
/* template types */ &kTemplateTypes[14],
/* template types */ &kTemplateTypes[10],
/* template numbers */ &kTemplateNumbers[6],
/* parameters */ &kParameters[673],
/* return matcher indices */ &kMatcherIndices[30],
/* flags */ OverloadFlags(OverloadFlag::kIsOperator, OverloadFlag::kSupportsVertexPipeline, OverloadFlag::kSupportsFragmentPipeline, OverloadFlag::kSupportsComputePipeline),
/* const eval */ nullptr,
/* const eval */ &ConstEval::OpOr,
},
{
/* [259] */
@ -14671,8 +14671,8 @@ constexpr IntrinsicInfo kBinaryOperators[] = {
/* [6] */
/* op &(bool, bool) -> bool */
/* op &<N : num>(vec<N, bool>, vec<N, bool>) -> vec<N, bool> */
/* op &<T : iu32>(T, T) -> T */
/* op &<T : iu32, N : num>(vec<N, T>, vec<N, T>) -> vec<N, T> */
/* op &<T : ia_iu32>(T, T) -> T */
/* op &<T : ia_iu32, N : num>(vec<N, T>, vec<N, T>) -> vec<N, T> */
/* num overloads */ 4,
/* overloads */ &kOverloads[251],
},
@ -14680,8 +14680,8 @@ constexpr IntrinsicInfo kBinaryOperators[] = {
/* [7] */
/* op |(bool, bool) -> bool */
/* op |<N : num>(vec<N, bool>, vec<N, bool>) -> vec<N, bool> */
/* op |<T : iu32>(T, T) -> T */
/* op |<T : iu32, N : num>(vec<N, T>, vec<N, T>) -> vec<N, T> */
/* op |<T : ia_iu32>(T, T) -> T */
/* op |<T : ia_iu32, N : num>(vec<N, T>, vec<N, T>) -> vec<N, T> */
/* num overloads */ 4,
/* overloads */ &kOverloads[255],
},

View File

@ -765,7 +765,7 @@ auto Val(T v) {
/// Creates a `Value<vec<N, T>>` from N scalar `args`
template <typename... T>
auto Vec(T&&... args) {
auto Vec(T... args) {
constexpr size_t N = sizeof...(args);
using FirstT = std::tuple_element_t<0, std::tuple<T...>>;
utils::Vector v{args...};

View File

@ -1,5 +1,6 @@
@compute
@workgroup_size(1)
fn main() {
var v = select(true & true, true, false);
let a = true;
var v = select(a & true, true, false);
}

View File

@ -1,5 +1,5 @@
[numthreads(1, 1, 1)]
void main() {
bool v = (false ? true : (true & true));
bool v = (false ? true : true);
return;
}

View File

@ -1,5 +1,5 @@
[numthreads(1, 1, 1)]
void main() {
bool v = (false ? true : (true & true));
bool v = (false ? true : true);
return;
}

View File

@ -1,7 +1,7 @@
#version 310 es
void tint_symbol() {
bool v = (false ? true : bool(uint(true) & uint(true)));
bool v = (false ? true : true);
}
layout(local_size_x = 1, local_size_y = 1, local_size_z = 1) in;

View File

@ -2,7 +2,8 @@
using namespace metal;
kernel void tint_symbol() {
bool v = select(bool(true & true), true, false);
bool const a = true;
bool v = select(bool(a & true), true, false);
return;
}

View File

@ -12,14 +12,14 @@
%void = OpTypeVoid
%1 = OpTypeFunction %void
%bool = OpTypeBool
%7 = OpConstantNull %bool
%true = OpConstantTrue %bool
%8 = OpConstantNull %bool
%_ptr_Function_bool = OpTypePointer Function %bool
%main = OpFunction %void None %1
%4 = OpLabel
%v = OpVariable %_ptr_Function_bool Function %7
%v = OpVariable %_ptr_Function_bool Function %8
%9 = OpLogicalAnd %bool %true %true
%5 = OpSelect %bool %7 %true %9
OpStore %v %5
%7 = OpSelect %bool %8 %true %9
OpStore %v %7
OpReturn
OpFunctionEnd

View File

@ -1,4 +1,5 @@
@compute @workgroup_size(1)
fn main() {
var v = select((true & true), true, false);
let a = true;
var v = select((a & true), true, false);
}

View File

@ -1,5 +1,5 @@
[numthreads(1, 1, 1)]
void f() {
const bool r = (true & false);
const bool r = false;
return;
}

View File

@ -1,5 +1,5 @@
[numthreads(1, 1, 1)]
void f() {
const bool r = (true & false);
const bool r = false;
return;
}

View File

@ -1,7 +1,7 @@
#version 310 es
void f() {
bool r = bool(uint(true) & uint(false));
bool r = false;
}
layout(local_size_x = 1, local_size_y = 1, local_size_z = 1) in;

View File

@ -1,5 +1,5 @@
[numthreads(1, 1, 1)]
void f() {
const int r = (1 & 2);
const int r = 0;
return;
}

View File

@ -1,5 +1,5 @@
[numthreads(1, 1, 1)]
void f() {
const int r = (1 & 2);
const int r = 0;
return;
}

View File

@ -1,7 +1,7 @@
#version 310 es
void f() {
int r = (1 & 2);
int r = 0;
}
layout(local_size_x = 1, local_size_y = 1, local_size_z = 1) in;

View File

@ -1,5 +1,5 @@
[numthreads(1, 1, 1)]
void f() {
const uint r = (1u & 2u);
const uint r = 0u;
return;
}

View File

@ -1,5 +1,5 @@
[numthreads(1, 1, 1)]
void f() {
const uint r = (1u & 2u);
const uint r = 0u;
return;
}

View File

@ -1,7 +1,7 @@
#version 310 es
void f() {
uint r = (1u & 2u);
uint r = 0u;
}
layout(local_size_x = 1, local_size_y = 1, local_size_z = 1) in;

View File

@ -1,5 +1,5 @@
[numthreads(1, 1, 1)]
void f() {
const int r = (1 | 2);
const int r = 3;
return;
}

View File

@ -1,5 +1,5 @@
[numthreads(1, 1, 1)]
void f() {
const int r = (1 | 2);
const int r = 3;
return;
}

View File

@ -1,7 +1,7 @@
#version 310 es
void f() {
int r = (1 | 2);
int r = 3;
}
layout(local_size_x = 1, local_size_y = 1, local_size_z = 1) in;

View File

@ -1,5 +1,5 @@
[numthreads(1, 1, 1)]
void f() {
const uint r = (1u | 2u);
const uint r = 3u;
return;
}

View File

@ -1,5 +1,5 @@
[numthreads(1, 1, 1)]
void f() {
const uint r = (1u | 2u);
const uint r = 3u;
return;
}

View File

@ -1,7 +1,7 @@
#version 310 es
void f() {
uint r = (1u | 2u);
uint r = 3u;
}
layout(local_size_x = 1, local_size_y = 1, local_size_z = 1) in;