From ebc5bba6718e5c1193fd116dac537b69dd1ec35b Mon Sep 17 00:00:00 2001 From: Antonio Maiorano Date: Fri, 16 Sep 2022 17:16:38 +0000 Subject: [PATCH] tint: const eval of binary XOR Bug: tint:1581 Change-Id: I5605426f0c4b9447ce770092de4ab2f639d0218d Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/102580 Reviewed-by: Ben Clayton Commit-Queue: Antonio Maiorano Kokoro: Kokoro --- src/tint/intrinsics.def | 4 +- src/tint/resolver/const_eval.cc | 17 +++++++++ src/tint/resolver/const_eval.h | 9 +++++ src/tint/resolver/const_eval_test.cc | 37 +++++++++++++++++++ src/tint/resolver/intrinsic_table.inl | 12 +++--- .../scalar-scalar/i32.wgsl.expected.dxc.hlsl | 2 +- .../scalar-scalar/i32.wgsl.expected.fxc.hlsl | 2 +- .../scalar-scalar/i32.wgsl.expected.glsl | 2 +- .../scalar-scalar/u32.wgsl.expected.dxc.hlsl | 2 +- .../scalar-scalar/u32.wgsl.expected.fxc.hlsl | 2 +- .../scalar-scalar/u32.wgsl.expected.glsl | 2 +- 11 files changed, 77 insertions(+), 14 deletions(-) diff --git a/src/tint/intrinsics.def b/src/tint/intrinsics.def index e32a0f4769..23e0dc86bc 100644 --- a/src/tint/intrinsics.def +++ b/src/tint/intrinsics.def @@ -920,8 +920,8 @@ op % (vec, vec) -> vec op % (vec, T) -> vec op % (T, vec) -> vec -op ^ (T, T) -> T -op ^ (vec, vec) -> vec +@const op ^ (T, T) -> T +@const op ^ (vec, vec) -> vec @const op & (bool, bool) -> bool @const op & (vec, vec) -> vec diff --git a/src/tint/resolver/const_eval.cc b/src/tint/resolver/const_eval.cc index 8e46f9a410..7e33244cdb 100644 --- a/src/tint/resolver/const_eval.cc +++ b/src/tint/resolver/const_eval.cc @@ -1445,6 +1445,23 @@ ConstEval::ConstantResult ConstEval::OpOr(const sem::Type* ty, return r; } +ConstEval::ConstantResult ConstEval::OpXor(const sem::Type* ty, + utils::VectorRef args, + const Source&) { + auto transform = [&](const sem::Constant* c0, const sem::Constant* c1) { + auto create = [&](auto i, auto j) -> const Constant* { + return CreateElement(builder, sem::Type::DeepestElementOf(ty), decltype(i){i ^ j}); + }; + return Dispatch_ia_iu32(create, c0, c1); + }; + + auto r = TransformElements(builder, ty, transform, args[0], args[1]); + if (builder.Diagnostics().contains_errors()) { + return utils::Failure; + } + return r; +} + ConstEval::ConstantResult ConstEval::atan2(const sem::Type* ty, utils::VectorRef args, const Source&) { diff --git a/src/tint/resolver/const_eval.h b/src/tint/resolver/const_eval.h index 6b57556c55..04e22826ee 100644 --- a/src/tint/resolver/const_eval.h +++ b/src/tint/resolver/const_eval.h @@ -356,6 +356,15 @@ class ConstEval { utils::VectorRef args, const Source& source); + /// Bitwise xor operator '^' + /// @param ty the expression type + /// @param args the input arguments + /// @param source the source location of the conversion + /// @return the result value, or null if the value cannot be calculated + ConstantResult OpXor(const sem::Type* ty, + utils::VectorRef args, + const Source& source); + //////////////////////////////////////////////////////////////////////////// // Builtins //////////////////////////////////////////////////////////////////////////// diff --git a/src/tint/resolver/const_eval_test.cc b/src/tint/resolver/const_eval_test.cc index f06fa32b66..3b1d7eeac1 100644 --- a/src/tint/resolver/const_eval_test.cc +++ b/src/tint/resolver/const_eval_test.cc @@ -3675,6 +3675,43 @@ TEST_F(ResolverConstEvalTest, NotAndOrOfVecs) { }); } +template +std::vector XorCases() { + using B = BitValues; + return { + C(T{0b1010}, T{0b1111}, T{0b0101}), + C(T{0b1010}, T{0b0000}, T{0b1010}), + C(T{0b1010}, T{0b0011}, T{0b1001}), + C(T{0b1010}, T{0b1100}, T{0b0110}), + C(T{0b1010}, T{0b0101}, T{0b1111}), + C(B::All, B::All, T{0}), + C(B::LeftMost, B::LeftMost, T{0}), + C(B::RightMost, B::RightMost, T{0}), + C(B::All, T{0}, B::All), + C(T{0}, B::All, B::All), + C(B::LeftMost, B::AllButLeftMost, B::All), + C(B::AllButLeftMost, B::LeftMost, B::All), + C(B::RightMost, B::AllButRightMost, B::All), + C(B::AllButRightMost, B::RightMost, B::All), + C(Vec(B::All, B::LeftMost, B::RightMost), // + Vec(B::All, B::All, B::All), // + Vec(T{0}, B::AllButLeftMost, B::AllButRightMost)), // + C(Vec(B::All, B::LeftMost, B::RightMost), // + Vec(T{0}, T{0}, T{0}), // + Vec(B::All, B::LeftMost, B::RightMost)), // + C(Vec(B::LeftMost, B::RightMost), // + Vec(B::AllButLeftMost, B::AllButRightMost), // + Vec(B::All, B::All)), + }; +} +INSTANTIATE_TEST_SUITE_P(Xor, + ResolverConstEvalBinaryOpTest, + testing::Combine( // + testing::Values(ast::BinaryOp::kXor), + testing::ValuesIn(Concat(XorCases(), // + XorCases(), // + XorCases())))); + // Tests for errors on overflow/underflow of binary operations with abstract numbers struct OverflowCase { ast::BinaryOp op; diff --git a/src/tint/resolver/intrinsic_table.inl b/src/tint/resolver/intrinsic_table.inl index 559466db0d..f74cb252d0 100644 --- a/src/tint/resolver/intrinsic_table.inl +++ b/src/tint/resolver/intrinsic_table.inl @@ -13122,24 +13122,24 @@ constexpr OverloadInfo kOverloads[] = { /* num parameters */ 2, /* num template types */ 1, /* num template numbers */ 0, - /* template types */ &kTemplateTypes[14], + /* template types */ &kTemplateTypes[10], /* template numbers */ &kTemplateNumbers[10], /* parameters */ &kParameters[689], /* return matcher indices */ &kMatcherIndices[1], /* flags */ OverloadFlags(OverloadFlag::kIsOperator, OverloadFlag::kSupportsVertexPipeline, OverloadFlag::kSupportsFragmentPipeline, OverloadFlag::kSupportsComputePipeline), - /* const eval */ nullptr, + /* const eval */ &ConstEval::OpXor, }, { /* [413] */ /* num parameters */ 2, /* num template types */ 1, /* num template numbers */ 1, - /* template types */ &kTemplateTypes[14], + /* template types */ &kTemplateTypes[10], /* template numbers */ &kTemplateNumbers[6], /* parameters */ &kParameters[687], /* return matcher indices */ &kMatcherIndices[30], /* flags */ OverloadFlags(OverloadFlag::kIsOperator, OverloadFlag::kSupportsVertexPipeline, OverloadFlag::kSupportsFragmentPipeline, OverloadFlag::kSupportsComputePipeline), - /* const eval */ nullptr, + /* const eval */ &ConstEval::OpXor, }, { /* [414] */ @@ -14703,8 +14703,8 @@ constexpr IntrinsicInfo kBinaryOperators[] = { }, { /* [5] */ - /* op ^(T, T) -> T */ - /* op ^(vec, vec) -> vec */ + /* op ^(T, T) -> T */ + /* op ^(vec, vec) -> vec */ /* num overloads */ 2, /* overloads */ &kOverloads[412], }, diff --git a/test/tint/expressions/binary/bit-xor/scalar-scalar/i32.wgsl.expected.dxc.hlsl b/test/tint/expressions/binary/bit-xor/scalar-scalar/i32.wgsl.expected.dxc.hlsl index f3a33ad1eb..6202d8b57b 100644 --- a/test/tint/expressions/binary/bit-xor/scalar-scalar/i32.wgsl.expected.dxc.hlsl +++ b/test/tint/expressions/binary/bit-xor/scalar-scalar/i32.wgsl.expected.dxc.hlsl @@ -1,5 +1,5 @@ [numthreads(1, 1, 1)] void f() { - const int r = (1 ^ 2); + const int r = 3; return; } diff --git a/test/tint/expressions/binary/bit-xor/scalar-scalar/i32.wgsl.expected.fxc.hlsl b/test/tint/expressions/binary/bit-xor/scalar-scalar/i32.wgsl.expected.fxc.hlsl index f3a33ad1eb..6202d8b57b 100644 --- a/test/tint/expressions/binary/bit-xor/scalar-scalar/i32.wgsl.expected.fxc.hlsl +++ b/test/tint/expressions/binary/bit-xor/scalar-scalar/i32.wgsl.expected.fxc.hlsl @@ -1,5 +1,5 @@ [numthreads(1, 1, 1)] void f() { - const int r = (1 ^ 2); + const int r = 3; return; } diff --git a/test/tint/expressions/binary/bit-xor/scalar-scalar/i32.wgsl.expected.glsl b/test/tint/expressions/binary/bit-xor/scalar-scalar/i32.wgsl.expected.glsl index 235697cbc7..aa5e335278 100644 --- a/test/tint/expressions/binary/bit-xor/scalar-scalar/i32.wgsl.expected.glsl +++ b/test/tint/expressions/binary/bit-xor/scalar-scalar/i32.wgsl.expected.glsl @@ -1,7 +1,7 @@ #version 310 es void f() { - int r = (1 ^ 2); + int r = 3; } layout(local_size_x = 1, local_size_y = 1, local_size_z = 1) in; diff --git a/test/tint/expressions/binary/bit-xor/scalar-scalar/u32.wgsl.expected.dxc.hlsl b/test/tint/expressions/binary/bit-xor/scalar-scalar/u32.wgsl.expected.dxc.hlsl index 5e4b9090cb..8a5e6557bb 100644 --- a/test/tint/expressions/binary/bit-xor/scalar-scalar/u32.wgsl.expected.dxc.hlsl +++ b/test/tint/expressions/binary/bit-xor/scalar-scalar/u32.wgsl.expected.dxc.hlsl @@ -1,5 +1,5 @@ [numthreads(1, 1, 1)] void f() { - const uint r = (1u ^ 2u); + const uint r = 3u; return; } diff --git a/test/tint/expressions/binary/bit-xor/scalar-scalar/u32.wgsl.expected.fxc.hlsl b/test/tint/expressions/binary/bit-xor/scalar-scalar/u32.wgsl.expected.fxc.hlsl index 5e4b9090cb..8a5e6557bb 100644 --- a/test/tint/expressions/binary/bit-xor/scalar-scalar/u32.wgsl.expected.fxc.hlsl +++ b/test/tint/expressions/binary/bit-xor/scalar-scalar/u32.wgsl.expected.fxc.hlsl @@ -1,5 +1,5 @@ [numthreads(1, 1, 1)] void f() { - const uint r = (1u ^ 2u); + const uint r = 3u; return; } diff --git a/test/tint/expressions/binary/bit-xor/scalar-scalar/u32.wgsl.expected.glsl b/test/tint/expressions/binary/bit-xor/scalar-scalar/u32.wgsl.expected.glsl index c6cce08ee1..936c36da95 100644 --- a/test/tint/expressions/binary/bit-xor/scalar-scalar/u32.wgsl.expected.glsl +++ b/test/tint/expressions/binary/bit-xor/scalar-scalar/u32.wgsl.expected.glsl @@ -1,7 +1,7 @@ #version 310 es void f() { - uint r = (1u ^ 2u); + uint r = 3u; } layout(local_size_x = 1, local_size_y = 1, local_size_z = 1) in;