tint: Add Checked[Add|Mul|Madd]()

Test-for-overflow utilities for AInt.

Bug: tint:1504
Change-Id: I974ef829c72aaa4c2012550855227f71d4a370a0
Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/91700
Kokoro: Kokoro <noreply+kokoro@google.com>
Reviewed-by: David Neto <dneto@google.com>
Commit-Queue: Ben Clayton <bclayton@chromium.org>
This commit is contained in:
Ben Clayton 2022-05-31 13:14:29 +00:00 committed by Dawn LUCI CQ
parent fa5cd029d1
commit 61537d3f57
3 changed files with 322 additions and 37 deletions

View File

@ -19,7 +19,10 @@
#include <functional>
#include <limits>
#include <ostream>
// TODO(https://crbug.com/dawn/1379) Update cpplint and remove NOLINT
#include <optional> // NOLINT(build/include_order))
#include "src/tint/utils/compiler_macros.h"
#include "src/tint/utils/result.h"
// Forward declaration
@ -184,33 +187,6 @@ std::enable_if_t<IsNumeric<A>, bool> operator!=(A a, Number<B> b) {
return !(a == b);
}
/// Enumerator of failure reasons when converting from one number to another.
enum class ConversionFailure {
kExceedsPositiveLimit, // The value was too big (+'ve) to fit in the target type
kExceedsNegativeLimit, // The value was too big (-'ve) to fit in the target type
};
/// Writes the conversion failure message to the ostream.
/// @param out the std::ostream to write to
/// @param failure the ConversionFailure
/// @return the std::ostream so calls can be chained
std::ostream& operator<<(std::ostream& out, ConversionFailure failure);
/// Converts a number from one type to another, checking that the value fits in the target type.
/// @returns the resulting value of the conversion, or a failure reason.
template <typename TO, typename FROM>
utils::Result<TO, ConversionFailure> CheckedConvert(Number<FROM> num) {
using T = decltype(UnwrapNumber<TO>() + num.value);
const auto value = static_cast<T>(num.value);
if (value > static_cast<T>(TO::kHighest)) {
return ConversionFailure::kExceedsPositiveLimit;
}
if (value < static_cast<T>(TO::kLowest)) {
return ConversionFailure::kExceedsNegativeLimit;
}
return TO(value); // Success
}
/// The partial specification of Number for f16 type, storing the f16 value as float,
/// and enforcing proper explicit casting.
template <>
@ -282,6 +258,114 @@ using f32 = Number<float>;
/// However since C++ don't have native binary16 type, the value is stored as float.
using f16 = Number<detail::NumberKindF16>;
/// Enumerator of failure reasons when converting from one number to another.
enum class ConversionFailure {
kExceedsPositiveLimit, // The value was too big (+'ve) to fit in the target type
kExceedsNegativeLimit, // The value was too big (-'ve) to fit in the target type
};
/// Writes the conversion failure message to the ostream.
/// @param out the std::ostream to write to
/// @param failure the ConversionFailure
/// @return the std::ostream so calls can be chained
std::ostream& operator<<(std::ostream& out, ConversionFailure failure);
/// Converts a number from one type to another, checking that the value fits in the target type.
/// @returns the resulting value of the conversion, or a failure reason.
template <typename TO, typename FROM>
utils::Result<TO, ConversionFailure> CheckedConvert(Number<FROM> num) {
using T = decltype(UnwrapNumber<TO>() + num.value);
const auto value = static_cast<T>(num.value);
if (value > static_cast<T>(TO::kHighest)) {
return ConversionFailure::kExceedsPositiveLimit;
}
if (value < static_cast<T>(TO::kLowest)) {
return ConversionFailure::kExceedsNegativeLimit;
}
return TO(value); // Success
}
/// Define 'TINT_HAS_OVERFLOW_BUILTINS' if the compiler provide overflow checking builtins.
/// If the compiler does not support these builtins, then these are emulated with algorithms
/// described in:
/// https://wiki.sei.cmu.edu/confluence/display/c/INT32-C.+Ensure+that+operations+on+signed+integers+do+not+result+in+overflow
#if defined(__GNUC__) && __GNUC__ >= 5
#define TINT_HAS_OVERFLOW_BUILTINS
#elif defined(__clang__)
#if __has_builtin(__builtin_add_overflow) && __has_builtin(__builtin_mul_overflow)
#define TINT_HAS_OVERFLOW_BUILTINS
#endif
#endif
/// @returns a + b, or an empty optional if the resulting value overflowed the AInt
inline std::optional<AInt> CheckedAdd(AInt a, AInt b) {
int64_t result;
#ifdef TINT_HAS_OVERFLOW_BUILTINS
if (__builtin_add_overflow(a.value, b.value, &result)) {
return {};
}
#else // TINT_HAS_OVERFLOW_BUILTINS
if (a.value >= 0) {
if (AInt::kHighest - a.value < b.value) {
return {};
}
} else {
if (b.value < AInt::kLowest - a.value) {
return {};
}
}
result = a.value + b.value;
#endif // TINT_HAS_OVERFLOW_BUILTINS
return AInt(result);
}
/// @returns a * b, or an empty optional if the resulting value overflowed the AInt
inline std::optional<AInt> CheckedMul(AInt a, AInt b) {
int64_t result;
#ifdef TINT_HAS_OVERFLOW_BUILTINS
if (__builtin_mul_overflow(a.value, b.value, &result)) {
return {};
}
#else // TINT_HAS_OVERFLOW_BUILTINS
if (a > 0) {
if (b > 0) {
if (a > (AInt::kHighest / b)) {
return {};
}
} else {
if (b < (AInt::kLowest / a)) {
return {};
}
}
} else {
if (b > 0) {
if (a < (AInt::kLowest / b)) {
return {};
}
} else {
if ((a != 0) && (b < (AInt::kHighest / a))) {
return {};
}
}
}
result = a.value * b.value;
#endif // TINT_HAS_OVERFLOW_BUILTINS
return AInt(result);
}
/// @returns a * b + c, or an empty optional if the value overflowed the AInt
inline std::optional<AInt> CheckedMadd(AInt a, AInt b, AInt c) {
// https://gcc.gnu.org/bugzilla/show_bug.cgi?id=80635
TINT_BEGIN_DISABLE_WARNING(MAYBE_UNINITIALIZED);
if (auto mul = CheckedMul(a, b)) {
return CheckedAdd(mul.value(), c);
}
return {};
TINT_END_DISABLE_WARNING(MAYBE_UNINITIALIZED);
}
} // namespace tint
namespace tint::number_suffixes {

View File

@ -13,6 +13,8 @@
// limitations under the License.
#include <cmath>
#include <tuple>
#include <vector>
#include "src/tint/program_builder.h"
#include "src/tint/utils/compiler_macros.h"
@ -141,6 +143,165 @@ TEST(NumberTest, QuantizeF16) {
EXPECT_TRUE(std::isnan(f16(nan)));
}
using BinaryCheckedCase = std::tuple<std::optional<AInt>, AInt, AInt>;
#undef OVERFLOW // corecrt_math.h :(
#define OVERFLOW \
{}
using CheckedAddTest = testing::TestWithParam<BinaryCheckedCase>;
TEST_P(CheckedAddTest, Test) {
auto expect = std::get<0>(GetParam());
auto a = std::get<1>(GetParam());
auto b = std::get<2>(GetParam());
EXPECT_EQ(CheckedAdd(a, b), expect) << std::hex << "0x" << a << " * 0x" << b;
EXPECT_EQ(CheckedAdd(b, a), expect) << std::hex << "0x" << a << " * 0x" << b;
}
INSTANTIATE_TEST_SUITE_P(
CheckedAddTest,
CheckedAddTest,
testing::ValuesIn(std::vector<BinaryCheckedCase>{
{AInt(0), AInt(0), AInt(0)},
{AInt(1), AInt(1), AInt(0)},
{AInt(2), AInt(1), AInt(1)},
{AInt(0), AInt(-1), AInt(1)},
{AInt(3), AInt(2), AInt(1)},
{AInt(-1), AInt(-2), AInt(1)},
{AInt(0x300), AInt(0x100), AInt(0x200)},
{AInt(0x100), AInt(-0x100), AInt(0x200)},
{AInt(AInt::kHighest), AInt(1), AInt(AInt::kHighest - 1)},
{AInt(AInt::kLowest), AInt(-1), AInt(AInt::kLowest + 1)},
{AInt(AInt::kHighest), AInt(0x7fffffff00000000ll), AInt(0x00000000ffffffffll)},
{AInt(AInt::kHighest), AInt(AInt::kHighest), AInt(0)},
{AInt(AInt::kLowest), AInt(AInt::kLowest), AInt(0)},
{OVERFLOW, AInt(1), AInt(AInt::kHighest)},
{OVERFLOW, AInt(-1), AInt(AInt::kLowest)},
{OVERFLOW, AInt(2), AInt(AInt::kHighest)},
{OVERFLOW, AInt(-2), AInt(AInt::kLowest)},
{OVERFLOW, AInt(10000), AInt(AInt::kHighest)},
{OVERFLOW, AInt(-10000), AInt(AInt::kLowest)},
{OVERFLOW, AInt(AInt::kHighest), AInt(AInt::kHighest)},
{OVERFLOW, AInt(AInt::kLowest), AInt(AInt::kLowest)},
////////////////////////////////////////////////////////////////////////
}));
using CheckedMulTest = testing::TestWithParam<BinaryCheckedCase>;
TEST_P(CheckedMulTest, Test) {
auto expect = std::get<0>(GetParam());
auto a = std::get<1>(GetParam());
auto b = std::get<2>(GetParam());
EXPECT_EQ(CheckedMul(a, b), expect) << std::hex << "0x" << a << " * 0x" << b;
EXPECT_EQ(CheckedMul(b, a), expect) << std::hex << "0x" << a << " * 0x" << b;
}
INSTANTIATE_TEST_SUITE_P(
CheckedMulTest,
CheckedMulTest,
testing::ValuesIn(std::vector<BinaryCheckedCase>{
{AInt(0), AInt(0), AInt(0)},
{AInt(0), AInt(1), AInt(0)},
{AInt(1), AInt(1), AInt(1)},
{AInt(-1), AInt(-1), AInt(1)},
{AInt(2), AInt(2), AInt(1)},
{AInt(-2), AInt(-2), AInt(1)},
{AInt(0x20000), AInt(0x100), AInt(0x200)},
{AInt(-0x20000), AInt(-0x100), AInt(0x200)},
{AInt(0x4000000000000000ll), AInt(0x80000000ll), AInt(0x80000000ll)},
{AInt(0x4000000000000000ll), AInt(-0x80000000ll), AInt(-0x80000000ll)},
{AInt(0x1000000000000000ll), AInt(0x40000000ll), AInt(0x40000000ll)},
{AInt(-0x1000000000000000ll), AInt(-0x40000000ll), AInt(0x40000000ll)},
{AInt(0x100000000000000ll), AInt(0x1000000), AInt(0x100000000ll)},
{AInt(0x2000000000000000ll), AInt(0x1000000000000000ll), AInt(2)},
{AInt(-0x2000000000000000ll), AInt(0x1000000000000000ll), AInt(-2)},
{AInt(-0x2000000000000000ll), AInt(-0x1000000000000000ll), AInt(2)},
{AInt(-0x2000000000000000ll), AInt(0x1000000000000000ll), AInt(-2)},
{AInt(0x4000000000000000ll), AInt(0x1000000000000000ll), AInt(4)},
{AInt(-0x4000000000000000ll), AInt(0x1000000000000000ll), AInt(-4)},
{AInt(-0x4000000000000000ll), AInt(-0x1000000000000000ll), AInt(4)},
{AInt(-0x4000000000000000ll), AInt(0x1000000000000000ll), AInt(-4)},
{AInt(-0x8000000000000000ll), AInt(0x1000000000000000ll), AInt(-8)},
{AInt(-0x8000000000000000ll), AInt(-0x1000000000000000ll), AInt(8)},
{AInt(0), AInt(AInt::kHighest), AInt(0)},
{AInt(0), AInt(AInt::kLowest), AInt(0)},
{OVERFLOW, AInt(0x1000000000000000ll), AInt(8)},
{OVERFLOW, AInt(-0x1000000000000000ll), AInt(-8)},
{OVERFLOW, AInt(0x800000000000000ll), AInt(0x10)},
{OVERFLOW, AInt(0x80000000ll), AInt(0x100000000ll)},
{OVERFLOW, AInt(AInt::kHighest), AInt(AInt::kHighest)},
{OVERFLOW, AInt(AInt::kHighest), AInt(AInt::kLowest)},
////////////////////////////////////////////////////////////////////////
}));
using TernaryCheckedCase = std::tuple<std::optional<AInt>, AInt, AInt, AInt>;
using CheckedMaddTest = testing::TestWithParam<TernaryCheckedCase>;
TEST_P(CheckedMaddTest, Test) {
auto expect = std::get<0>(GetParam());
auto a = std::get<1>(GetParam());
auto b = std::get<2>(GetParam());
auto c = std::get<3>(GetParam());
EXPECT_EQ(CheckedMadd(a, b, c), expect)
<< std::hex << "0x" << a << " * 0x" << b << " + 0x" << c;
EXPECT_EQ(CheckedMadd(b, a, c), expect)
<< std::hex << "0x" << a << " * 0x" << b << " + 0x" << c;
}
INSTANTIATE_TEST_SUITE_P(
CheckedMaddTest,
CheckedMaddTest,
testing::ValuesIn(std::vector<TernaryCheckedCase>{
{AInt(0), AInt(0), AInt(0), AInt(0)},
{AInt(0), AInt(1), AInt(0), AInt(0)},
{AInt(1), AInt(1), AInt(1), AInt(0)},
{AInt(2), AInt(1), AInt(1), AInt(1)},
{AInt(0), AInt(1), AInt(-1), AInt(1)},
{AInt(-1), AInt(1), AInt(-2), AInt(1)},
{AInt(-1), AInt(-1), AInt(1), AInt(0)},
{AInt(2), AInt(2), AInt(1), AInt(0)},
{AInt(-2), AInt(-2), AInt(1), AInt(0)},
{AInt(0), AInt(AInt::kHighest), AInt(0), AInt(0)},
{AInt(0), AInt(AInt::kLowest), AInt(0), AInt(0)},
{AInt(3), AInt(1), AInt(2), AInt(1)},
{AInt(0x300), AInt(1), AInt(0x100), AInt(0x200)},
{AInt(0x100), AInt(1), AInt(-0x100), AInt(0x200)},
{AInt(0x20000), AInt(0x100), AInt(0x200), AInt(0)},
{AInt(-0x20000), AInt(-0x100), AInt(0x200), AInt(0)},
{AInt(0x4000000000000000ll), AInt(0x80000000ll), AInt(0x80000000ll), AInt(0)},
{AInt(0x4000000000000000ll), AInt(-0x80000000ll), AInt(-0x80000000ll), AInt(0)},
{AInt(0x1000000000000000ll), AInt(0x40000000ll), AInt(0x40000000ll), AInt(0)},
{AInt(-0x1000000000000000ll), AInt(-0x40000000ll), AInt(0x40000000ll), AInt(0)},
{AInt(0x100000000000000ll), AInt(0x1000000), AInt(0x100000000ll), AInt(0)},
{AInt(0x2000000000000000ll), AInt(0x1000000000000000ll), AInt(2), AInt(0)},
{AInt(-0x2000000000000000ll), AInt(0x1000000000000000ll), AInt(-2), AInt(0)},
{AInt(-0x2000000000000000ll), AInt(-0x1000000000000000ll), AInt(2), AInt(0)},
{AInt(-0x2000000000000000ll), AInt(0x1000000000000000ll), AInt(-2), AInt(0)},
{AInt(0x4000000000000000ll), AInt(0x1000000000000000ll), AInt(4), AInt(0)},
{AInt(-0x4000000000000000ll), AInt(0x1000000000000000ll), AInt(-4), AInt(0)},
{AInt(-0x4000000000000000ll), AInt(-0x1000000000000000ll), AInt(4), AInt(0)},
{AInt(-0x4000000000000000ll), AInt(0x1000000000000000ll), AInt(-4), AInt(0)},
{AInt(-0x8000000000000000ll), AInt(0x1000000000000000ll), AInt(-8), AInt(0)},
{AInt(-0x8000000000000000ll), AInt(-0x1000000000000000ll), AInt(8), AInt(0)},
{AInt(AInt::kHighest), AInt(1), AInt(1), AInt(AInt::kHighest - 1)},
{AInt(AInt::kLowest), AInt(1), AInt(-1), AInt(AInt::kLowest + 1)},
{AInt(AInt::kHighest), AInt(1), AInt(0x7fffffff00000000ll), AInt(0x00000000ffffffffll)},
{AInt(AInt::kHighest), AInt(1), AInt(AInt::kHighest), AInt(0)},
{AInt(AInt::kLowest), AInt(1), AInt(AInt::kLowest), AInt(0)},
{OVERFLOW, AInt(0x1000000000000000ll), AInt(8), AInt(0)},
{OVERFLOW, AInt(-0x1000000000000000ll), AInt(-8), AInt(0)},
{OVERFLOW, AInt(0x800000000000000ll), AInt(0x10), AInt(0)},
{OVERFLOW, AInt(0x80000000ll), AInt(0x100000000ll), AInt(0)},
{OVERFLOW, AInt(AInt::kHighest), AInt(AInt::kHighest), AInt(0)},
{OVERFLOW, AInt(AInt::kHighest), AInt(AInt::kLowest), AInt(0)},
{OVERFLOW, AInt(1), AInt(1), AInt(AInt::kHighest)},
{OVERFLOW, AInt(1), AInt(-1), AInt(AInt::kLowest)},
{OVERFLOW, AInt(1), AInt(2), AInt(AInt::kHighest)},
{OVERFLOW, AInt(1), AInt(-2), AInt(AInt::kLowest)},
{OVERFLOW, AInt(1), AInt(10000), AInt(AInt::kHighest)},
{OVERFLOW, AInt(1), AInt(-10000), AInt(AInt::kLowest)},
{OVERFLOW, AInt(1), AInt(AInt::kHighest), AInt(AInt::kHighest)},
{OVERFLOW, AInt(1), AInt(AInt::kLowest), AInt(AInt::kLowest)},
{OVERFLOW, AInt(1), AInt(AInt::kHighest), AInt(1)},
{OVERFLOW, AInt(1), AInt(AInt::kLowest), AInt(-1)},
}));
TINT_END_DISABLE_WARNING(CONSTANT_OVERFLOW);
} // namespace

