tint: Implement const eval of binary multiply

Bug: tint:1581
Change-Id: I70ff40ed4d8faf0a665824fef936ffbafb3f0948
Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/99362
Kokoro: Kokoro <noreply+kokoro@google.com>
Commit-Queue: Antonio Maiorano <amaiorano@google.com>
Reviewed-by: Ben Clayton <bclayton@google.com>
This commit is contained in:
Antonio Maiorano
2022-09-01 14:57:39 +00:00
committed by Dawn LUCI CQ
parent ae6f76fe3a
commit c20c5dfb4a
35 changed files with 1243 additions and 482 deletions

View File

@@ -898,15 +898,15 @@ op ! <N: num> (vec<N, bool>) -> vec<N, bool>
@const op - <T: fia_fiu32_f16, N: num> (T, vec<N, T>) -> vec<N, T>
@const op - <T: fa_f32_f16, N: num, M: num> (mat<N, M, T>, mat<N, M, T>) -> mat<N, M, T>
op * <T: fiu32_f16>(T, T) -> T
op * <T: fiu32_f16, N: num> (vec<N, T>, vec<N, T>) -> vec<N, T>
op * <T: fiu32_f16, N: num> (vec<N, T>, T) -> vec<N, T>
op * <T: fiu32_f16, N: num> (T, vec<N, T>) -> vec<N, T>
op * <T: f32_f16, N: num, M: num> (T, mat<N, M, T>) -> mat<N, M, T>
op * <T: f32_f16, N: num, M: num> (mat<N, M, T>, T) -> mat<N, M, T>
op * <T: f32_f16, C: num, R: num> (mat<C, R, T>, vec<C, T>) -> vec<R, T>
op * <T: f32_f16, C: num, R: num> (vec<R, T>, mat<C, R, T>) -> vec<C, T>
op * <T: f32_f16, K: num, C: num, R: num> (mat<K, R, T>, mat<C, K, T>) -> mat<C, R, T>
@const("Multiply") op * <T: fia_fiu32_f16>(T, T) -> T
@const("Multiply") op * <T: fia_fiu32_f16, N: num> (vec<N, T>, vec<N, T>) -> vec<N, T>
@const("Multiply") op * <T: fia_fiu32_f16, N: num> (vec<N, T>, T) -> vec<N, T>
@const("Multiply") op * <T: fia_fiu32_f16, N: num> (T, vec<N, T>) -> vec<N, T>
@const("Multiply") op * <T: fa_f32_f16, N: num, M: num> (T, mat<N, M, T>) -> mat<N, M, T>
@const("Multiply") op * <T: fa_f32_f16, N: num, M: num> (mat<N, M, T>, T) -> mat<N, M, T>
@const("MultiplyMatVec") op * <T: fa_f32_f16, C: num, R: num> (mat<C, R, T>, vec<C, T>) -> vec<R, T>
@const("MultiplyVecMat") op * <T: fa_f32_f16, C: num, R: num> (vec<R, T>, mat<C, R, T>) -> vec<C, T>
@const("MultiplyMatMat") op * <T: fa_f32_f16, K: num, C: num, R: num> (mat<K, R, T>, mat<C, K, T>) -> mat<C, R, T>
op / <T: fiu32_f16>(T, T) -> T
op / <T: fiu32_f16, N: num> (vec<N, T>, vec<N, T>) -> vec<N, T>

View File