View File

@ -20,23 +20,63 @@
#define TINT_REQUIRE_SEMICOLON static_assert(true)
#if defined(_MSC_VER)
#define TINT_WARNING_UNREACHABLE_CODE 4702
#define TINT_WARNING_CONSTANT_OVERFLOW 4756
////////////////////////////////////////////////////////////////////////////////
// MSVC
////////////////////////////////////////////////////////////////////////////////
#define TINT_DISABLE_WARNING_CONSTANT_OVERFLOW __pragma(warning(disable : 4756))
#define TINT_DISABLE_WARNING_MAYBE_UNINITIALIZED /* currently no-op */
#define TINT_DISABLE_WARNING_UNREACHABLE_CODE __pragma(warning(disable : 4702))
// clang-format off
#define TINT_BEGIN_DISABLE_WARNING(name) \
__pragma(warning(push)) \
__pragma(warning(disable:TINT_CONCAT(TINT_WARNING_, name))) \
#define TINT_BEGIN_DISABLE_WARNING(name) \
__pragma(warning(push)) \
TINT_CONCAT(TINT_DISABLE_WARNING_, name) \
TINT_REQUIRE_SEMICOLON
#define TINT_END_DISABLE_WARNING(name) \
__pragma(warning(pop)) \
#define TINT_END_DISABLE_WARNING(name) \
__pragma(warning(pop)) \
TINT_REQUIRE_SEMICOLON
// clang-format on
#elif defined(__clang__)
////////////////////////////////////////////////////////////////////////////////
// Clang
////////////////////////////////////////////////////////////////////////////////
#define TINT_DISABLE_WARNING_CONSTANT_OVERFLOW /* currently no-op */
#define TINT_DISABLE_WARNING_MAYBE_UNINITIALIZED /* currently no-op */
#define TINT_DISABLE_WARNING_UNREACHABLE_CODE /* currently no-op */
// clang-format off
#define TINT_BEGIN_DISABLE_WARNING(name) \
_Pragma("clang diagnostic push") \
TINT_CONCAT(TINT_DISABLE_WARNING_, name) \
TINT_REQUIRE_SEMICOLON
#define TINT_END_DISABLE_WARNING(name) \
_Pragma("clang diagnostic pop") \
TINT_REQUIRE_SEMICOLON
// clang-format on
#elif defined(__GNUC__)
////////////////////////////////////////////////////////////////////////////////
// GCC
////////////////////////////////////////////////////////////////////////////////
#define TINT_DISABLE_WARNING_CONSTANT_OVERFLOW /* currently no-op */
#define TINT_DISABLE_WARNING_MAYBE_UNINITIALIZED \
_Pragma("GCC diagnostic ignored \"-Wmaybe-uninitialized\"")
#define TINT_DISABLE_WARNING_UNREACHABLE_CODE /* currently no-op */
// clang-format off
#define TINT_BEGIN_DISABLE_WARNING(name) \
_Pragma("GCC diagnostic push") \
TINT_CONCAT(TINT_DISABLE_WARNING_, name) \
TINT_REQUIRE_SEMICOLON
#define TINT_END_DISABLE_WARNING(name) \
_Pragma("GCC diagnostic pop") \
TINT_REQUIRE_SEMICOLON
// clang-format on
#else
// clang-format off
////////////////////////////////////////////////////////////////////////////////
// Other
////////////////////////////////////////////////////////////////////////////////
#define TINT_BEGIN_DISABLE_WARNING(name) TINT_REQUIRE_SEMICOLON
#define TINT_END_DISABLE_WARNING(name) TINT_REQUIRE_SEMICOLON
// clang-format on
#endif // defined(_MSC_VER)
#endif
#endif // SRC_TINT_UTILS_COMPILER_MACROS_H_