@@ -22,6 +22,7 @@
#include <optional>
#include <ostream>
#include "src/tint/traits.h"
#include "src/tint/utils/compiler_macros.h"
#include "src/tint/utils/result.h"
@@ -33,6 +34,14 @@ struct Number;
} // namespace tint
namespace tint::detail {
/// Base template for IsNumber
template <typename T>
struct IsNumber : std::false_type {};
/// Specialization for IsNumber
template <typename T>
struct IsNumber<Number<T>> : std::true_type {};
/// An empty structure used as a unique template type for Number when
/// specializing for the f16 type.
struct NumberKindF16 {};
@@ -68,6 +77,10 @@ constexpr bool IsInteger = std::is_integral_v<T>;
template <typename T>
constexpr bool IsNumeric = IsInteger<T> || IsFloatingPoint<T>;
/// Evaluates to true iff T is a Number
template <typename T>
constexpr bool IsNumber = detail::IsNumber<T>::value;
/// Resolves to the underlying type for a Number.
template <typename T>
using UnwrapNumber = typename detail::NumberUnwrapper<T>::type;
@@ -236,6 +249,26 @@ using f32 = Number<float>;
/// However since C++ don't have native binary16 type, the value is stored as float.
using f16 = Number<detail::NumberKindF16>;
/// @returns the friendly name of Number type T
template <typename T, typename = traits::EnableIf<IsNumber<T>>>
const char* FriendlyName() {
if constexpr (std::is_same_v<T, AInt>) {
return "abstract-int";
} else if constexpr (std::is_same_v<T, AFloat>) {
return "abstract-float";
} else if constexpr (std::is_same_v<T, i32>) {
return "i32";
} else if constexpr (std::is_same_v<T, u32>) {
return "u32";
} else if constexpr (std::is_same_v<T, f32>) {
return "f32";
} else if constexpr (std::is_same_v<T, f16>) {
return "f16";
} else {
static_assert(!sizeof(T), "Unhandled type");
}
}
/// Enumerator of failure reasons when converting from one number to another.
enum class ConversionFailure {
kExceedsPositiveLimit, // The value was too big (+'ve) to fit in the target type
@@ -437,6 +470,15 @@ inline std::optional<AInt> CheckedMul(AInt a, AInt b) {
return AInt(result);
}
/// @returns a * b, or an empty optional if the resulting value overflowed the AFloat
inline std::optional<AFloat> CheckedMul(AFloat a, AFloat b) {
auto result = a.value * b.value;
if (!std::isfinite(result)) {
return {};
}
return AFloat{result};
}
/// @returns a * b + c, or an empty optional if the value overflowed the AInt
inline std::optional<AInt> CheckedMadd(AInt a, AInt b, AInt c) {
// https://gcc.gnu.org/bugzilla/show_bug.cgi?id=80635

View File

@@ -38,6 +38,7 @@
#include "src/tint/sem/vector.h"
#include "src/tint/utils/compiler_macros.h"
#include "src/tint/utils/map.h"
#include "src/tint/utils/scoped_assignment.h"
#include "src/tint/utils/transform.h"
using namespace tint::number_suffixes; // NOLINT
@@ -508,11 +509,202 @@ const Constant* TransformBinaryElements(ProgramBuilder& builder,
auto* ty = n0 > n1 ? c0->Type() : c1->Type();
return CreateComposite(builder, ty, std::move(els));
}
} // namespace
ConstEval::ConstEval(ProgramBuilder& b) : builder(b) {}
template <typename NumberT>
utils::Result<NumberT> ConstEval::Add(NumberT a, NumberT b) {
using T = UnwrapNumber<NumberT>;
auto add_values = [](T lhs, T rhs) {
if constexpr (std::is_integral_v<T> && std::is_signed_v<T>) {
// Ensure no UB for signed overflow
using UT = std::make_unsigned_t<T>;
return static_cast<T>(static_cast<UT>(lhs) + static_cast<UT>(rhs));
} else {
return lhs + rhs;
}
};
NumberT result;
if constexpr (std::is_same_v<NumberT, AInt> || std::is_same_v<NumberT, AFloat>) {
// Check for over/underflow for abstract values
if (auto r = CheckedAdd(a, b)) {
result = r->value;
} else {
AddError("'" + std::to_string(add_values(a.value, b.value)) +
"' cannot be represented as '" + FriendlyName<NumberT>() + "'",
*current_source);
return utils::Failure;
}
} else {
result = add_values(a.value, b.value);
}
return result;
}
template <typename NumberT>
utils::Result<NumberT> ConstEval::Mul(NumberT a, NumberT b) {
using T = UnwrapNumber<NumberT>;
auto mul_values = [](T lhs, T rhs) { //
if constexpr (std::is_integral_v<T> && std::is_signed_v<T>) {
// For signed integrals, avoid C++ UB by multiplying as unsigned
using UT = std::make_unsigned_t<T>;
return static_cast<T>(static_cast<UT>(lhs) * static_cast<UT>(rhs));
} else {
return lhs * rhs;
}
};
NumberT result;
if constexpr (std::is_same_v<NumberT, AInt> || std::is_same_v<NumberT, AFloat>) {
// Check for over/underflow for abstract values
if (auto r = CheckedMul(a, b)) {
result = r->value;
} else {
AddError("'" + std::to_string(mul_values(a.value, b.value)) +
"' cannot be represented as '" + FriendlyName<NumberT>() + "'",
*current_source);
return utils::Failure;
}
} else {
result = mul_values(a.value, b.value);
}
return result;
}
template <typename NumberT>
utils::Result<NumberT> ConstEval::Dot2(NumberT a1, NumberT a2, NumberT b1, NumberT b2) {
auto r1 = Mul(a1, b1);
if (!r1) {
return utils::Failure;
}
auto r2 = Mul(a2, b2);
if (!r2) {
return utils::Failure;
}
auto r = Add(r1.Get(), r2.Get());
if (!r) {
return utils::Failure;
}
return r;
}
template <typename NumberT>
utils::Result<NumberT> ConstEval::Dot3(NumberT a1,
NumberT a2,
NumberT a3,
NumberT b1,
NumberT b2,
NumberT b3) {
auto r1 = Mul(a1, b1);
if (!r1) {
return utils::Failure;
}
auto r2 = Mul(a2, b2);
if (!r2) {
return utils::Failure;
}
auto r3 = Mul(a3, b3);
if (!r3) {
return utils::Failure;
}
auto r = Add(r1.Get(), r2.Get());
if (!r) {
return utils::Failure;
}
r = Add(r.Get(), r3.Get());
if (!r) {
return utils::Failure;
}
return r;
}
template <typename NumberT>
utils::Result<NumberT> ConstEval::Dot4(NumberT a1,
NumberT a2,
NumberT a3,
NumberT a4,
NumberT b1,
NumberT b2,
NumberT b3,
NumberT b4) {
auto r1 = Mul(a1, b1);
if (!r1) {
return utils::Failure;
}
auto r2 = Mul(a2, b2);
if (!r2) {
return utils::Failure;
}
auto r3 = Mul(a3, b3);
if (!r3) {
return utils::Failure;
}
auto r4 = Mul(a4, b4);
if (!r4) {
return utils::Failure;
}
auto r = Add(r1.Get(), r2.Get());
if (!r) {
return utils::Failure;
}
r = Add(r.Get(), r3.Get());
if (!r) {
return utils::Failure;
}
r = Add(r.Get(), r4.Get());
if (!r) {
return utils::Failure;
}
return r;
}
auto ConstEval::AddFunc(const sem::Type* elem_ty) {
return [=](auto a1, auto a2) -> utils::Result<const Constant*> {
if (auto r = Add(a1, a2)) {
return CreateElement(builder, elem_ty, r.Get());
}
return utils::Failure;
};
}
auto ConstEval::MulFunc(const sem::Type* elem_ty) {
return [=](auto a1, auto a2) -> utils::Result<const Constant*> {
if (auto r = Mul(a1, a2)) {
return CreateElement(builder, elem_ty, r.Get());
}
return utils::Failure;
};
}
auto ConstEval::Dot2Func(const sem::Type* elem_ty) {
return [=](auto a1, auto a2, auto b1, auto b2) -> utils::Result<const Constant*> {
if (auto r = Dot2(a1, a2, b1, b2)) {
return CreateElement(builder, elem_ty, r.Get());
}
return utils::Failure;
};
}
auto ConstEval::Dot3Func(const sem::Type* elem_ty) {
return [=](auto a1, auto a2, auto a3, auto b1, auto b2,
auto b3) -> utils::Result<const Constant*> {
if (auto r = Dot3(a1, a2, a3, b1, b2, b3)) {
return CreateElement(builder, elem_ty, r.Get());
}
return utils::Failure;
};
}
auto ConstEval::Dot4Func(const sem::Type* elem_ty) {
return [=](auto a1, auto a2, auto a3, auto a4, auto b1, auto b2, auto b3,
auto b4) -> utils::Result<const Constant*> {
if (auto r = Dot4(a1, a2, a3, a4, b1, b2, b3, b4)) {
return CreateElement(builder, elem_ty, r.Get());
}
return utils::Failure;
};
}
ConstEval::ConstantResult ConstEval::Literal(const sem::Type* ty,
const ast::LiteralExpression* literal) {
return Switch(
@@ -756,42 +948,15 @@ ConstEval::ConstantResult ConstEval::OpUnaryMinus(const sem::Type*,
return TransformElements(builder, transform, args[0]);
}
ConstEval::ConstantResult ConstEval::OpPlus(const sem::Type* ty,
ConstEval::ConstantResult ConstEval::OpPlus(const sem::Type*,
utils::VectorRef<const sem::Constant*> args,
const Source& source) {
auto transform = [&](const sem::Constant* c0, const sem::Constant* c1) {
auto create = [&](auto i, auto j) -> const Constant* {
using NumberT = decltype(i);
using T = UnwrapNumber<NumberT>;
auto add_values = [](T lhs, T rhs) {
if constexpr (std::is_integral_v<T> && std::is_signed_v<T>) {
// Ensure no UB for signed overflow
using UT = std::make_unsigned_t<T>;
return static_cast<T>(static_cast<UT>(lhs) + static_cast<UT>(rhs));
} else {
return lhs + rhs;
}
};
NumberT result;
if constexpr (std::is_same_v<NumberT, AInt> || std::is_same_v<NumberT, AFloat>) {
// Check for over/underflow for abstract values
if (auto r = CheckedAdd(i, j)) {
result = r->value;
} else {
AddError("'" + std::to_string(add_values(i.value, j.value)) +
"' cannot be represented as '" +
ty->FriendlyName(builder.Symbols()) + "'",
source);
return nullptr;
}
} else {
result = add_values(i.value, j.value);
}
return CreateElement(builder, c0->Type(), result);
};
return Dispatch_fia_fiu32_f16(create, c0, c1);
TINT_SCOPED_ASSIGNMENT(current_source, &source);
auto transform = [&](const sem::Constant* c0, const sem::Constant* c1) -> const Constant* {
if (auto r = Dispatch_fia_fiu32_f16(AddFunc(c0->Type()), c0, c1)) {
return r.Get();
}
return nullptr;
};
auto r = TransformBinaryElements(builder, transform, args[0], args[1]);
@@ -846,6 +1011,192 @@ ConstEval::ConstantResult ConstEval::OpMinus(const sem::Type* ty,
return r;
}
ConstEval::ConstantResult ConstEval::OpMultiply(const sem::Type* /*ty*/,
utils::VectorRef<const sem::Constant*> args,
const Source& source) {
TINT_SCOPED_ASSIGNMENT(current_source, &source);
auto transform = [&](const sem::Constant* c0, const sem::Constant* c1) -> const Constant* {
if (auto r = Dispatch_fia_fiu32_f16(MulFunc(c0->Type()), c0, c1)) {
return r.Get();
}
return nullptr;
};
auto r = TransformBinaryElements(builder, transform, args[0], args[1]);
if (builder.Diagnostics().contains_errors()) {
return utils::Failure;
}
return r;
}
ConstEval::ConstantResult ConstEval::OpMultiplyMatVec(const sem::Type* ty,
utils::VectorRef<const sem::Constant*> args,
const Source& source) {
TINT_SCOPED_ASSIGNMENT(current_source, &source);
auto* mat_ty = args[0]->Type()->As<sem::Matrix>();
auto* vec_ty = args[1]->Type()->As<sem::Vector>();
auto* elem_ty = vec_ty->type();
auto dot = [&](const sem::Constant* m, size_t row, const sem::Constant* v) {
utils::Result<const Constant*> result;
switch (mat_ty->columns()) {
case 2:
result = Dispatch_fa_f32_f16(Dot2Func(elem_ty), //
m->Index(0)->Index(row), //
m->Index(1)->Index(row), //
v->Index(0), //
v->Index(1));
break;
case 3:
result = Dispatch_fa_f32_f16(Dot3Func(elem_ty), //
m->Index(0)->Index(row), //
m->Index(1)->Index(row), //
m->Index(2)->Index(row), //
v->Index(0), //
v->Index(1), v->Index(2));
break;
case 4:
result = Dispatch_fa_f32_f16(Dot4Func(elem_ty), //
m->Index(0)->Index(row), //
m->Index(1)->Index(row), //
m->Index(2)->Index(row), //
m->Index(3)->Index(row), //
v->Index(0), //
v->Index(1), //
v->Index(2), //
v->Index(3));
break;
}
return result;
};
utils::Vector<const sem::Constant*, 4> result;
for (size_t i = 0; i < mat_ty->rows(); ++i) {
auto r = dot(args[0], i, args[1]); // matrix row i * vector
if (!r) {
return utils::Failure;
}
result.Push(r.Get());
}
return CreateComposite(builder, ty, result);
}
ConstEval::ConstantResult ConstEval::OpMultiplyVecMat(const sem::Type* ty,
utils::VectorRef<const sem::Constant*> args,
const Source& source) {
TINT_SCOPED_ASSIGNMENT(current_source, &source);
auto* vec_ty = args[0]->Type()->As<sem::Vector>();
auto* mat_ty = args[1]->Type()->As<sem::Matrix>();
auto* elem_ty = vec_ty->type();
auto dot = [&](const sem::Constant* v, const sem::Constant* m, size_t col) {
utils::Result<const Constant*> result;
switch (mat_ty->rows()) {
case 2:
result = Dispatch_fa_f32_f16(Dot2Func(elem_ty), //
m->Index(col)->Index(0), //
m->Index(col)->Index(1), //
v->Index(0), //
v->Index(1));
break;
case 3:
result = Dispatch_fa_f32_f16(Dot3Func(elem_ty), //
m->Index(col)->Index(0), //
m->Index(col)->Index(1), //
m->Index(col)->Index(2),
v->Index(0), //
v->Index(1), //
v->Index(2));
break;
case 4:
result = Dispatch_fa_f32_f16(Dot4Func(elem_ty), //
m->Index(col)->Index(0), //
m->Index(col)->Index(1), //
m->Index(col)->Index(2), //
m->Index(col)->Index(3), //
v->Index(0), //
v->Index(1), //
v->Index(2), //
v->Index(3));
}
return result;
};
utils::Vector<const sem::Constant*, 4> result;
for (size_t i = 0; i < mat_ty->columns(); ++i) {
auto r = dot(args[0], args[1], i); // vector * matrix col i
if (!r) {
return utils::Failure;
}
result.Push(r.Get());
}
return CreateComposite(builder, ty, result);
}
ConstEval::ConstantResult ConstEval::OpMultiplyMatMat(const sem::Type* ty,
utils::VectorRef<const sem::Constant*> args,
const Source& source) {
TINT_SCOPED_ASSIGNMENT(current_source, &source);
auto* mat1 = args[0];
auto* mat2 = args[1];
auto* mat1_ty = mat1->Type()->As<sem::Matrix>();
auto* mat2_ty = mat2->Type()->As<sem::Matrix>();
auto* elem_ty = mat1_ty->type();
auto dot = [&](const sem::Constant* m1, size_t row, const sem::Constant* m2, size_t col) {
auto m1e = [&](size_t r, size_t c) { return m1->Index(c)->Index(r); };
auto m2e = [&](size_t r, size_t c) { return m2->Index(c)->Index(r); };
utils::Result<const Constant*> result;
switch (mat1_ty->columns()) {
case 2:
result = Dispatch_fa_f32_f16(Dot2Func(elem_ty), //
m1e(row, 0), //
m1e(row, 1), //
m2e(0, col), //
m2e(1, col));
break;
case 3:
result = Dispatch_fa_f32_f16(Dot3Func(elem_ty), //
m1e(row, 0), //
m1e(row, 1), //
m1e(row, 2), //
m2e(0, col), //
m2e(1, col), //
m2e(2, col));
break;
case 4:
result = Dispatch_fa_f32_f16(Dot4Func(elem_ty), //
m1e(row, 0), //
m1e(row, 1), //
m1e(row, 2), //
m1e(row, 3), //
m2e(0, col), //
m2e(1, col), //
m2e(2, col), //
m2e(3, col));
break;
}
return result;
};
utils::Vector<const sem::Constant*, 4> result_mat;
for (size_t c = 0; c < mat2_ty->columns(); ++c) {
utils::Vector<const sem::Constant*, 4> col_vec;
for (size_t r = 0; r < mat1_ty->rows(); ++r) {
auto v = dot(mat1, r, mat2, c); // mat1 row r * mat2 col c
if (!v) {
return utils::Failure;
}
col_vec.Push(v.Get()); // mat1 row r * mat2 col c
}
// Add column vector to matrix
auto* col_vec_ty = ty->As<sem::Matrix>()->ColumnType();
result_mat.Push(CreateComposite(builder, col_vec_ty, col_vec));
}
return CreateComposite(builder, ty, result_mat);
}
ConstEval::ConstantResult ConstEval::atan2(const sem::Type*,
utils::VectorRef<const sem::Constant*> args,
const Source&) {

View File

@@ -230,6 +230,42 @@ class ConstEval {
utils::VectorRef<const sem::Constant*> args,
const Source& source);
/// Multiply operator '*' for the same type on the LHS and RHS
/// @param ty the expression type
/// @param args the input arguments
/// @param source the source location of the conversion
/// @return the result value, or null if the value cannot be calculated
ConstantResult OpMultiply(const sem::Type* ty,
utils::VectorRef<const sem::Constant*> args,
const Source& source);
/// Multiply operator '*' for matCxR<T> * vecC<T>
/// @param ty the expression type
/// @param args the input arguments
/// @param source the source location of the conversion
/// @return the result value, or null if the value cannot be calculated
ConstantResult OpMultiplyMatVec(const sem::Type* ty,
utils::VectorRef<const sem::Constant*> args,
const Source& source);
/// Multiply operator '*' for vecR<T> * matCxR<T>
/// @param ty the expression type
/// @param args the input arguments
/// @param source the source location of the conversion
/// @return the result value, or null if the value cannot be calculated
ConstantResult OpMultiplyVecMat(const sem::Type* ty,
utils::VectorRef<const sem::Constant*> args,
const Source& source);
/// Multiply operator '*' for matKxR<T> * matCxK<T>
/// @param ty the expression type
/// @param args the input arguments
/// @param source the source location of the conversion
/// @return the result value, or null if the value cannot be calculated
ConstantResult OpMultiplyMatMat(const sem::Type* ty,
utils::VectorRef<const sem::Constant*> args,
const Source& source);
////////////////////////////////////////////////////////////////////////////
// Builtins
////////////////////////////////////////////////////////////////////////////
@@ -259,7 +295,97 @@ class ConstEval {
/// Adds the given warning message to the diagnostics
void AddWarning(const std::string& msg, const Source& source) const;
/// Adds two Number<T>s
/// @param a the lhs number
/// @param b the rhs number
/// @returns the result number on success, or logs an error and returns Failure
template <typename NumberT>
utils::Result<NumberT> Add(NumberT a, NumberT b);
/// Multiplies two Number<T>s
/// @param a the lhs number
/// @param b the rhs number
/// @returns the result number on success, or logs an error and returns Failure
template <typename NumberT>
utils::Result<NumberT> Mul(NumberT a, NumberT b);
/// Returns the dot product of (a1,a2) with (b1,b2)
/// @param a1 component 1 of lhs vector
/// @param a2 component 2 of lhs vector
/// @param b1 component 1 of rhs vector
/// @param b2 component 2 of rhs vector
/// @returns the result number on success, or logs an error and returns Failure
template <typename NumberT>
utils::Result<NumberT> Dot2(NumberT a1, NumberT a2, NumberT b1, NumberT b2);
/// Returns the dot product of (a1,a2,a3) with (b1,b2,b3)
/// @param a1 component 1 of lhs vector
/// @param a2 component 2 of lhs vector
/// @param a3 component 3 of lhs vector
/// @param b1 component 1 of rhs vector
/// @param b2 component 2 of rhs vector
/// @param b3 component 3 of rhs vector
/// @returns the result number on success, or logs an error and returns Failure
template <typename NumberT>
utils::Result<NumberT> Dot3(NumberT a1,
NumberT a2,
NumberT a3,
NumberT b1,
NumberT b2,
NumberT b3);
/// Returns the dot product of (a1,b1,c1,d1) with (a2,b2,c2,d2)
/// @param a1 component 1 of lhs vector
/// @param a2 component 2 of lhs vector
/// @param a3 component 3 of lhs vector
/// @param a4 component 4 of lhs vector
/// @param b1 component 1 of rhs vector
/// @param b2 component 2 of rhs vector
/// @param b3 component 3 of rhs vector
/// @param b4 component 4 of rhs vector
/// @returns the result number on success, or logs an error and returns Failure
template <typename NumberT>
utils::Result<NumberT> Dot4(NumberT a1,
NumberT a2,
NumberT a3,
NumberT a4,
NumberT b1,
NumberT b2,
NumberT b3,
NumberT b4);
/// Returns a callable that calls Add, and creates a Constant with its result of type `elem_ty`
/// if successful, or returns Failure otherwise.
/// @param elem_ty the element type of the Constant to create on success
/// @returns the callable function
auto AddFunc(const sem::Type* elem_ty);
/// Returns a callable that calls Mul, and creates a Constant with its result of type `elem_ty`
/// if successful, or returns Failure otherwise.
/// @param elem_ty the element type of the Constant to create on success
/// @returns the callable function
auto MulFunc(const sem::Type* elem_ty);
/// Returns a callable that calls Dot2, and creates a Constant with its result of type `elem_ty`
/// if successful, or returns Failure otherwise.
/// @param elem_ty the element type of the Constant to create on success
/// @returns the callable function
auto Dot2Func(const sem::Type* elem_ty);
/// Returns a callable that calls Dot3, and creates a Constant with its result of type `elem_ty`
/// if successful, or returns Failure otherwise.
/// @param elem_ty the element type of the Constant to create on success
/// @returns the callable function
auto Dot3Func(const sem::Type* elem_ty);
/// Returns a callable that calls Dot4, and creates a Constant with its result of type `elem_ty`
/// if successful, or returns Failure otherwise.
/// @param elem_ty the element type of the Constant to create on success
/// @returns the callable function
auto Dot4Func(const sem::Type* elem_ty);
ProgramBuilder& builder;
const Source* current_source = nullptr;
};
} // namespace tint::resolver

View File

@@ -15,6 +15,7 @@
#include <cmath>
#include <type_traits>
#include "gmock/gmock.h"
#include "gtest/gtest.h"
#include "src/tint/resolver/resolver_test_helper.h"
#include "src/tint/sem/builtin_type.h"
@@ -24,6 +25,8 @@
#include "src/tint/sem/test_helper.h"
#include "src/tint/utils/transform.h"
using ::testing::HasSubstr;
using namespace tint::number_suffixes; // NOLINT
namespace tint::resolver {
@@ -74,6 +77,19 @@ auto Abs(const Number<T>& v) {
}
}
TINT_BEGIN_DISABLE_WARNING(CONSTANT_OVERFLOW);
template <typename T>
constexpr Number<T> Mul(Number<T> v1, Number<T> v2) {
if constexpr (std::is_integral_v<T> && std::is_signed_v<T>) {
// For signed integrals, avoid C++ UB by multiplying as unsigned
using UT = std::make_unsigned_t<T>;
return static_cast<Number<T>>(static_cast<UT>(v1) * static_cast<UT>(v2));
} else {
return static_cast<Number<T>>(v1 * v2);
}
}
TINT_END_DISABLE_WARNING(CONSTANT_OVERFLOW);
// Concats any number of std::vectors
template <typename Vec, typename... Vecs>
[[nodiscard]] auto Concat(Vec&& v1, Vecs&&... vs) {
@@ -3228,29 +3244,26 @@ Case C(T lhs, U rhs, V expected, bool overflow = false) {
return Case{Val(lhs), Val(rhs), Val(expected), overflow};
}
static std::ostream& operator<<(std::ostream& o, const Case& c) {
auto print_value = [&](auto&& value) {
std::visit(
[&](auto&& v) {
using ValueType = std::decay_t<decltype(v)>;
o << ValueType::DataType::Name() << "(";
for (auto& a : v.args.values) {
o << std::get<typename ValueType::ElementType>(a);
if (&a != &v.args.values.Back()) {
o << ", ";
}
static std::ostream& operator<<(std::ostream& o, const Types& types) {
std::visit(
[&](auto&& v) {
using ValueType = std::decay_t<decltype(v)>;
o << ValueType::DataType::Name() << "(";
for (auto& a : v.args.values) {
o << std::get<typename ValueType::ElementType>(a);
if (&a != &v.args.values.Back()) {
o << ", ";
}
o << ")";
},
value);
};
o << "lhs: ";
print_value(c.lhs);
o << ", rhs: ";
print_value(c.rhs);
o << ", expected: ";
print_value(c.expected);
o << ", overflow: " << c.overflow;
}
o << ")";
},
types);
return o;
}
static std::ostream& operator<<(std::ostream& o, const Case& c) {
o << "lhs: " << c.lhs << ", rhs: " << c.rhs << ", expected: " << c.expected
<< ", overflow: " << c.overflow;
return o;
}
@@ -3281,7 +3294,6 @@ bool ForEachElemPair(const sem::Constant* a, const sem::Constant* b, Func&& f) {
using ResolverConstEvalBinaryOpTest = ResolverTestWithParam<std::tuple<ast::BinaryOp, Case>>;
TEST_P(ResolverConstEvalBinaryOpTest, Test) {
Enable(ast::Extension::kF16);
auto op = std::get<0>(GetParam());
auto& c = std::get<1>(GetParam());
@@ -3300,10 +3312,8 @@ TEST_P(ResolverConstEvalBinaryOpTest, Test) {
auto* expr = create<ast::BinaryExpression>(op, lhs_expr, rhs_expr);
GlobalConst("C", expr);
auto* expected_expr = expected.Expr(*this);
GlobalConst("E", expected_expr);
EXPECT_TRUE(r()->Resolve()) << r()->error();
auto* sem = Sem().Get(expr);
@@ -3413,6 +3423,215 @@ INSTANTIATE_TEST_SUITE_P(Sub,
OpSubFloatCases<f32>(),
OpSubFloatCases<f16>()))));
template <typename T>
std::vector<Case> OpMulScalarCases() {
return {
C(T{0}, T{0}, T{0}),
C(T{1}, T{2}, T{2}),
C(T{2}, T{3}, T{6}),
C(Negate(T{2}), T{3}, Negate(T{6})),
C(T::Highest(), T{1}, T::Highest()),
C(T::Lowest(), T{1}, T::Lowest()),
C(T::Highest(), T::Highest(), Mul(T::Highest(), T::Highest()), true),
C(T::Lowest(), T::Lowest(), Mul(T::Lowest(), T::Lowest()), true),
};
}
template <typename T>
std::vector<Case> OpMulVecCases() {
return {
// s * vec3 = vec3
C(Val(T{2.0}), Vec(T{1.25}, T{2.25}, T{3.25}), Vec(T{2.5}, T{4.5}, T{6.5})),
// vec3 * s = vec3
C(Vec(T{1.25}, T{2.25}, T{3.25}), Val(T{2.0}), Vec(T{2.5}, T{4.5}, T{6.5})),
// vec3 * vec3 = vec3
C(Vec(T{1.25}, T{2.25}, T{3.25}), Vec(T{2.0}, T{2.0}, T{2.0}), Vec(T{2.5}, T{4.5}, T{6.5})),
};
}
template <typename T>
std::vector<Case> OpMulMatCases() {
return {
// s * mat3x2 = mat3x2
C(Val(T{2.25}),
Mat({T{1.0}, T{4.0}}, //
{T{2.0}, T{5.0}}, //
{T{3.0}, T{6.0}}),
Mat({T{2.25}, T{9.0}}, //
{T{4.5}, T{11.25}}, //
{T{6.75}, T{13.5}})),
// mat3x2 * s = mat3x2
C(Mat({T{1.0}, T{4.0}}, //
{T{2.0}, T{5.0}}, //
{T{3.0}, T{6.0}}),
Val(T{2.25}),
Mat({T{2.25}, T{9.0}}, //
{T{4.5}, T{11.25}}, //
{T{6.75}, T{13.5}})),
// vec3 * mat2x3 = vec2
C(Vec(T{1.25}, T{2.25}, T{3.25}), //
Mat({T{1.0}, T{2.0}, T{3.0}}, //
{T{4.0}, T{5.0}, T{6.0}}), //
Vec(T{15.5}, T{35.75})),
// mat2x3 * vec2 = vec3
C(Mat({T{1.0}, T{2.0}, T{3.0}}, //
{T{4.0}, T{5.0}, T{6.0}}), //
Vec(T{1.25}, T{2.25}), //
Vec(T{10.25}, T{13.75}, T{17.25})),
// mat3x2 * mat2x3 = mat2x2
C(Mat({T{1.0}, T{2.0}}, //
{T{3.0}, T{4.0}}, //
{T{5.0}, T{6.0}}), //
Mat({T{1.25}, T{2.25}, T{3.25}}, //
{T{4.25}, T{5.25}, T{6.25}}), //
Mat({T{24.25}, T{31.0}}, //
{T{51.25}, T{67.0}})), //
};
}
INSTANTIATE_TEST_SUITE_P(Mul,
ResolverConstEvalBinaryOpTest,
testing::Combine( //
testing::Values(ast::BinaryOp::kMultiply),
testing::ValuesIn(Concat( //
OpMulScalarCases<AInt>(),
OpMulScalarCases<i32>(),
OpMulScalarCases<u32>(),
OpMulScalarCases<AFloat>(),
OpMulScalarCases<f32>(),
OpMulScalarCases<f16>(),
OpMulVecCases<AInt>(),
OpMulVecCases<i32>(),
OpMulVecCases<u32>(),
OpMulVecCases<AFloat>(),
OpMulVecCases<f32>(),
OpMulVecCases<f16>(),
OpMulMatCases<AFloat>(),
OpMulMatCases<f32>(),
OpMulMatCases<f16>()))));
// Tests for errors on overflow/underflow of binary operations with abstract numbers
struct OverflowCase {
ast::BinaryOp op;
Types lhs;
Types rhs;
std::string overflowed_result;
};
static std::ostream& operator<<(std::ostream& o, const OverflowCase& c) {
o << ast::FriendlyName(c.op) << ", lhs: " << c.lhs << ", rhs: " << c.rhs;
return o;
}
using ResolverConstEvalBinaryOpTest_Overflow = ResolverTestWithParam<OverflowCase>;
TEST_P(ResolverConstEvalBinaryOpTest_Overflow, Test) {
Enable(ast::Extension::kF16);
auto& c = GetParam();
auto* lhs_expr = std::visit([&](auto&& value) { return value.Expr(*this); }, c.lhs);
auto* rhs_expr = std::visit([&](auto&& value) { return value.Expr(*this); }, c.rhs);
auto* expr = create<ast::BinaryExpression>(Source{{1, 1}}, c.op, lhs_expr, rhs_expr);
GlobalConst("C", expr);
ASSERT_FALSE(r()->Resolve());
std::string type_name = std::visit(
[&](auto&& value) {
using ValueType = std::decay_t<decltype(value)>;
return tint::FriendlyName<typename ValueType::ElementType>();
},
c.lhs);
EXPECT_THAT(r()->error(), HasSubstr("1:1 error: '" + c.overflowed_result +
"' cannot be represented as '" + type_name + "'"));
}
INSTANTIATE_TEST_SUITE_P(
Test,
ResolverConstEvalBinaryOpTest_Overflow,
testing::Values( //
// scalar-scalar add
OverflowCase{ast::BinaryOp::kAdd, Val(AInt::Highest()), Val(1_a), "-9223372036854775808"},
OverflowCase{ast::BinaryOp::kAdd, Val(AInt::Lowest()), Val(-1_a), "9223372036854775807"},
OverflowCase{ast::BinaryOp::kAdd, Val(AFloat::Highest()), Val(AFloat::Highest()), "inf"},
OverflowCase{ast::BinaryOp::kAdd, Val(AFloat::Lowest()), Val(AFloat::Lowest()), "-inf"},
// scalar-scalar subtract
OverflowCase{ast::BinaryOp::kSubtract, Val(AInt::Lowest()), Val(1_a),
"9223372036854775807"},
OverflowCase{ast::BinaryOp::kSubtract, Val(AInt::Highest()), Val(-1_a),
"-9223372036854775808"},
OverflowCase{ast::BinaryOp::kSubtract, Val(AFloat::Highest()), Val(AFloat::Lowest()),
"inf"},
OverflowCase{ast::BinaryOp::kSubtract, Val(AFloat::Lowest()), Val(AFloat::Highest()),
"-inf"},
// scalar-scalar multiply
OverflowCase{ast::BinaryOp::kMultiply, Val(AInt::Highest()), Val(2_a), "-2"},
OverflowCase{ast::BinaryOp::kMultiply, Val(AInt::Lowest()), Val(-2_a), "0"},
// scalar-vector multiply
OverflowCase{ast::BinaryOp::kMultiply, Val(AInt::Highest()), Vec(2_a, 1_a), "-2"},
OverflowCase{ast::BinaryOp::kMultiply, Val(AInt::Lowest()), Vec(-2_a, 1_a), "0"},
// vector-matrix multiply
// Overflow from first multiplication of dot product of vector and matrix column 0
// i.e. (v[0] * m[0][0] + v[1] * m[0][1])
// ^
OverflowCase{ast::BinaryOp::kMultiply, //
Vec(AFloat::Highest(), 1.0_a), //
Mat({2.0_a, 1.0_a}, //
{1.0_a, 1.0_a}), //
"inf"},
// Overflow from second multiplication of dot product of vector and matrix column 0
// i.e. (v[0] * m[0][0] + v[1] * m[0][1])
// ^
OverflowCase{ast::BinaryOp::kMultiply, //
Vec(1.0_a, AFloat::Highest()), //
Mat({1.0_a, 2.0_a}, //
{1.0_a, 1.0_a}), //
"inf"},
// Overflow from addition of dot product of vector and matrix column 0
// i.e. (v[0] * m[0][0] + v[1] * m[0][1])
// ^
OverflowCase{ast::BinaryOp::kMultiply, //
Vec(AFloat::Highest(), AFloat::Highest()), //
Mat({1.0_a, 1.0_a}, //
{1.0_a, 1.0_a}), //
"inf"},
// matrix-matrix multiply
// Overflow from first multiplication of dot product of lhs row 0 and rhs column 0
// i.e. m1[0][0] * m2[0][0] + m1[0][1] * m[1][0]
// ^
OverflowCase{ast::BinaryOp::kMultiply, //
Mat({AFloat::Highest(), 1.0_a}, //
{1.0_a, 1.0_a}), //
Mat({2.0_a, 1.0_a}, //
{1.0_a, 1.0_a}), //
"inf"},
// Overflow from second multiplication of dot product of lhs row 0 and rhs column 0
// i.e. m1[0][0] * m2[0][0] + m1[0][1] * m[1][0]
// ^
OverflowCase{ast::BinaryOp::kMultiply, //
Mat({1.0_a, AFloat::Highest()}, //
{1.0_a, 1.0_a}), //
Mat({1.0_a, 1.0_a}, //
{2.0_a, 1.0_a}), //
"inf"},
// Overflow from addition of dot product of lhs row 0 and rhs column 0
// i.e. m1[0][0] * m2[0][0] + m1[0][1] * m[1][0]
// ^
OverflowCase{ast::BinaryOp::kMultiply, //
Mat({AFloat::Highest(), 1.0_a}, //
{AFloat::Highest(), 1.0_a}), //
Mat({1.0_a, 1.0_a}, //
{1.0_a, 1.0_a}), //
"inf"}
));
TEST_F(ResolverConstEvalTest, BinaryAbstractAddOverflow_AInt) {
GlobalConst("c", Add(Source{{1, 1}}, Expr(AInt::Highest()), 1_a));
EXPECT_FALSE(r()->Resolve());

View File

@@ -9572,108 +9572,108 @@ constexpr OverloadInfo kOverloads[] = {
/* num parameters */ 2,
/* num template types */ 1,
/* num template numbers */ 0,
/* template types */ &kTemplateTypes[15],
/* template types */ &kTemplateTypes[13],
/* template numbers */ &kTemplateNumbers[10],
/* parameters */ &kParameters[727],
/* return matcher indices */ &kMatcherIndices[1],
/* flags */ OverloadFlags(OverloadFlag::kIsOperator, OverloadFlag::kSupportsVertexPipeline, OverloadFlag::kSupportsFragmentPipeline, OverloadFlag::kSupportsComputePipeline),
/* const eval */ nullptr,
/* const eval */ &ConstEval::OpMultiply,
},
{
/* [118] */
/* num parameters */ 2,
/* num template types */ 1,
/* num template numbers */ 1,
/* template types */ &kTemplateTypes[15],
/* template types */ &kTemplateTypes[13],
/* template numbers */ &kTemplateNumbers[6],
/* parameters */ &kParameters[725],
/* return matcher indices */ &kMatcherIndices[30],
/* flags */ OverloadFlags(OverloadFlag::kIsOperator, OverloadFlag::kSupportsVertexPipeline, OverloadFlag::kSupportsFragmentPipeline, OverloadFlag::kSupportsComputePipeline),
/* const eval */ nullptr,
/* const eval */ &ConstEval::OpMultiply,
},
{
/* [119] */
/* num parameters */ 2,
/* num template types */ 1,
/* num template numbers */ 1,
/* template types */ &kTemplateTypes[15],
/* template types */ &kTemplateTypes[13],
/* template numbers */ &kTemplateNumbers[6],
/* parameters */ &kParameters[723],
/* return matcher indices */ &kMatcherIndices[30],
/* flags */ OverloadFlags(OverloadFlag::kIsOperator, OverloadFlag::kSupportsVertexPipeline, OverloadFlag::kSupportsFragmentPipeline, OverloadFlag::kSupportsComputePipeline),
/* const eval */ nullptr,
/* const eval */ &ConstEval::OpMultiply,
},
{
/* [120] */
/* num parameters */ 2,
/* num template types */ 1,
/* num template numbers */ 1,
/* template types */ &kTemplateTypes[15],
/* template types */ &kTemplateTypes[13],
/* template numbers */ &kTemplateNumbers[6],
/* parameters */ &kParameters[721],
/* return matcher indices */ &kMatcherIndices[30],
/* flags */ OverloadFlags(OverloadFlag::kIsOperator, OverloadFlag::kSupportsVertexPipeline, OverloadFlag::kSupportsFragmentPipeline, OverloadFlag::kSupportsComputePipeline),
/* const eval */ nullptr,
/* const eval */ &ConstEval::OpMultiply,
},
{
/* [121] */
/* num parameters */ 2,
/* num template types */ 1,
/* num template numbers */ 2,
/* template types */ &kTemplateTypes[11],
/* template types */ &kTemplateTypes[12],
/* template numbers */ &kTemplateNumbers[6],
/* parameters */ &kParameters[719],
/* return matcher indices */ &kMatcherIndices[10],
/* flags */ OverloadFlags(OverloadFlag::kIsOperator, OverloadFlag::kSupportsVertexPipeline, OverloadFlag::kSupportsFragmentPipeline, OverloadFlag::kSupportsComputePipeline),
/* const eval */ nullptr,
/* const eval */ &ConstEval::OpMultiply,
},
{
/* [122] */
/* num parameters */ 2,
/* num template types */ 1,
/* num template numbers */ 2,
/* template types */ &kTemplateTypes[11],
/* template types */ &kTemplateTypes[12],
/* template numbers */ &kTemplateNumbers[6],
/* parameters */ &kParameters[717],
/* return matcher indices */ &kMatcherIndices[10],
/* flags */ OverloadFlags(OverloadFlag::kIsOperator, OverloadFlag::kSupportsVertexPipeline, OverloadFlag::kSupportsFragmentPipeline, OverloadFlag::kSupportsComputePipeline),
/* const eval */ nullptr,
/* const eval */ &ConstEval::OpMultiply,
},
{
/* [123] */
/* num parameters */ 2,
/* num template types */ 1,
/* num template numbers */ 2,
/* template types */ &kTemplateTypes[11],
/* template types */ &kTemplateTypes[12],
/* template numbers */ &kTemplateNumbers[1],
/* parameters */ &kParameters[715],
/* return matcher indices */ &kMatcherIndices[69],
/* flags */ OverloadFlags(OverloadFlag::kIsOperator, OverloadFlag::kSupportsVertexPipeline, OverloadFlag::kSupportsFragmentPipeline, OverloadFlag::kSupportsComputePipeline),
/* const eval */ nullptr,
/* const eval */ &ConstEval::OpMultiplyMatVec,
},
{
/* [124] */
/* num parameters */ 2,
/* num template types */ 1,
/* num template numbers */ 2,
/* template types */ &kTemplateTypes[11],
/* template types */ &kTemplateTypes[12],
/* template numbers */ &kTemplateNumbers[1],
/* parameters */ &kParameters[713],
/* return matcher indices */ &kMatcherIndices[30],
/* flags */ OverloadFlags(OverloadFlag::kIsOperator, OverloadFlag::kSupportsVertexPipeline, OverloadFlag::kSupportsFragmentPipeline, OverloadFlag::kSupportsComputePipeline),
/* const eval */ nullptr,
/* const eval */ &ConstEval::OpMultiplyVecMat,
},
{
/* [125] */
/* num parameters */ 2,
/* num template types */ 1,
/* num template numbers */ 3,
/* template types */ &kTemplateTypes[11],
/* template types */ &kTemplateTypes[12],
/* template numbers */ &kTemplateNumbers[0],
/* parameters */ &kParameters[711],
/* return matcher indices */ &kMatcherIndices[22],
/* flags */ OverloadFlags(OverloadFlag::kIsOperator, OverloadFlag::kSupportsVertexPipeline, OverloadFlag::kSupportsFragmentPipeline, OverloadFlag::kSupportsComputePipeline),
/* const eval */ nullptr,
/* const eval */ &ConstEval::OpMultiplyMatMat,
},
{
/* [126] */
@@ -14630,15 +14630,15 @@ constexpr IntrinsicInfo kBinaryOperators[] = {
},
{
/* [2] */
/* op *<T : fiu32_f16>(T, T) -> T */
/* op *<T : fiu32_f16, N : num>(vec<N, T>, vec<N, T>) -> vec<N, T> */
/* op *<T : fiu32_f16, N : num>(vec<N, T>, T) -> vec<N, T> */
/* op *<T : fiu32_f16, N : num>(T, vec<N, T>) -> vec<N, T> */
/* op *<T : f32_f16, N : num, M : num>(T, mat<N, M, T>) -> mat<N, M, T> */
/* op *<T : f32_f16, N : num, M : num>(mat<N, M, T>, T) -> mat<N, M, T> */
/* op *<T : f32_f16, C : num, R : num>(mat<C, R, T>, vec<C, T>) -> vec<R, T> */
/* op *<T : f32_f16, C : num, R : num>(vec<R, T>, mat<C, R, T>) -> vec<C, T> */
/* op *<T : f32_f16, K : num, C : num, R : num>(mat<K, R, T>, mat<C, K, T>) -> mat<C, R, T> */
/* op *<T : fia_fiu32_f16>(T, T) -> T */
/* op *<T : fia_fiu32_f16, N : num>(vec<N, T>, vec<N, T>) -> vec<N, T> */
/* op *<T : fia_fiu32_f16, N : num>(vec<N, T>, T) -> vec<N, T> */
/* op *<T : fia_fiu32_f16, N : num>(T, vec<N, T>) -> vec<N, T> */
/* op *<T : fa_f32_f16, N : num, M : num>(T, mat<N, M, T>) -> mat<N, M, T> */
/* op *<T : fa_f32_f16, N : num, M : num>(mat<N, M, T>, T) -> mat<N, M, T> */
/* op *<T : fa_f32_f16, C : num, R : num>(mat<C, R, T>, vec<C, T>) -> vec<R, T> */
/* op *<T : fa_f32_f16, C : num, R : num>(vec<R, T>, mat<C, R, T>) -> vec<C, T> */
/* op *<T : fa_f32_f16, K : num, C : num, R : num>(mat<K, R, T>, mat<C, K, T>) -> mat<C, R, T> */
/* num overloads */ 9,
/* overloads */ &kOverloads[117],
},

View File

@@ -641,15 +641,15 @@ TEST_F(IntrinsicTableTest, MismatchBinaryOp) {
EXPECT_EQ(Diagnostics().str(), R"(12:34 error: no matching overload for operator * (f32, bool)
9 candidate operators:
operator * (T, T) -> T where: T is f32, i32, u32 or f16
operator * (vecN<T>, T) -> vecN<T> where: T is f32, i32, u32 or f16
operator * (T, vecN<T>) -> vecN<T> where: T is f32, i32, u32 or f16
operator * (T, matNxM<T>) -> matNxM<T> where: T is f32 or f16
operator * (matNxM<T>, T) -> matNxM<T> where: T is f32 or f16
operator * (vecN<T>, vecN<T>) -> vecN<T> where: T is f32, i32, u32 or f16
operator * (matCxR<T>, vecC<T>) -> vecR<T> where: T is f32 or f16
operator * (vecR<T>, matCxR<T>) -> vecC<T> where: T is f32 or f16
operator * (matKxR<T>, matCxK<T>) -> matCxR<T> where: T is f32 or f16
operator * (T, T) -> T where: T is abstract-float, abstract-int, f32, i32, u32 or f16
operator * (vecN<T>, T) -> vecN<T> where: T is abstract-float, abstract-int, f32, i32, u32 or f16
operator * (T, vecN<T>) -> vecN<T> where: T is abstract-float, abstract-int, f32, i32, u32 or f16
operator * (T, matNxM<T>) -> matNxM<T> where: T is abstract-float, f32 or f16
operator * (matNxM<T>, T) -> matNxM<T> where: T is abstract-float, f32 or f16
operator * (vecN<T>, vecN<T>) -> vecN<T> where: T is abstract-float, abstract-int, f32, i32, u32 or f16
operator * (matCxR<T>, vecC<T>) -> vecR<T> where: T is abstract-float, f32 or f16
operator * (vecR<T>, matCxR<T>) -> vecC<T> where: T is abstract-float, f32 or f16
operator * (matKxR<T>, matCxK<T>) -> matCxR<T> where: T is abstract-float, f32 or f16
)");
}
@@ -673,15 +673,15 @@ TEST_F(IntrinsicTableTest, MismatchCompoundOp) {
EXPECT_EQ(Diagnostics().str(), R"(12:34 error: no matching overload for operator *= (f32, bool)
9 candidate operators:
operator *= (T, T) -> T where: T is f32, i32, u32 or f16
operator *= (vecN<T>, T) -> vecN<T> where: T is f32, i32, u32 or f16
operator *= (T, vecN<T>) -> vecN<T> where: T is f32, i32, u32 or f16
operator *= (T, matNxM<T>) -> matNxM<T> where: T is f32 or f16
operator *= (matNxM<T>, T) -> matNxM<T> where: T is f32 or f16
operator *= (vecN<T>, vecN<T>) -> vecN<T> where: T is f32, i32, u32 or f16
operator *= (matCxR<T>, vecC<T>) -> vecR<T> where: T is f32 or f16
operator *= (vecR<T>, matCxR<T>) -> vecC<T> where: T is f32 or f16
operator *= (matKxR<T>, matCxK<T>) -> matCxR<T> where: T is f32 or f16
operator *= (T, T) -> T where: T is abstract-float, abstract-int, f32, i32, u32 or f16
operator *= (vecN<T>, T) -> vecN<T> where: T is abstract-float, abstract-int, f32, i32, u32 or f16
operator *= (T, vecN<T>) -> vecN<T> where: T is abstract-float, abstract-int, f32, i32, u32 or f16
operator *= (T, matNxM<T>) -> matNxM<T> where: T is abstract-float, f32 or f16
operator *= (matNxM<T>, T) -> matNxM<T> where: T is abstract-float, f32 or f16
operator *= (vecN<T>, vecN<T>) -> vecN<T> where: T is abstract-float, abstract-int, f32, i32, u32 or f16
operator *= (matCxR<T>, vecC<T>) -> vecR<T> where: T is abstract-float, f32 or f16
operator *= (vecR<T>, matCxR<T>) -> vecC<T> where: T is abstract-float, f32 or f16
operator *= (matKxR<T>, matCxK<T>) -> matCxR<T> where: T is abstract-float, f32 or f16
)");
}

View File

@@ -17,6 +17,7 @@
#include <ostream>
#include <variant>
#include "src/tint/debug.h"
namespace tint::utils {
@@ -36,6 +37,9 @@ struct [[nodiscard]] Result {
static_assert(!std::is_same_v<SUCCESS_TYPE, FAILURE_TYPE>,
"Result must not have the same type for SUCCESS_TYPE and FAILURE_TYPE");
/// Default constructor initializes to invalid state
Result() : value(std::monostate{}) {}
/// Constructor
/// @param success the success result
Result(const SUCCESS_TYPE& success) // NOLINT(runtime/explicit):
@@ -47,27 +51,43 @@ struct [[nodiscard]] Result {
: value{failure} {}
/// @returns true if the result was a success
operator bool() const { return std::holds_alternative<SUCCESS_TYPE>(value); }
operator bool() const {
Validate();
return std::holds_alternative<SUCCESS_TYPE>(value);
}
/// @returns true if the result was a failure
bool operator!() const { return std::holds_alternative<FAILURE_TYPE>(value); }
bool operator!() const {
Validate();
return std::holds_alternative<FAILURE_TYPE>(value);
}
/// @returns the success value
/// @warning attempting to call this when the Result holds an failure will result in UB.
const SUCCESS_TYPE* operator->() const { return &std::get<SUCCESS_TYPE>(value); }
const SUCCESS_TYPE* operator->() const {
Validate();
return &(Get());
}
/// @returns the success value
/// @warning attempting to call this when the Result holds an failure value will result in UB.
const SUCCESS_TYPE& Get() const { return std::get<SUCCESS_TYPE>(value); }
const SUCCESS_TYPE& Get() const {
Validate();
return std::get<SUCCESS_TYPE>(value);
}
/// @returns the failure value
/// @warning attempting to call this when the Result holds a success value will result in UB.
const FAILURE_TYPE& Failure() const { return std::get<FAILURE_TYPE>(value); }
const FAILURE_TYPE& Failure() const {
Validate();
return std::get<FAILURE_TYPE>(value);
}
/// Equality operator
/// @param val the value to compare this Result to
/// @returns true if this result holds a success value equal to `value`
bool operator==(SUCCESS_TYPE val) const {
Validate();
if (auto* v = std::get_if<SUCCESS_TYPE>(&value)) {
return *v == val;
}
@@ -78,14 +98,18 @@ struct [[nodiscard]] Result {
/// @param val the value to compare this Result to
/// @returns true if this result holds a failure value equal to `value`
bool operator==(FAILURE_TYPE val) const {
Validate();
if (auto* v = std::get_if<FAILURE_TYPE>(&value)) {
return *v == val;
}
return false;
}
private:
void Validate() const { TINT_ASSERT(Utils, !std::holds_alternative<std::monostate>(value)); }
/// The result. Either a success of failure value.
std::variant<SUCCESS_TYPE, FAILURE_TYPE> value;
std::variant<std::monostate, SUCCESS_TYPE, FAILURE_TYPE> value;
};
/// Writes the result to the ostream.

View File

@@ -151,7 +151,8 @@ INSTANTIATE_TEST_SUITE_P(
BinaryData{"(left % right)", ast::BinaryOp::kModulo}));
TEST_F(GlslGeneratorImplTest_Binary, Multiply_VectorScalar_f32) {
auto* lhs = vec3<f32>(1_f, 1_f, 1_f);
GlobalVar("a", vec3<f32>(1_f, 1_f, 1_f), ast::StorageClass::kPrivate);
auto* lhs = Expr("a");
auto* rhs = Expr(1_f);
auto* expr = create<ast::BinaryExpression>(ast::BinaryOp::kMultiply, lhs, rhs);
@@ -162,13 +163,14 @@ TEST_F(GlslGeneratorImplTest_Binary, Multiply_VectorScalar_f32) {
std::stringstream out;
EXPECT_TRUE(gen.EmitExpression(out, expr)) << gen.error();
EXPECT_EQ(out.str(), "(vec3(1.0f) * 1.0f)");
EXPECT_EQ(out.str(), "(a * 1.0f)");
}
TEST_F(GlslGeneratorImplTest_Binary, Multiply_VectorScalar_f16) {
Enable(ast::Extension::kF16);
auto* lhs = vec3<f16>(1_h, 1_h, 1_h);
GlobalVar("a", vec3<f16>(1_h, 1_h, 1_h), ast::StorageClass::kPrivate);
auto* lhs = Expr("a");
auto* rhs = Expr(1_h);
auto* expr = create<ast::BinaryExpression>(ast::BinaryOp::kMultiply, lhs, rhs);
@@ -179,12 +181,13 @@ TEST_F(GlslGeneratorImplTest_Binary, Multiply_VectorScalar_f16) {
std::stringstream out;
EXPECT_TRUE(gen.EmitExpression(out, expr)) << gen.error();
EXPECT_EQ(out.str(), "(f16vec3(1.0hf) * 1.0hf)");
EXPECT_EQ(out.str(), "(a * 1.0hf)");
}
TEST_F(GlslGeneratorImplTest_Binary, Multiply_ScalarVector_f32) {
GlobalVar("a", vec3<f32>(1_f, 1_f, 1_f), ast::StorageClass::kPrivate);
auto* lhs = Expr(1_f);
auto* rhs = vec3<f32>(1_f, 1_f, 1_f);
auto* rhs = Expr("a");
auto* expr = create<ast::BinaryExpression>(ast::BinaryOp::kMultiply, lhs, rhs);
@@ -194,14 +197,15 @@ TEST_F(GlslGeneratorImplTest_Binary, Multiply_ScalarVector_f32) {
std::stringstream out;
EXPECT_TRUE(gen.EmitExpression(out, expr)) << gen.error();
EXPECT_EQ(out.str(), "(1.0f * vec3(1.0f))");
EXPECT_EQ(out.str(), "(1.0f * a)");
}
TEST_F(GlslGeneratorImplTest_Binary, Multiply_ScalarVector_f16) {
Enable(ast::Extension::kF16);
GlobalVar("a", vec3<f16>(1_h, 1_h, 1_h), ast::StorageClass::kPrivate);
auto* lhs = Expr(1_h);
auto* rhs = vec3<f16>(1_h, 1_h, 1_h);
auto* rhs = Expr("a");
auto* expr = create<ast::BinaryExpression>(ast::BinaryOp::kMultiply, lhs, rhs);
@@ -211,7 +215,7 @@ TEST_F(GlslGeneratorImplTest_Binary, Multiply_ScalarVector_f16) {
std::stringstream out;
EXPECT_TRUE(gen.EmitExpression(out, expr)) << gen.error();
EXPECT_EQ(out.str(), "(1.0hf * f16vec3(1.0hf))");
EXPECT_EQ(out.str(), "(1.0hf * a)");
}
TEST_F(GlslGeneratorImplTest_Binary, Multiply_MatrixScalar_f32) {

View File

@@ -184,7 +184,7 @@ TEST_F(HlslGeneratorImplTest_Binary, Multiply_VectorScalar_f32) {
std::stringstream out;
EXPECT_TRUE(gen.EmitExpression(out, expr)) << gen.error();
EXPECT_EQ(out.str(), "((1.0f).xxx * 1.0f)");
EXPECT_EQ(out.str(), "(1.0f).xxx");
}
TEST_F(HlslGeneratorImplTest_Binary, Multiply_VectorScalar_f16) {
@@ -201,7 +201,7 @@ TEST_F(HlslGeneratorImplTest_Binary, Multiply_VectorScalar_f16) {
std::stringstream out;
EXPECT_TRUE(gen.EmitExpression(out, expr)) << gen.error();
EXPECT_EQ(out.str(), "((float16_t(1.0h)).xxx * float16_t(1.0h))");
EXPECT_EQ(out.str(), "(float16_t(1.0h)).xxx");
}
TEST_F(HlslGeneratorImplTest_Binary, Multiply_ScalarVector_f32) {
@@ -216,7 +216,7 @@ TEST_F(HlslGeneratorImplTest_Binary, Multiply_ScalarVector_f32) {
std::stringstream out;
EXPECT_TRUE(gen.EmitExpression(out, expr)) << gen.error();
EXPECT_EQ(out.str(), "(1.0f * (1.0f).xxx)");
EXPECT_EQ(out.str(), "(1.0f).xxx");
}
TEST_F(HlslGeneratorImplTest_Binary, Multiply_ScalarVector_f16) {
@@ -233,7 +233,7 @@ TEST_F(HlslGeneratorImplTest_Binary, Multiply_ScalarVector_f16) {
std::stringstream out;
EXPECT_TRUE(gen.EmitExpression(out, expr)) << gen.error();
EXPECT_EQ(out.str(), "(float16_t(1.0h) * (float16_t(1.0h)).xxx)");
EXPECT_EQ(out.str(), "(float16_t(1.0h)).xxx");
}
TEST_F(HlslGeneratorImplTest_Binary, Multiply_MatrixScalar_f32) {