mirror of
https://github.com/encounter/dawn-cmake.git
synced 2025-12-10 14:08:04 +00:00
tint: Implement const eval of binary multiply
Bug: tint:1581 Change-Id: I70ff40ed4d8faf0a665824fef936ffbafb3f0948 Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/99362 Kokoro: Kokoro <noreply+kokoro@google.com> Commit-Queue: Antonio Maiorano <amaiorano@google.com> Reviewed-by: Ben Clayton <bclayton@google.com>
This commit is contained in:
committed by
Dawn LUCI CQ
parent
ae6f76fe3a
commit
c20c5dfb4a
@@ -898,15 +898,15 @@ op ! <N: num> (vec<N, bool>) -> vec<N, bool>
|
||||
@const op - <T: fia_fiu32_f16, N: num> (T, vec<N, T>) -> vec<N, T>
|
||||
@const op - <T: fa_f32_f16, N: num, M: num> (mat<N, M, T>, mat<N, M, T>) -> mat<N, M, T>
|
||||
|
||||
op * <T: fiu32_f16>(T, T) -> T
|
||||
op * <T: fiu32_f16, N: num> (vec<N, T>, vec<N, T>) -> vec<N, T>
|
||||
op * <T: fiu32_f16, N: num> (vec<N, T>, T) -> vec<N, T>
|
||||
op * <T: fiu32_f16, N: num> (T, vec<N, T>) -> vec<N, T>
|
||||
op * <T: f32_f16, N: num, M: num> (T, mat<N, M, T>) -> mat<N, M, T>
|
||||
op * <T: f32_f16, N: num, M: num> (mat<N, M, T>, T) -> mat<N, M, T>
|
||||
op * <T: f32_f16, C: num, R: num> (mat<C, R, T>, vec<C, T>) -> vec<R, T>
|
||||
op * <T: f32_f16, C: num, R: num> (vec<R, T>, mat<C, R, T>) -> vec<C, T>
|
||||
op * <T: f32_f16, K: num, C: num, R: num> (mat<K, R, T>, mat<C, K, T>) -> mat<C, R, T>
|
||||
@const("Multiply") op * <T: fia_fiu32_f16>(T, T) -> T
|
||||
@const("Multiply") op * <T: fia_fiu32_f16, N: num> (vec<N, T>, vec<N, T>) -> vec<N, T>
|
||||
@const("Multiply") op * <T: fia_fiu32_f16, N: num> (vec<N, T>, T) -> vec<N, T>
|
||||
@const("Multiply") op * <T: fia_fiu32_f16, N: num> (T, vec<N, T>) -> vec<N, T>
|
||||
@const("Multiply") op * <T: fa_f32_f16, N: num, M: num> (T, mat<N, M, T>) -> mat<N, M, T>
|
||||
@const("Multiply") op * <T: fa_f32_f16, N: num, M: num> (mat<N, M, T>, T) -> mat<N, M, T>
|
||||
@const("MultiplyMatVec") op * <T: fa_f32_f16, C: num, R: num> (mat<C, R, T>, vec<C, T>) -> vec<R, T>
|
||||
@const("MultiplyVecMat") op * <T: fa_f32_f16, C: num, R: num> (vec<R, T>, mat<C, R, T>) -> vec<C, T>
|
||||
@const("MultiplyMatMat") op * <T: fa_f32_f16, K: num, C: num, R: num> (mat<K, R, T>, mat<C, K, T>) -> mat<C, R, T>
|
||||
|
||||
op / <T: fiu32_f16>(T, T) -> T
|
||||
op / <T: fiu32_f16, N: num> (vec<N, T>, vec<N, T>) -> vec<N, T>
|
||||
|
||||
@@ -22,6 +22,7 @@
|
||||
#include <optional>
|
||||
#include <ostream>
|
||||
|
||||
#include "src/tint/traits.h"
|
||||
#include "src/tint/utils/compiler_macros.h"
|
||||
#include "src/tint/utils/result.h"
|
||||
|
||||
@@ -33,6 +34,14 @@ struct Number;
|
||||
} // namespace tint
|
||||
|
||||
namespace tint::detail {
|
||||
/// Base template for IsNumber
|
||||
template <typename T>
|
||||
struct IsNumber : std::false_type {};
|
||||
|
||||
/// Specialization for IsNumber
|
||||
template <typename T>
|
||||
struct IsNumber<Number<T>> : std::true_type {};
|
||||
|
||||
/// An empty structure used as a unique template type for Number when
|
||||
/// specializing for the f16 type.
|
||||
struct NumberKindF16 {};
|
||||
@@ -68,6 +77,10 @@ constexpr bool IsInteger = std::is_integral_v<T>;
|
||||
template <typename T>
|
||||
constexpr bool IsNumeric = IsInteger<T> || IsFloatingPoint<T>;
|
||||
|
||||
/// Evaluates to true iff T is a Number
|
||||
template <typename T>
|
||||
constexpr bool IsNumber = detail::IsNumber<T>::value;
|
||||
|
||||
/// Resolves to the underlying type for a Number.
|
||||
template <typename T>
|
||||
using UnwrapNumber = typename detail::NumberUnwrapper<T>::type;
|
||||
@@ -236,6 +249,26 @@ using f32 = Number<float>;
|
||||
/// However since C++ don't have native binary16 type, the value is stored as float.
|
||||
using f16 = Number<detail::NumberKindF16>;
|
||||
|
||||
/// @returns the friendly name of Number type T
|
||||
template <typename T, typename = traits::EnableIf<IsNumber<T>>>
|
||||
const char* FriendlyName() {
|
||||
if constexpr (std::is_same_v<T, AInt>) {
|
||||
return "abstract-int";
|
||||
} else if constexpr (std::is_same_v<T, AFloat>) {
|
||||
return "abstract-float";
|
||||
} else if constexpr (std::is_same_v<T, i32>) {
|
||||
return "i32";
|
||||
} else if constexpr (std::is_same_v<T, u32>) {
|
||||
return "u32";
|
||||
} else if constexpr (std::is_same_v<T, f32>) {
|
||||
return "f32";
|
||||
} else if constexpr (std::is_same_v<T, f16>) {
|
||||
return "f16";
|
||||
} else {
|
||||
static_assert(!sizeof(T), "Unhandled type");
|
||||
}
|
||||
}
|
||||
|
||||
/// Enumerator of failure reasons when converting from one number to another.
|
||||
enum class ConversionFailure {
|
||||
kExceedsPositiveLimit, // The value was too big (+'ve) to fit in the target type
|
||||
@@ -437,6 +470,15 @@ inline std::optional<AInt> CheckedMul(AInt a, AInt b) {
|
||||
return AInt(result);
|
||||
}
|
||||
|
||||
/// @returns a * b, or an empty optional if the resulting value overflowed the AFloat
|
||||
inline std::optional<AFloat> CheckedMul(AFloat a, AFloat b) {
|
||||
auto result = a.value * b.value;
|
||||
if (!std::isfinite(result)) {
|
||||
return {};
|
||||
}
|
||||
return AFloat{result};
|
||||
}
|
||||
|
||||
/// @returns a * b + c, or an empty optional if the value overflowed the AInt
|
||||
inline std::optional<AInt> CheckedMadd(AInt a, AInt b, AInt c) {
|
||||
// https://gcc.gnu.org/bugzilla/show_bug.cgi?id=80635
|
||||
|
||||
@@ -38,6 +38,7 @@
|
||||
#include "src/tint/sem/vector.h"
|
||||
#include "src/tint/utils/compiler_macros.h"
|
||||
#include "src/tint/utils/map.h"
|
||||
#include "src/tint/utils/scoped_assignment.h"
|
||||
#include "src/tint/utils/transform.h"
|
||||
|
||||
using namespace tint::number_suffixes; // NOLINT
|
||||
@@ -508,11 +509,202 @@ const Constant* TransformBinaryElements(ProgramBuilder& builder,
|
||||
auto* ty = n0 > n1 ? c0->Type() : c1->Type();
|
||||
return CreateComposite(builder, ty, std::move(els));
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
ConstEval::ConstEval(ProgramBuilder& b) : builder(b) {}
|
||||
|
||||
template <typename NumberT>
|
||||
utils::Result<NumberT> ConstEval::Add(NumberT a, NumberT b) {
|
||||
using T = UnwrapNumber<NumberT>;
|
||||
auto add_values = [](T lhs, T rhs) {
|
||||
if constexpr (std::is_integral_v<T> && std::is_signed_v<T>) {
|
||||
// Ensure no UB for signed overflow
|
||||
using UT = std::make_unsigned_t<T>;
|
||||
return static_cast<T>(static_cast<UT>(lhs) + static_cast<UT>(rhs));
|
||||
} else {
|
||||
return lhs + rhs;
|
||||
}
|
||||
};
|
||||
NumberT result;
|
||||
if constexpr (std::is_same_v<NumberT, AInt> || std::is_same_v<NumberT, AFloat>) {
|
||||
// Check for over/underflow for abstract values
|
||||
if (auto r = CheckedAdd(a, b)) {
|
||||
result = r->value;
|
||||
} else {
|
||||
AddError("'" + std::to_string(add_values(a.value, b.value)) +
|
||||
"' cannot be represented as '" + FriendlyName<NumberT>() + "'",
|
||||
*current_source);
|
||||
return utils::Failure;
|
||||
}
|
||||
} else {
|
||||
result = add_values(a.value, b.value);
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
template <typename NumberT>
|
||||
utils::Result<NumberT> ConstEval::Mul(NumberT a, NumberT b) {
|
||||
using T = UnwrapNumber<NumberT>;
|
||||
auto mul_values = [](T lhs, T rhs) { //
|
||||
if constexpr (std::is_integral_v<T> && std::is_signed_v<T>) {
|
||||
// For signed integrals, avoid C++ UB by multiplying as unsigned
|
||||
using UT = std::make_unsigned_t<T>;
|
||||
return static_cast<T>(static_cast<UT>(lhs) * static_cast<UT>(rhs));
|
||||
} else {
|
||||
return lhs * rhs;
|
||||
}
|
||||
};
|
||||
NumberT result;
|
||||
if constexpr (std::is_same_v<NumberT, AInt> || std::is_same_v<NumberT, AFloat>) {
|
||||
// Check for over/underflow for abstract values
|
||||
if (auto r = CheckedMul(a, b)) {
|
||||
result = r->value;
|
||||
} else {
|
||||
AddError("'" + std::to_string(mul_values(a.value, b.value)) +
|
||||
"' cannot be represented as '" + FriendlyName<NumberT>() + "'",
|
||||
*current_source);
|
||||
return utils::Failure;
|
||||
}
|
||||
} else {
|
||||
result = mul_values(a.value, b.value);
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
template <typename NumberT>
|
||||
utils::Result<NumberT> ConstEval::Dot2(NumberT a1, NumberT a2, NumberT b1, NumberT b2) {
|
||||
auto r1 = Mul(a1, b1);
|
||||
if (!r1) {
|
||||
return utils::Failure;
|
||||
}
|
||||
auto r2 = Mul(a2, b2);
|
||||
if (!r2) {
|
||||
return utils::Failure;
|
||||
}
|
||||
auto r = Add(r1.Get(), r2.Get());
|
||||
if (!r) {
|
||||
return utils::Failure;
|
||||
}
|
||||
return r;
|
||||
}
|
||||
|
||||
template <typename NumberT>
|
||||
utils::Result<NumberT> ConstEval::Dot3(NumberT a1,
|
||||
NumberT a2,
|
||||
NumberT a3,
|
||||
NumberT b1,
|
||||
NumberT b2,
|
||||
NumberT b3) {
|
||||
auto r1 = Mul(a1, b1);
|
||||
if (!r1) {
|
||||
return utils::Failure;
|
||||
}
|
||||
auto r2 = Mul(a2, b2);
|
||||
if (!r2) {
|
||||
return utils::Failure;
|
||||
}
|
||||
auto r3 = Mul(a3, b3);
|
||||
if (!r3) {
|
||||
return utils::Failure;
|
||||
}
|
||||
auto r = Add(r1.Get(), r2.Get());
|
||||
if (!r) {
|
||||
return utils::Failure;
|
||||
}
|
||||
r = Add(r.Get(), r3.Get());
|
||||
if (!r) {
|
||||
return utils::Failure;
|
||||
}
|
||||
return r;
|
||||
}
|
||||
|
||||
template <typename NumberT>
|
||||
utils::Result<NumberT> ConstEval::Dot4(NumberT a1,
|
||||
NumberT a2,
|
||||
NumberT a3,
|
||||
NumberT a4,
|
||||
NumberT b1,
|
||||
NumberT b2,
|
||||
NumberT b3,
|
||||
NumberT b4) {
|
||||
auto r1 = Mul(a1, b1);
|
||||
if (!r1) {
|
||||
return utils::Failure;
|
||||
}
|
||||
auto r2 = Mul(a2, b2);
|
||||
if (!r2) {
|
||||
return utils::Failure;
|
||||
}
|
||||
auto r3 = Mul(a3, b3);
|
||||
if (!r3) {
|
||||
return utils::Failure;
|
||||
}
|
||||
auto r4 = Mul(a4, b4);
|
||||
if (!r4) {
|
||||
return utils::Failure;
|
||||
}
|
||||
auto r = Add(r1.Get(), r2.Get());
|
||||
if (!r) {
|
||||
return utils::Failure;
|
||||
}
|
||||
r = Add(r.Get(), r3.Get());
|
||||
if (!r) {
|
||||
return utils::Failure;
|
||||
}
|
||||
r = Add(r.Get(), r4.Get());
|
||||
if (!r) {
|
||||
return utils::Failure;
|
||||
}
|
||||
return r;
|
||||
}
|
||||
|
||||
auto ConstEval::AddFunc(const sem::Type* elem_ty) {
|
||||
return [=](auto a1, auto a2) -> utils::Result<const Constant*> {
|
||||
if (auto r = Add(a1, a2)) {
|
||||
return CreateElement(builder, elem_ty, r.Get());
|
||||
}
|
||||
return utils::Failure;
|
||||
};
|
||||
}
|
||||
|
||||
auto ConstEval::MulFunc(const sem::Type* elem_ty) {
|
||||
return [=](auto a1, auto a2) -> utils::Result<const Constant*> {
|
||||
if (auto r = Mul(a1, a2)) {
|
||||
return CreateElement(builder, elem_ty, r.Get());
|
||||
}
|
||||
return utils::Failure;
|
||||
};
|
||||
}
|
||||
|
||||
auto ConstEval::Dot2Func(const sem::Type* elem_ty) {
|
||||
return [=](auto a1, auto a2, auto b1, auto b2) -> utils::Result<const Constant*> {
|
||||
if (auto r = Dot2(a1, a2, b1, b2)) {
|
||||
return CreateElement(builder, elem_ty, r.Get());
|
||||
}
|
||||
return utils::Failure;
|
||||
};
|
||||
}
|
||||
|
||||
auto ConstEval::Dot3Func(const sem::Type* elem_ty) {
|
||||
return [=](auto a1, auto a2, auto a3, auto b1, auto b2,
|
||||
auto b3) -> utils::Result<const Constant*> {
|
||||
if (auto r = Dot3(a1, a2, a3, b1, b2, b3)) {
|
||||
return CreateElement(builder, elem_ty, r.Get());
|
||||
}
|
||||
return utils::Failure;
|
||||
};
|
||||
}
|
||||
|
||||
auto ConstEval::Dot4Func(const sem::Type* elem_ty) {
|
||||
return [=](auto a1, auto a2, auto a3, auto a4, auto b1, auto b2, auto b3,
|
||||
auto b4) -> utils::Result<const Constant*> {
|
||||
if (auto r = Dot4(a1, a2, a3, a4, b1, b2, b3, b4)) {
|
||||
return CreateElement(builder, elem_ty, r.Get());
|
||||
}
|
||||
return utils::Failure;
|
||||
};
|
||||
}
|
||||
|
||||
ConstEval::ConstantResult ConstEval::Literal(const sem::Type* ty,
|
||||
const ast::LiteralExpression* literal) {
|
||||
return Switch(
|
||||
@@ -756,42 +948,15 @@ ConstEval::ConstantResult ConstEval::OpUnaryMinus(const sem::Type*,
|
||||
return TransformElements(builder, transform, args[0]);
|
||||
}
|
||||
|
||||
ConstEval::ConstantResult ConstEval::OpPlus(const sem::Type* ty,
|
||||
ConstEval::ConstantResult ConstEval::OpPlus(const sem::Type*,
|
||||
utils::VectorRef<const sem::Constant*> args,
|
||||
const Source& source) {
|
||||
auto transform = [&](const sem::Constant* c0, const sem::Constant* c1) {
|
||||
auto create = [&](auto i, auto j) -> const Constant* {
|
||||
using NumberT = decltype(i);
|
||||
using T = UnwrapNumber<NumberT>;
|
||||
|
||||
auto add_values = [](T lhs, T rhs) {
|
||||
if constexpr (std::is_integral_v<T> && std::is_signed_v<T>) {
|
||||
// Ensure no UB for signed overflow
|
||||
using UT = std::make_unsigned_t<T>;
|
||||
return static_cast<T>(static_cast<UT>(lhs) + static_cast<UT>(rhs));
|
||||
} else {
|
||||
return lhs + rhs;
|
||||
}
|
||||
};
|
||||
|
||||
NumberT result;
|
||||
if constexpr (std::is_same_v<NumberT, AInt> || std::is_same_v<NumberT, AFloat>) {
|
||||
// Check for over/underflow for abstract values
|
||||
if (auto r = CheckedAdd(i, j)) {
|
||||
result = r->value;
|
||||
} else {
|
||||
AddError("'" + std::to_string(add_values(i.value, j.value)) +
|
||||
"' cannot be represented as '" +
|
||||
ty->FriendlyName(builder.Symbols()) + "'",
|
||||
source);
|
||||
return nullptr;
|
||||
}
|
||||
} else {
|
||||
result = add_values(i.value, j.value);
|
||||
}
|
||||
return CreateElement(builder, c0->Type(), result);
|
||||
};
|
||||
return Dispatch_fia_fiu32_f16(create, c0, c1);
|
||||
TINT_SCOPED_ASSIGNMENT(current_source, &source);
|
||||
auto transform = [&](const sem::Constant* c0, const sem::Constant* c1) -> const Constant* {
|
||||
if (auto r = Dispatch_fia_fiu32_f16(AddFunc(c0->Type()), c0, c1)) {
|
||||
return r.Get();
|
||||
}
|
||||
return nullptr;
|
||||
};
|
||||
|
||||
auto r = TransformBinaryElements(builder, transform, args[0], args[1]);
|
||||
@@ -846,6 +1011,192 @@ ConstEval::ConstantResult ConstEval::OpMinus(const sem::Type* ty,
|
||||
return r;
|
||||
}
|
||||
|
||||
ConstEval::ConstantResult ConstEval::OpMultiply(const sem::Type* /*ty*/,
|
||||
utils::VectorRef<const sem::Constant*> args,
|
||||
const Source& source) {
|
||||
TINT_SCOPED_ASSIGNMENT(current_source, &source);
|
||||
auto transform = [&](const sem::Constant* c0, const sem::Constant* c1) -> const Constant* {
|
||||
if (auto r = Dispatch_fia_fiu32_f16(MulFunc(c0->Type()), c0, c1)) {
|
||||
return r.Get();
|
||||
}
|
||||
return nullptr;
|
||||
};
|
||||
|
||||
auto r = TransformBinaryElements(builder, transform, args[0], args[1]);
|
||||
if (builder.Diagnostics().contains_errors()) {
|
||||
return utils::Failure;
|
||||
}
|
||||
return r;
|
||||
}
|
||||
|
||||
ConstEval::ConstantResult ConstEval::OpMultiplyMatVec(const sem::Type* ty,
|
||||
utils::VectorRef<const sem::Constant*> args,
|
||||
const Source& source) {
|
||||
TINT_SCOPED_ASSIGNMENT(current_source, &source);
|
||||
auto* mat_ty = args[0]->Type()->As<sem::Matrix>();
|
||||
auto* vec_ty = args[1]->Type()->As<sem::Vector>();
|
||||
auto* elem_ty = vec_ty->type();
|
||||
|
||||
auto dot = [&](const sem::Constant* m, size_t row, const sem::Constant* v) {
|
||||
utils::Result<const Constant*> result;
|
||||
switch (mat_ty->columns()) {
|
||||
case 2:
|
||||
result = Dispatch_fa_f32_f16(Dot2Func(elem_ty), //
|
||||
m->Index(0)->Index(row), //
|
||||
m->Index(1)->Index(row), //
|
||||
v->Index(0), //
|
||||
v->Index(1));
|
||||
break;
|
||||
case 3:
|
||||
result = Dispatch_fa_f32_f16(Dot3Func(elem_ty), //
|
||||
m->Index(0)->Index(row), //
|
||||
m->Index(1)->Index(row), //
|
||||
m->Index(2)->Index(row), //
|
||||
v->Index(0), //
|
||||
v->Index(1), v->Index(2));
|
||||
break;
|
||||
case 4:
|
||||
result = Dispatch_fa_f32_f16(Dot4Func(elem_ty), //
|
||||
m->Index(0)->Index(row), //
|
||||
m->Index(1)->Index(row), //
|
||||
m->Index(2)->Index(row), //
|
||||
m->Index(3)->Index(row), //
|
||||
v->Index(0), //
|
||||
v->Index(1), //
|
||||
v->Index(2), //
|
||||
v->Index(3));
|
||||
break;
|
||||
}
|
||||
return result;
|
||||
};
|
||||
|
||||
utils::Vector<const sem::Constant*, 4> result;
|
||||
for (size_t i = 0; i < mat_ty->rows(); ++i) {
|
||||
auto r = dot(args[0], i, args[1]); // matrix row i * vector
|
||||
if (!r) {
|
||||
return utils::Failure;
|
||||
}
|
||||
result.Push(r.Get());
|
||||
}
|
||||
return CreateComposite(builder, ty, result);
|
||||
}
|
||||
ConstEval::ConstantResult ConstEval::OpMultiplyVecMat(const sem::Type* ty,
|
||||
utils::VectorRef<const sem::Constant*> args,
|
||||
const Source& source) {
|
||||
TINT_SCOPED_ASSIGNMENT(current_source, &source);
|
||||
auto* vec_ty = args[0]->Type()->As<sem::Vector>();
|
||||
auto* mat_ty = args[1]->Type()->As<sem::Matrix>();
|
||||
auto* elem_ty = vec_ty->type();
|
||||
|
||||
auto dot = [&](const sem::Constant* v, const sem::Constant* m, size_t col) {
|
||||
utils::Result<const Constant*> result;
|
||||
switch (mat_ty->rows()) {
|
||||
case 2:
|
||||
result = Dispatch_fa_f32_f16(Dot2Func(elem_ty), //
|
||||
m->Index(col)->Index(0), //
|
||||
m->Index(col)->Index(1), //
|
||||
v->Index(0), //
|
||||
v->Index(1));
|
||||
break;
|
||||
case 3:
|
||||
result = Dispatch_fa_f32_f16(Dot3Func(elem_ty), //
|
||||
m->Index(col)->Index(0), //
|
||||
m->Index(col)->Index(1), //
|
||||
m->Index(col)->Index(2),
|
||||
v->Index(0), //
|
||||
v->Index(1), //
|
||||
v->Index(2));
|
||||
break;
|
||||
case 4:
|
||||
result = Dispatch_fa_f32_f16(Dot4Func(elem_ty), //
|
||||
m->Index(col)->Index(0), //
|
||||
m->Index(col)->Index(1), //
|
||||
m->Index(col)->Index(2), //
|
||||
m->Index(col)->Index(3), //
|
||||
v->Index(0), //
|
||||
v->Index(1), //
|
||||
v->Index(2), //
|
||||
v->Index(3));
|
||||
}
|
||||
return result;
|
||||
};
|
||||
|
||||
utils::Vector<const sem::Constant*, 4> result;
|
||||
for (size_t i = 0; i < mat_ty->columns(); ++i) {
|
||||
auto r = dot(args[0], args[1], i); // vector * matrix col i
|
||||
if (!r) {
|
||||
return utils::Failure;
|
||||
}
|
||||
result.Push(r.Get());
|
||||
}
|
||||
return CreateComposite(builder, ty, result);
|
||||
}
|
||||
|
||||
ConstEval::ConstantResult ConstEval::OpMultiplyMatMat(const sem::Type* ty,
|
||||
utils::VectorRef<const sem::Constant*> args,
|
||||
const Source& source) {
|
||||
TINT_SCOPED_ASSIGNMENT(current_source, &source);
|
||||
auto* mat1 = args[0];
|
||||
auto* mat2 = args[1];
|
||||
auto* mat1_ty = mat1->Type()->As<sem::Matrix>();
|
||||
auto* mat2_ty = mat2->Type()->As<sem::Matrix>();
|
||||
auto* elem_ty = mat1_ty->type();
|
||||
|
||||
auto dot = [&](const sem::Constant* m1, size_t row, const sem::Constant* m2, size_t col) {
|
||||
auto m1e = [&](size_t r, size_t c) { return m1->Index(c)->Index(r); };
|
||||
auto m2e = [&](size_t r, size_t c) { return m2->Index(c)->Index(r); };
|
||||
|
||||
utils::Result<const Constant*> result;
|
||||
switch (mat1_ty->columns()) {
|
||||
case 2:
|
||||
result = Dispatch_fa_f32_f16(Dot2Func(elem_ty), //
|
||||
m1e(row, 0), //
|
||||
m1e(row, 1), //
|
||||
m2e(0, col), //
|
||||
m2e(1, col));
|
||||
break;
|
||||
case 3:
|
||||
result = Dispatch_fa_f32_f16(Dot3Func(elem_ty), //
|
||||
m1e(row, 0), //
|
||||
m1e(row, 1), //
|
||||
m1e(row, 2), //
|
||||
m2e(0, col), //
|
||||
m2e(1, col), //
|
||||
m2e(2, col));
|
||||
break;
|
||||
case 4:
|
||||
result = Dispatch_fa_f32_f16(Dot4Func(elem_ty), //
|
||||
m1e(row, 0), //
|
||||
m1e(row, 1), //
|
||||
m1e(row, 2), //
|
||||
m1e(row, 3), //
|
||||
m2e(0, col), //
|
||||
m2e(1, col), //
|
||||
m2e(2, col), //
|
||||
m2e(3, col));
|
||||
break;
|
||||
}
|
||||
return result;
|
||||
};
|
||||
|
||||
utils::Vector<const sem::Constant*, 4> result_mat;
|
||||
for (size_t c = 0; c < mat2_ty->columns(); ++c) {
|
||||
utils::Vector<const sem::Constant*, 4> col_vec;
|
||||
for (size_t r = 0; r < mat1_ty->rows(); ++r) {
|
||||
auto v = dot(mat1, r, mat2, c); // mat1 row r * mat2 col c
|
||||
if (!v) {
|
||||
return utils::Failure;
|
||||
}
|
||||
col_vec.Push(v.Get()); // mat1 row r * mat2 col c
|
||||
}
|
||||
|
||||
// Add column vector to matrix
|
||||
auto* col_vec_ty = ty->As<sem::Matrix>()->ColumnType();
|
||||
result_mat.Push(CreateComposite(builder, col_vec_ty, col_vec));
|
||||
}
|
||||
return CreateComposite(builder, ty, result_mat);
|
||||
}
|
||||
|
||||
ConstEval::ConstantResult ConstEval::atan2(const sem::Type*,
|
||||
utils::VectorRef<const sem::Constant*> args,
|
||||
const Source&) {
|
||||
|
||||
@@ -230,6 +230,42 @@ class ConstEval {
|
||||
utils::VectorRef<const sem::Constant*> args,
|
||||
const Source& source);
|
||||
|
||||
/// Multiply operator '*' for the same type on the LHS and RHS
|
||||
/// @param ty the expression type
|
||||
/// @param args the input arguments
|
||||
/// @param source the source location of the conversion
|
||||
/// @return the result value, or null if the value cannot be calculated
|
||||
ConstantResult OpMultiply(const sem::Type* ty,
|
||||
utils::VectorRef<const sem::Constant*> args,
|
||||
const Source& source);
|
||||
|
||||
/// Multiply operator '*' for matCxR<T> * vecC<T>
|
||||
/// @param ty the expression type
|
||||
/// @param args the input arguments
|
||||
/// @param source the source location of the conversion
|
||||
/// @return the result value, or null if the value cannot be calculated
|
||||
ConstantResult OpMultiplyMatVec(const sem::Type* ty,
|
||||
utils::VectorRef<const sem::Constant*> args,
|
||||
const Source& source);
|
||||
|
||||
/// Multiply operator '*' for vecR<T> * matCxR<T>
|
||||
/// @param ty the expression type
|
||||
/// @param args the input arguments
|
||||
/// @param source the source location of the conversion
|
||||
/// @return the result value, or null if the value cannot be calculated
|
||||
ConstantResult OpMultiplyVecMat(const sem::Type* ty,
|
||||
utils::VectorRef<const sem::Constant*> args,
|
||||
const Source& source);
|
||||
|
||||
/// Multiply operator '*' for matKxR<T> * matCxK<T>
|
||||
/// @param ty the expression type
|
||||
/// @param args the input arguments
|
||||
/// @param source the source location of the conversion
|
||||
/// @return the result value, or null if the value cannot be calculated
|
||||
ConstantResult OpMultiplyMatMat(const sem::Type* ty,
|
||||
utils::VectorRef<const sem::Constant*> args,
|
||||
const Source& source);
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////
|
||||
// Builtins
|
||||
////////////////////////////////////////////////////////////////////////////
|
||||
@@ -259,7 +295,97 @@ class ConstEval {
|
||||
/// Adds the given warning message to the diagnostics
|
||||
void AddWarning(const std::string& msg, const Source& source) const;
|
||||
|
||||
/// Adds two Number<T>s
|
||||
/// @param a the lhs number
|
||||
/// @param b the rhs number
|
||||
/// @returns the result number on success, or logs an error and returns Failure
|
||||
template <typename NumberT>
|
||||
utils::Result<NumberT> Add(NumberT a, NumberT b);
|
||||
|
||||
/// Multiplies two Number<T>s
|
||||
/// @param a the lhs number
|
||||
/// @param b the rhs number
|
||||
/// @returns the result number on success, or logs an error and returns Failure
|
||||
template <typename NumberT>
|
||||
utils::Result<NumberT> Mul(NumberT a, NumberT b);
|
||||
|
||||
/// Returns the dot product of (a1,a2) with (b1,b2)
|
||||
/// @param a1 component 1 of lhs vector
|
||||
/// @param a2 component 2 of lhs vector
|
||||
/// @param b1 component 1 of rhs vector
|
||||
/// @param b2 component 2 of rhs vector
|
||||
/// @returns the result number on success, or logs an error and returns Failure
|
||||
template <typename NumberT>
|
||||
utils::Result<NumberT> Dot2(NumberT a1, NumberT a2, NumberT b1, NumberT b2);
|
||||
|
||||
/// Returns the dot product of (a1,a2,a3) with (b1,b2,b3)
|
||||
/// @param a1 component 1 of lhs vector
|
||||
/// @param a2 component 2 of lhs vector
|
||||
/// @param a3 component 3 of lhs vector
|
||||
/// @param b1 component 1 of rhs vector
|
||||
/// @param b2 component 2 of rhs vector
|
||||
/// @param b3 component 3 of rhs vector
|
||||
/// @returns the result number on success, or logs an error and returns Failure
|
||||
template <typename NumberT>
|
||||
utils::Result<NumberT> Dot3(NumberT a1,
|
||||
NumberT a2,
|
||||
NumberT a3,
|
||||
NumberT b1,
|
||||
NumberT b2,
|
||||
NumberT b3);
|
||||
|
||||
/// Returns the dot product of (a1,b1,c1,d1) with (a2,b2,c2,d2)
|
||||
/// @param a1 component 1 of lhs vector
|
||||
/// @param a2 component 2 of lhs vector
|
||||
/// @param a3 component 3 of lhs vector
|
||||
/// @param a4 component 4 of lhs vector
|
||||
/// @param b1 component 1 of rhs vector
|
||||
/// @param b2 component 2 of rhs vector
|
||||
/// @param b3 component 3 of rhs vector
|
||||
/// @param b4 component 4 of rhs vector
|
||||
/// @returns the result number on success, or logs an error and returns Failure
|
||||
template <typename NumberT>
|
||||
utils::Result<NumberT> Dot4(NumberT a1,
|
||||
NumberT a2,
|
||||
NumberT a3,
|
||||
NumberT a4,
|
||||
NumberT b1,
|
||||
NumberT b2,
|
||||
NumberT b3,
|
||||
NumberT b4);
|
||||
|
||||
/// Returns a callable that calls Add, and creates a Constant with its result of type `elem_ty`
|
||||
/// if successful, or returns Failure otherwise.
|
||||
/// @param elem_ty the element type of the Constant to create on success
|
||||
/// @returns the callable function
|
||||
auto AddFunc(const sem::Type* elem_ty);
|
||||
|
||||
/// Returns a callable that calls Mul, and creates a Constant with its result of type `elem_ty`
|
||||
/// if successful, or returns Failure otherwise.
|
||||
/// @param elem_ty the element type of the Constant to create on success
|
||||
/// @returns the callable function
|
||||
auto MulFunc(const sem::Type* elem_ty);
|
||||
|
||||
/// Returns a callable that calls Dot2, and creates a Constant with its result of type `elem_ty`
|
||||
/// if successful, or returns Failure otherwise.
|
||||
/// @param elem_ty the element type of the Constant to create on success
|
||||
/// @returns the callable function
|
||||
auto Dot2Func(const sem::Type* elem_ty);
|
||||
|
||||
/// Returns a callable that calls Dot3, and creates a Constant with its result of type `elem_ty`
|
||||
/// if successful, or returns Failure otherwise.
|
||||
/// @param elem_ty the element type of the Constant to create on success
|
||||
/// @returns the callable function
|
||||
auto Dot3Func(const sem::Type* elem_ty);
|
||||
|
||||
/// Returns a callable that calls Dot4, and creates a Constant with its result of type `elem_ty`
|
||||
/// if successful, or returns Failure otherwise.
|
||||
/// @param elem_ty the element type of the Constant to create on success
|
||||
/// @returns the callable function
|
||||
auto Dot4Func(const sem::Type* elem_ty);
|
||||
|
||||
ProgramBuilder& builder;
|
||||
const Source* current_source = nullptr;
|
||||
};
|
||||
|
||||
} // namespace tint::resolver
|
||||
|
||||
@@ -15,6 +15,7 @@
|
||||
#include <cmath>
|
||||
#include <type_traits>
|
||||
|
||||
#include "gmock/gmock.h"
|
||||
#include "gtest/gtest.h"
|
||||
#include "src/tint/resolver/resolver_test_helper.h"
|
||||
#include "src/tint/sem/builtin_type.h"
|
||||
@@ -24,6 +25,8 @@
|
||||
#include "src/tint/sem/test_helper.h"
|
||||
#include "src/tint/utils/transform.h"
|
||||
|
||||
using ::testing::HasSubstr;
|
||||
|
||||
using namespace tint::number_suffixes; // NOLINT
|
||||
|
||||
namespace tint::resolver {
|
||||
@@ -74,6 +77,19 @@ auto Abs(const Number<T>& v) {
|
||||
}
|
||||
}
|
||||
|
||||
TINT_BEGIN_DISABLE_WARNING(CONSTANT_OVERFLOW);
|
||||
template <typename T>
|
||||
constexpr Number<T> Mul(Number<T> v1, Number<T> v2) {
|
||||
if constexpr (std::is_integral_v<T> && std::is_signed_v<T>) {
|
||||
// For signed integrals, avoid C++ UB by multiplying as unsigned
|
||||
using UT = std::make_unsigned_t<T>;
|
||||
return static_cast<Number<T>>(static_cast<UT>(v1) * static_cast<UT>(v2));
|
||||
} else {
|
||||
return static_cast<Number<T>>(v1 * v2);
|
||||
}
|
||||
}
|
||||
TINT_END_DISABLE_WARNING(CONSTANT_OVERFLOW);
|
||||
|
||||
// Concats any number of std::vectors
|
||||
template <typename Vec, typename... Vecs>
|
||||
[[nodiscard]] auto Concat(Vec&& v1, Vecs&&... vs) {
|
||||
@@ -3228,29 +3244,26 @@ Case C(T lhs, U rhs, V expected, bool overflow = false) {
|
||||
return Case{Val(lhs), Val(rhs), Val(expected), overflow};
|
||||
}
|
||||
|
||||
static std::ostream& operator<<(std::ostream& o, const Case& c) {
|
||||
auto print_value = [&](auto&& value) {
|
||||
std::visit(
|
||||
[&](auto&& v) {
|
||||
using ValueType = std::decay_t<decltype(v)>;
|
||||
o << ValueType::DataType::Name() << "(";
|
||||
for (auto& a : v.args.values) {
|
||||
o << std::get<typename ValueType::ElementType>(a);
|
||||
if (&a != &v.args.values.Back()) {
|
||||
o << ", ";
|
||||
}
|
||||
static std::ostream& operator<<(std::ostream& o, const Types& types) {
|
||||
std::visit(
|
||||
[&](auto&& v) {
|
||||
using ValueType = std::decay_t<decltype(v)>;
|
||||
o << ValueType::DataType::Name() << "(";
|
||||
for (auto& a : v.args.values) {
|
||||
o << std::get<typename ValueType::ElementType>(a);
|
||||
if (&a != &v.args.values.Back()) {
|
||||
o << ", ";
|
||||
}
|
||||
o << ")";
|
||||
},
|
||||
value);
|
||||
};
|
||||
o << "lhs: ";
|
||||
print_value(c.lhs);
|
||||
o << ", rhs: ";
|
||||
print_value(c.rhs);
|
||||
o << ", expected: ";
|
||||
print_value(c.expected);
|
||||
o << ", overflow: " << c.overflow;
|
||||
}
|
||||
o << ")";
|
||||
},
|
||||
types);
|
||||
return o;
|
||||
}
|
||||
|
||||
static std::ostream& operator<<(std::ostream& o, const Case& c) {
|
||||
o << "lhs: " << c.lhs << ", rhs: " << c.rhs << ", expected: " << c.expected
|
||||
<< ", overflow: " << c.overflow;
|
||||
return o;
|
||||
}
|
||||
|
||||
@@ -3281,7 +3294,6 @@ bool ForEachElemPair(const sem::Constant* a, const sem::Constant* b, Func&& f) {
|
||||
using ResolverConstEvalBinaryOpTest = ResolverTestWithParam<std::tuple<ast::BinaryOp, Case>>;
|
||||
TEST_P(ResolverConstEvalBinaryOpTest, Test) {
|
||||
Enable(ast::Extension::kF16);
|
||||
|
||||
auto op = std::get<0>(GetParam());
|
||||
auto& c = std::get<1>(GetParam());
|
||||
|
||||
@@ -3300,10 +3312,8 @@ TEST_P(ResolverConstEvalBinaryOpTest, Test) {
|
||||
auto* expr = create<ast::BinaryExpression>(op, lhs_expr, rhs_expr);
|
||||
|
||||
GlobalConst("C", expr);
|
||||
|
||||
auto* expected_expr = expected.Expr(*this);
|
||||
GlobalConst("E", expected_expr);
|
||||
|
||||
EXPECT_TRUE(r()->Resolve()) << r()->error();
|
||||
|
||||
auto* sem = Sem().Get(expr);
|
||||
@@ -3413,6 +3423,215 @@ INSTANTIATE_TEST_SUITE_P(Sub,
|
||||
OpSubFloatCases<f32>(),
|
||||
OpSubFloatCases<f16>()))));
|
||||
|
||||
template <typename T>
|
||||
std::vector<Case> OpMulScalarCases() {
|
||||
return {
|
||||
C(T{0}, T{0}, T{0}),
|
||||
C(T{1}, T{2}, T{2}),
|
||||
C(T{2}, T{3}, T{6}),
|
||||
C(Negate(T{2}), T{3}, Negate(T{6})),
|
||||
C(T::Highest(), T{1}, T::Highest()),
|
||||
C(T::Lowest(), T{1}, T::Lowest()),
|
||||
C(T::Highest(), T::Highest(), Mul(T::Highest(), T::Highest()), true),
|
||||
C(T::Lowest(), T::Lowest(), Mul(T::Lowest(), T::Lowest()), true),
|
||||
};
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
std::vector<Case> OpMulVecCases() {
|
||||
return {
|
||||
// s * vec3 = vec3
|
||||
C(Val(T{2.0}), Vec(T{1.25}, T{2.25}, T{3.25}), Vec(T{2.5}, T{4.5}, T{6.5})),
|
||||
// vec3 * s = vec3
|
||||
C(Vec(T{1.25}, T{2.25}, T{3.25}), Val(T{2.0}), Vec(T{2.5}, T{4.5}, T{6.5})),
|
||||
// vec3 * vec3 = vec3
|
||||
C(Vec(T{1.25}, T{2.25}, T{3.25}), Vec(T{2.0}, T{2.0}, T{2.0}), Vec(T{2.5}, T{4.5}, T{6.5})),
|
||||
};
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
std::vector<Case> OpMulMatCases() {
|
||||
return {
|
||||
// s * mat3x2 = mat3x2
|
||||
C(Val(T{2.25}),
|
||||
Mat({T{1.0}, T{4.0}}, //
|
||||
{T{2.0}, T{5.0}}, //
|
||||
{T{3.0}, T{6.0}}),
|
||||
Mat({T{2.25}, T{9.0}}, //
|
||||
{T{4.5}, T{11.25}}, //
|
||||
{T{6.75}, T{13.5}})),
|
||||
// mat3x2 * s = mat3x2
|
||||
C(Mat({T{1.0}, T{4.0}}, //
|
||||
{T{2.0}, T{5.0}}, //
|
||||
{T{3.0}, T{6.0}}),
|
||||
Val(T{2.25}),
|
||||
Mat({T{2.25}, T{9.0}}, //
|
||||
{T{4.5}, T{11.25}}, //
|
||||
{T{6.75}, T{13.5}})),
|
||||
// vec3 * mat2x3 = vec2
|
||||
C(Vec(T{1.25}, T{2.25}, T{3.25}), //
|
||||
Mat({T{1.0}, T{2.0}, T{3.0}}, //
|
||||
{T{4.0}, T{5.0}, T{6.0}}), //
|
||||
Vec(T{15.5}, T{35.75})),
|
||||
// mat2x3 * vec2 = vec3
|
||||
C(Mat({T{1.0}, T{2.0}, T{3.0}}, //
|
||||
{T{4.0}, T{5.0}, T{6.0}}), //
|
||||
Vec(T{1.25}, T{2.25}), //
|
||||
Vec(T{10.25}, T{13.75}, T{17.25})),
|
||||
// mat3x2 * mat2x3 = mat2x2
|
||||
C(Mat({T{1.0}, T{2.0}}, //
|
||||
{T{3.0}, T{4.0}}, //
|
||||
{T{5.0}, T{6.0}}), //
|
||||
Mat({T{1.25}, T{2.25}, T{3.25}}, //
|
||||
{T{4.25}, T{5.25}, T{6.25}}), //
|
||||
Mat({T{24.25}, T{31.0}}, //
|
||||
{T{51.25}, T{67.0}})), //
|
||||
};
|
||||
}
|
||||
|
||||
INSTANTIATE_TEST_SUITE_P(Mul,
|
||||
ResolverConstEvalBinaryOpTest,
|
||||
testing::Combine( //
|
||||
testing::Values(ast::BinaryOp::kMultiply),
|
||||
testing::ValuesIn(Concat( //
|
||||
OpMulScalarCases<AInt>(),
|
||||
OpMulScalarCases<i32>(),
|
||||
OpMulScalarCases<u32>(),
|
||||
OpMulScalarCases<AFloat>(),
|
||||
OpMulScalarCases<f32>(),
|
||||
OpMulScalarCases<f16>(),
|
||||
OpMulVecCases<AInt>(),
|
||||
OpMulVecCases<i32>(),
|
||||
OpMulVecCases<u32>(),
|
||||
OpMulVecCases<AFloat>(),
|
||||
OpMulVecCases<f32>(),
|
||||
OpMulVecCases<f16>(),
|
||||
OpMulMatCases<AFloat>(),
|
||||
OpMulMatCases<f32>(),
|
||||
OpMulMatCases<f16>()))));
|
||||
|
||||
// Tests for errors on overflow/underflow of binary operations with abstract numbers
|
||||
struct OverflowCase {
|
||||
ast::BinaryOp op;
|
||||
Types lhs;
|
||||
Types rhs;
|
||||
std::string overflowed_result;
|
||||
};
|
||||
|
||||
static std::ostream& operator<<(std::ostream& o, const OverflowCase& c) {
|
||||
o << ast::FriendlyName(c.op) << ", lhs: " << c.lhs << ", rhs: " << c.rhs;
|
||||
return o;
|
||||
}
|
||||
using ResolverConstEvalBinaryOpTest_Overflow = ResolverTestWithParam<OverflowCase>;
|
||||
TEST_P(ResolverConstEvalBinaryOpTest_Overflow, Test) {
|
||||
Enable(ast::Extension::kF16);
|
||||
auto& c = GetParam();
|
||||
auto* lhs_expr = std::visit([&](auto&& value) { return value.Expr(*this); }, c.lhs);
|
||||
auto* rhs_expr = std::visit([&](auto&& value) { return value.Expr(*this); }, c.rhs);
|
||||
auto* expr = create<ast::BinaryExpression>(Source{{1, 1}}, c.op, lhs_expr, rhs_expr);
|
||||
GlobalConst("C", expr);
|
||||
ASSERT_FALSE(r()->Resolve());
|
||||
|
||||
std::string type_name = std::visit(
|
||||
[&](auto&& value) {
|
||||
using ValueType = std::decay_t<decltype(value)>;
|
||||
return tint::FriendlyName<typename ValueType::ElementType>();
|
||||
},
|
||||
c.lhs);
|
||||
|
||||
EXPECT_THAT(r()->error(), HasSubstr("1:1 error: '" + c.overflowed_result +
|
||||
"' cannot be represented as '" + type_name + "'"));
|
||||
}
|
||||
INSTANTIATE_TEST_SUITE_P(
|
||||
Test,
|
||||
ResolverConstEvalBinaryOpTest_Overflow,
|
||||
testing::Values( //
|
||||
// scalar-scalar add
|
||||
OverflowCase{ast::BinaryOp::kAdd, Val(AInt::Highest()), Val(1_a), "-9223372036854775808"},
|
||||
OverflowCase{ast::BinaryOp::kAdd, Val(AInt::Lowest()), Val(-1_a), "9223372036854775807"},
|
||||
OverflowCase{ast::BinaryOp::kAdd, Val(AFloat::Highest()), Val(AFloat::Highest()), "inf"},
|
||||
OverflowCase{ast::BinaryOp::kAdd, Val(AFloat::Lowest()), Val(AFloat::Lowest()), "-inf"},
|
||||
// scalar-scalar subtract
|
||||
OverflowCase{ast::BinaryOp::kSubtract, Val(AInt::Lowest()), Val(1_a),
|
||||
"9223372036854775807"},
|
||||
OverflowCase{ast::BinaryOp::kSubtract, Val(AInt::Highest()), Val(-1_a),
|
||||
"-9223372036854775808"},
|
||||
OverflowCase{ast::BinaryOp::kSubtract, Val(AFloat::Highest()), Val(AFloat::Lowest()),
|
||||
"inf"},
|
||||
OverflowCase{ast::BinaryOp::kSubtract, Val(AFloat::Lowest()), Val(AFloat::Highest()),
|
||||
"-inf"},
|
||||
|
||||
// scalar-scalar multiply
|
||||
OverflowCase{ast::BinaryOp::kMultiply, Val(AInt::Highest()), Val(2_a), "-2"},
|
||||
OverflowCase{ast::BinaryOp::kMultiply, Val(AInt::Lowest()), Val(-2_a), "0"},
|
||||
|
||||
// scalar-vector multiply
|
||||
OverflowCase{ast::BinaryOp::kMultiply, Val(AInt::Highest()), Vec(2_a, 1_a), "-2"},
|
||||
OverflowCase{ast::BinaryOp::kMultiply, Val(AInt::Lowest()), Vec(-2_a, 1_a), "0"},
|
||||
|
||||
// vector-matrix multiply
|
||||
|
||||
// Overflow from first multiplication of dot product of vector and matrix column 0
|
||||
// i.e. (v[0] * m[0][0] + v[1] * m[0][1])
|
||||
// ^
|
||||
OverflowCase{ast::BinaryOp::kMultiply, //
|
||||
Vec(AFloat::Highest(), 1.0_a), //
|
||||
Mat({2.0_a, 1.0_a}, //
|
||||
{1.0_a, 1.0_a}), //
|
||||
"inf"},
|
||||
|
||||
// Overflow from second multiplication of dot product of vector and matrix column 0
|
||||
// i.e. (v[0] * m[0][0] + v[1] * m[0][1])
|
||||
// ^
|
||||
OverflowCase{ast::BinaryOp::kMultiply, //
|
||||
Vec(1.0_a, AFloat::Highest()), //
|
||||
Mat({1.0_a, 2.0_a}, //
|
||||
{1.0_a, 1.0_a}), //
|
||||
"inf"},
|
||||
|
||||
// Overflow from addition of dot product of vector and matrix column 0
|
||||
// i.e. (v[0] * m[0][0] + v[1] * m[0][1])
|
||||
// ^
|
||||
OverflowCase{ast::BinaryOp::kMultiply, //
|
||||
Vec(AFloat::Highest(), AFloat::Highest()), //
|
||||
Mat({1.0_a, 1.0_a}, //
|
||||
{1.0_a, 1.0_a}), //
|
||||
"inf"},
|
||||
|
||||
// matrix-matrix multiply
|
||||
|
||||
// Overflow from first multiplication of dot product of lhs row 0 and rhs column 0
|
||||
// i.e. m1[0][0] * m2[0][0] + m1[0][1] * m[1][0]
|
||||
// ^
|
||||
OverflowCase{ast::BinaryOp::kMultiply, //
|
||||
Mat({AFloat::Highest(), 1.0_a}, //
|
||||
{1.0_a, 1.0_a}), //
|
||||
Mat({2.0_a, 1.0_a}, //
|
||||
{1.0_a, 1.0_a}), //
|
||||
"inf"},
|
||||
|
||||
// Overflow from second multiplication of dot product of lhs row 0 and rhs column 0
|
||||
// i.e. m1[0][0] * m2[0][0] + m1[0][1] * m[1][0]
|
||||
// ^
|
||||
OverflowCase{ast::BinaryOp::kMultiply, //
|
||||
Mat({1.0_a, AFloat::Highest()}, //
|
||||
{1.0_a, 1.0_a}), //
|
||||
Mat({1.0_a, 1.0_a}, //
|
||||
{2.0_a, 1.0_a}), //
|
||||
"inf"},
|
||||
|
||||
// Overflow from addition of dot product of lhs row 0 and rhs column 0
|
||||
// i.e. m1[0][0] * m2[0][0] + m1[0][1] * m[1][0]
|
||||
// ^
|
||||
OverflowCase{ast::BinaryOp::kMultiply, //
|
||||
Mat({AFloat::Highest(), 1.0_a}, //
|
||||
{AFloat::Highest(), 1.0_a}), //
|
||||
Mat({1.0_a, 1.0_a}, //
|
||||
{1.0_a, 1.0_a}), //
|
||||
"inf"}
|
||||
|
||||
));
|
||||
|
||||
TEST_F(ResolverConstEvalTest, BinaryAbstractAddOverflow_AInt) {
|
||||
GlobalConst("c", Add(Source{{1, 1}}, Expr(AInt::Highest()), 1_a));
|
||||
EXPECT_FALSE(r()->Resolve());
|
||||
|
||||
@@ -9572,108 +9572,108 @@ constexpr OverloadInfo kOverloads[] = {
|
||||
/* num parameters */ 2,
|
||||
/* num template types */ 1,
|
||||
/* num template numbers */ 0,
|
||||
/* template types */ &kTemplateTypes[15],
|
||||
/* template types */ &kTemplateTypes[13],
|
||||
/* template numbers */ &kTemplateNumbers[10],
|
||||
/* parameters */ &kParameters[727],
|
||||
/* return matcher indices */ &kMatcherIndices[1],
|
||||
/* flags */ OverloadFlags(OverloadFlag::kIsOperator, OverloadFlag::kSupportsVertexPipeline, OverloadFlag::kSupportsFragmentPipeline, OverloadFlag::kSupportsComputePipeline),
|
||||
/* const eval */ nullptr,
|
||||
/* const eval */ &ConstEval::OpMultiply,
|
||||
},
|
||||
{
|
||||
/* [118] */
|
||||
/* num parameters */ 2,
|
||||
/* num template types */ 1,
|
||||
/* num template numbers */ 1,
|
||||
/* template types */ &kTemplateTypes[15],
|
||||
/* template types */ &kTemplateTypes[13],
|
||||
/* template numbers */ &kTemplateNumbers[6],
|
||||
/* parameters */ &kParameters[725],
|
||||
/* return matcher indices */ &kMatcherIndices[30],
|
||||
/* flags */ OverloadFlags(OverloadFlag::kIsOperator, OverloadFlag::kSupportsVertexPipeline, OverloadFlag::kSupportsFragmentPipeline, OverloadFlag::kSupportsComputePipeline),
|
||||
/* const eval */ nullptr,
|
||||
/* const eval */ &ConstEval::OpMultiply,
|
||||
},
|
||||
{
|
||||
/* [119] */
|
||||
/* num parameters */ 2,
|
||||
/* num template types */ 1,
|
||||
/* num template numbers */ 1,
|
||||
/* template types */ &kTemplateTypes[15],
|
||||
/* template types */ &kTemplateTypes[13],
|
||||
/* template numbers */ &kTemplateNumbers[6],
|
||||
/* parameters */ &kParameters[723],
|
||||
/* return matcher indices */ &kMatcherIndices[30],
|
||||
/* flags */ OverloadFlags(OverloadFlag::kIsOperator, OverloadFlag::kSupportsVertexPipeline, OverloadFlag::kSupportsFragmentPipeline, OverloadFlag::kSupportsComputePipeline),
|
||||
/* const eval */ nullptr,
|
||||
/* const eval */ &ConstEval::OpMultiply,
|
||||
},
|
||||
{
|
||||
/* [120] */
|
||||
/* num parameters */ 2,
|
||||
/* num template types */ 1,
|
||||
/* num template numbers */ 1,
|
||||
/* template types */ &kTemplateTypes[15],
|
||||
/* template types */ &kTemplateTypes[13],
|
||||
/* template numbers */ &kTemplateNumbers[6],
|
||||
/* parameters */ &kParameters[721],
|
||||
/* return matcher indices */ &kMatcherIndices[30],
|
||||
/* flags */ OverloadFlags(OverloadFlag::kIsOperator, OverloadFlag::kSupportsVertexPipeline, OverloadFlag::kSupportsFragmentPipeline, OverloadFlag::kSupportsComputePipeline),
|
||||
/* const eval */ nullptr,
|
||||
/* const eval */ &ConstEval::OpMultiply,
|
||||
},
|
||||
{
|
||||
/* [121] */
|
||||
/* num parameters */ 2,
|
||||
/* num template types */ 1,
|
||||
/* num template numbers */ 2,
|
||||
/* template types */ &kTemplateTypes[11],
|
||||
/* template types */ &kTemplateTypes[12],
|
||||
/* template numbers */ &kTemplateNumbers[6],
|
||||
/* parameters */ &kParameters[719],
|
||||
/* return matcher indices */ &kMatcherIndices[10],
|
||||
/* flags */ OverloadFlags(OverloadFlag::kIsOperator, OverloadFlag::kSupportsVertexPipeline, OverloadFlag::kSupportsFragmentPipeline, OverloadFlag::kSupportsComputePipeline),
|
||||
/* const eval */ nullptr,
|
||||
/* const eval */ &ConstEval::OpMultiply,
|
||||
},
|
||||
{
|
||||
/* [122] */
|
||||
/* num parameters */ 2,
|
||||
/* num template types */ 1,
|
||||
/* num template numbers */ 2,
|
||||
/* template types */ &kTemplateTypes[11],
|
||||
/* template types */ &kTemplateTypes[12],
|
||||
/* template numbers */ &kTemplateNumbers[6],
|
||||
/* parameters */ &kParameters[717],
|
||||
/* return matcher indices */ &kMatcherIndices[10],
|
||||
/* flags */ OverloadFlags(OverloadFlag::kIsOperator, OverloadFlag::kSupportsVertexPipeline, OverloadFlag::kSupportsFragmentPipeline, OverloadFlag::kSupportsComputePipeline),
|
||||
/* const eval */ nullptr,
|
||||
/* const eval */ &ConstEval::OpMultiply,
|
||||
},
|
||||
{
|
||||
/* [123] */
|
||||
/* num parameters */ 2,
|
||||
/* num template types */ 1,
|
||||
/* num template numbers */ 2,
|
||||
/* template types */ &kTemplateTypes[11],
|
||||
/* template types */ &kTemplateTypes[12],
|
||||
/* template numbers */ &kTemplateNumbers[1],
|
||||
/* parameters */ &kParameters[715],
|
||||
/* return matcher indices */ &kMatcherIndices[69],
|
||||
/* flags */ OverloadFlags(OverloadFlag::kIsOperator, OverloadFlag::kSupportsVertexPipeline, OverloadFlag::kSupportsFragmentPipeline, OverloadFlag::kSupportsComputePipeline),
|
||||
/* const eval */ nullptr,
|
||||
/* const eval */ &ConstEval::OpMultiplyMatVec,
|
||||
},
|
||||
{
|
||||
/* [124] */
|
||||
/* num parameters */ 2,
|
||||
/* num template types */ 1,
|
||||
/* num template numbers */ 2,
|
||||
/* template types */ &kTemplateTypes[11],
|
||||
/* template types */ &kTemplateTypes[12],
|
||||
/* template numbers */ &kTemplateNumbers[1],
|
||||
/* parameters */ &kParameters[713],
|
||||
/* return matcher indices */ &kMatcherIndices[30],
|
||||
/* flags */ OverloadFlags(OverloadFlag::kIsOperator, OverloadFlag::kSupportsVertexPipeline, OverloadFlag::kSupportsFragmentPipeline, OverloadFlag::kSupportsComputePipeline),
|
||||
/* const eval */ nullptr,
|
||||
/* const eval */ &ConstEval::OpMultiplyVecMat,
|
||||
},
|
||||
{
|
||||
/* [125] */
|
||||
/* num parameters */ 2,
|
||||
/* num template types */ 1,
|
||||
/* num template numbers */ 3,
|
||||
/* template types */ &kTemplateTypes[11],
|
||||
/* template types */ &kTemplateTypes[12],
|
||||
/* template numbers */ &kTemplateNumbers[0],
|
||||
/* parameters */ &kParameters[711],
|
||||
/* return matcher indices */ &kMatcherIndices[22],
|
||||
/* flags */ OverloadFlags(OverloadFlag::kIsOperator, OverloadFlag::kSupportsVertexPipeline, OverloadFlag::kSupportsFragmentPipeline, OverloadFlag::kSupportsComputePipeline),
|
||||
/* const eval */ nullptr,
|
||||
/* const eval */ &ConstEval::OpMultiplyMatMat,
|
||||
},
|
||||
{
|
||||
/* [126] */
|
||||
@@ -14630,15 +14630,15 @@ constexpr IntrinsicInfo kBinaryOperators[] = {
|
||||
},
|
||||
{
|
||||
/* [2] */
|
||||
/* op *<T : fiu32_f16>(T, T) -> T */
|
||||
/* op *<T : fiu32_f16, N : num>(vec<N, T>, vec<N, T>) -> vec<N, T> */
|
||||
/* op *<T : fiu32_f16, N : num>(vec<N, T>, T) -> vec<N, T> */
|
||||
/* op *<T : fiu32_f16, N : num>(T, vec<N, T>) -> vec<N, T> */
|
||||
/* op *<T : f32_f16, N : num, M : num>(T, mat<N, M, T>) -> mat<N, M, T> */
|
||||
/* op *<T : f32_f16, N : num, M : num>(mat<N, M, T>, T) -> mat<N, M, T> */
|
||||
/* op *<T : f32_f16, C : num, R : num>(mat<C, R, T>, vec<C, T>) -> vec<R, T> */
|
||||
/* op *<T : f32_f16, C : num, R : num>(vec<R, T>, mat<C, R, T>) -> vec<C, T> */
|
||||
/* op *<T : f32_f16, K : num, C : num, R : num>(mat<K, R, T>, mat<C, K, T>) -> mat<C, R, T> */
|
||||
/* op *<T : fia_fiu32_f16>(T, T) -> T */
|
||||
/* op *<T : fia_fiu32_f16, N : num>(vec<N, T>, vec<N, T>) -> vec<N, T> */
|
||||
/* op *<T : fia_fiu32_f16, N : num>(vec<N, T>, T) -> vec<N, T> */
|
||||
/* op *<T : fia_fiu32_f16, N : num>(T, vec<N, T>) -> vec<N, T> */
|
||||
/* op *<T : fa_f32_f16, N : num, M : num>(T, mat<N, M, T>) -> mat<N, M, T> */
|
||||
/* op *<T : fa_f32_f16, N : num, M : num>(mat<N, M, T>, T) -> mat<N, M, T> */
|
||||
/* op *<T : fa_f32_f16, C : num, R : num>(mat<C, R, T>, vec<C, T>) -> vec<R, T> */
|
||||
/* op *<T : fa_f32_f16, C : num, R : num>(vec<R, T>, mat<C, R, T>) -> vec<C, T> */
|
||||
/* op *<T : fa_f32_f16, K : num, C : num, R : num>(mat<K, R, T>, mat<C, K, T>) -> mat<C, R, T> */
|
||||
/* num overloads */ 9,
|
||||
/* overloads */ &kOverloads[117],
|
||||
},
|
||||
|
||||
@@ -641,15 +641,15 @@ TEST_F(IntrinsicTableTest, MismatchBinaryOp) {
|
||||
EXPECT_EQ(Diagnostics().str(), R"(12:34 error: no matching overload for operator * (f32, bool)
|
||||
|
||||
9 candidate operators:
|
||||
operator * (T, T) -> T where: T is f32, i32, u32 or f16
|
||||
operator * (vecN<T>, T) -> vecN<T> where: T is f32, i32, u32 or f16
|
||||
operator * (T, vecN<T>) -> vecN<T> where: T is f32, i32, u32 or f16
|
||||
operator * (T, matNxM<T>) -> matNxM<T> where: T is f32 or f16
|
||||
operator * (matNxM<T>, T) -> matNxM<T> where: T is f32 or f16
|
||||
operator * (vecN<T>, vecN<T>) -> vecN<T> where: T is f32, i32, u32 or f16
|
||||
operator * (matCxR<T>, vecC<T>) -> vecR<T> where: T is f32 or f16
|
||||
operator * (vecR<T>, matCxR<T>) -> vecC<T> where: T is f32 or f16
|
||||
operator * (matKxR<T>, matCxK<T>) -> matCxR<T> where: T is f32 or f16
|
||||
operator * (T, T) -> T where: T is abstract-float, abstract-int, f32, i32, u32 or f16
|
||||
operator * (vecN<T>, T) -> vecN<T> where: T is abstract-float, abstract-int, f32, i32, u32 or f16
|
||||
operator * (T, vecN<T>) -> vecN<T> where: T is abstract-float, abstract-int, f32, i32, u32 or f16
|
||||
operator * (T, matNxM<T>) -> matNxM<T> where: T is abstract-float, f32 or f16
|
||||
operator * (matNxM<T>, T) -> matNxM<T> where: T is abstract-float, f32 or f16
|
||||
operator * (vecN<T>, vecN<T>) -> vecN<T> where: T is abstract-float, abstract-int, f32, i32, u32 or f16
|
||||
operator * (matCxR<T>, vecC<T>) -> vecR<T> where: T is abstract-float, f32 or f16
|
||||
operator * (vecR<T>, matCxR<T>) -> vecC<T> where: T is abstract-float, f32 or f16
|
||||
operator * (matKxR<T>, matCxK<T>) -> matCxR<T> where: T is abstract-float, f32 or f16
|
||||
)");
|
||||
}
|
||||
|
||||
@@ -673,15 +673,15 @@ TEST_F(IntrinsicTableTest, MismatchCompoundOp) {
|
||||
EXPECT_EQ(Diagnostics().str(), R"(12:34 error: no matching overload for operator *= (f32, bool)
|
||||
|
||||
9 candidate operators:
|
||||
operator *= (T, T) -> T where: T is f32, i32, u32 or f16
|
||||
operator *= (vecN<T>, T) -> vecN<T> where: T is f32, i32, u32 or f16
|
||||
operator *= (T, vecN<T>) -> vecN<T> where: T is f32, i32, u32 or f16
|
||||
operator *= (T, matNxM<T>) -> matNxM<T> where: T is f32 or f16
|
||||
operator *= (matNxM<T>, T) -> matNxM<T> where: T is f32 or f16
|
||||
operator *= (vecN<T>, vecN<T>) -> vecN<T> where: T is f32, i32, u32 or f16
|
||||
operator *= (matCxR<T>, vecC<T>) -> vecR<T> where: T is f32 or f16
|
||||
operator *= (vecR<T>, matCxR<T>) -> vecC<T> where: T is f32 or f16
|
||||
operator *= (matKxR<T>, matCxK<T>) -> matCxR<T> where: T is f32 or f16
|
||||
operator *= (T, T) -> T where: T is abstract-float, abstract-int, f32, i32, u32 or f16
|
||||
operator *= (vecN<T>, T) -> vecN<T> where: T is abstract-float, abstract-int, f32, i32, u32 or f16
|
||||
operator *= (T, vecN<T>) -> vecN<T> where: T is abstract-float, abstract-int, f32, i32, u32 or f16
|
||||
operator *= (T, matNxM<T>) -> matNxM<T> where: T is abstract-float, f32 or f16
|
||||
operator *= (matNxM<T>, T) -> matNxM<T> where: T is abstract-float, f32 or f16
|
||||
operator *= (vecN<T>, vecN<T>) -> vecN<T> where: T is abstract-float, abstract-int, f32, i32, u32 or f16
|
||||
operator *= (matCxR<T>, vecC<T>) -> vecR<T> where: T is abstract-float, f32 or f16
|
||||
operator *= (vecR<T>, matCxR<T>) -> vecC<T> where: T is abstract-float, f32 or f16
|
||||
operator *= (matKxR<T>, matCxK<T>) -> matCxR<T> where: T is abstract-float, f32 or f16
|
||||
)");
|
||||
}
|
||||
|
||||
|
||||
@@ -17,6 +17,7 @@
|
||||
|
||||
#include <ostream>
|
||||
#include <variant>
|
||||
#include "src/tint/debug.h"
|
||||
|
||||
namespace tint::utils {
|
||||
|
||||
@@ -36,6 +37,9 @@ struct [[nodiscard]] Result {
|
||||
static_assert(!std::is_same_v<SUCCESS_TYPE, FAILURE_TYPE>,
|
||||
"Result must not have the same type for SUCCESS_TYPE and FAILURE_TYPE");
|
||||
|
||||
/// Default constructor initializes to invalid state
|
||||
Result() : value(std::monostate{}) {}
|
||||
|
||||
/// Constructor
|
||||
/// @param success the success result
|
||||
Result(const SUCCESS_TYPE& success) // NOLINT(runtime/explicit):
|
||||
@@ -47,27 +51,43 @@ struct [[nodiscard]] Result {
|
||||
: value{failure} {}
|
||||
|
||||
/// @returns true if the result was a success
|
||||
operator bool() const { return std::holds_alternative<SUCCESS_TYPE>(value); }
|
||||
operator bool() const {
|
||||
Validate();
|
||||
return std::holds_alternative<SUCCESS_TYPE>(value);
|
||||
}
|
||||
|
||||
/// @returns true if the result was a failure
|
||||
bool operator!() const { return std::holds_alternative<FAILURE_TYPE>(value); }
|
||||
bool operator!() const {
|
||||
Validate();
|
||||
return std::holds_alternative<FAILURE_TYPE>(value);
|
||||
}
|
||||
|
||||
/// @returns the success value
|
||||
/// @warning attempting to call this when the Result holds an failure will result in UB.
|
||||
const SUCCESS_TYPE* operator->() const { return &std::get<SUCCESS_TYPE>(value); }
|
||||
const SUCCESS_TYPE* operator->() const {
|
||||
Validate();
|
||||
return &(Get());
|
||||
}
|
||||
|
||||
/// @returns the success value
|
||||
/// @warning attempting to call this when the Result holds an failure value will result in UB.
|
||||
const SUCCESS_TYPE& Get() const { return std::get<SUCCESS_TYPE>(value); }
|
||||
const SUCCESS_TYPE& Get() const {
|
||||
Validate();
|
||||
return std::get<SUCCESS_TYPE>(value);
|
||||
}
|
||||
|
||||
/// @returns the failure value
|
||||
/// @warning attempting to call this when the Result holds a success value will result in UB.
|
||||
const FAILURE_TYPE& Failure() const { return std::get<FAILURE_TYPE>(value); }
|
||||
const FAILURE_TYPE& Failure() const {
|
||||
Validate();
|
||||
return std::get<FAILURE_TYPE>(value);
|
||||
}
|
||||
|
||||
/// Equality operator
|
||||
/// @param val the value to compare this Result to
|
||||
/// @returns true if this result holds a success value equal to `value`
|
||||
bool operator==(SUCCESS_TYPE val) const {
|
||||
Validate();
|
||||
if (auto* v = std::get_if<SUCCESS_TYPE>(&value)) {
|
||||
return *v == val;
|
||||
}
|
||||
@@ -78,14 +98,18 @@ struct [[nodiscard]] Result {
|
||||
/// @param val the value to compare this Result to
|
||||
/// @returns true if this result holds a failure value equal to `value`
|
||||
bool operator==(FAILURE_TYPE val) const {
|
||||
Validate();
|
||||
if (auto* v = std::get_if<FAILURE_TYPE>(&value)) {
|
||||
return *v == val;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
private:
|
||||
void Validate() const { TINT_ASSERT(Utils, !std::holds_alternative<std::monostate>(value)); }
|
||||
|
||||
/// The result. Either a success of failure value.
|
||||
std::variant<SUCCESS_TYPE, FAILURE_TYPE> value;
|
||||
std::variant<std::monostate, SUCCESS_TYPE, FAILURE_TYPE> value;
|
||||
};
|
||||
|
||||
/// Writes the result to the ostream.
|
||||
|
||||
@@ -151,7 +151,8 @@ INSTANTIATE_TEST_SUITE_P(
|
||||
BinaryData{"(left % right)", ast::BinaryOp::kModulo}));
|
||||
|
||||
TEST_F(GlslGeneratorImplTest_Binary, Multiply_VectorScalar_f32) {
|
||||
auto* lhs = vec3<f32>(1_f, 1_f, 1_f);
|
||||
GlobalVar("a", vec3<f32>(1_f, 1_f, 1_f), ast::StorageClass::kPrivate);
|
||||
auto* lhs = Expr("a");
|
||||
auto* rhs = Expr(1_f);
|
||||
|
||||
auto* expr = create<ast::BinaryExpression>(ast::BinaryOp::kMultiply, lhs, rhs);
|
||||
@@ -162,13 +163,14 @@ TEST_F(GlslGeneratorImplTest_Binary, Multiply_VectorScalar_f32) {
|
||||
|
||||
std::stringstream out;
|
||||
EXPECT_TRUE(gen.EmitExpression(out, expr)) << gen.error();
|
||||
EXPECT_EQ(out.str(), "(vec3(1.0f) * 1.0f)");
|
||||
EXPECT_EQ(out.str(), "(a * 1.0f)");
|
||||
}
|
||||
|
||||
TEST_F(GlslGeneratorImplTest_Binary, Multiply_VectorScalar_f16) {
|
||||
Enable(ast::Extension::kF16);
|
||||
|
||||
auto* lhs = vec3<f16>(1_h, 1_h, 1_h);
|
||||
GlobalVar("a", vec3<f16>(1_h, 1_h, 1_h), ast::StorageClass::kPrivate);
|
||||
auto* lhs = Expr("a");
|
||||
auto* rhs = Expr(1_h);
|
||||
|
||||
auto* expr = create<ast::BinaryExpression>(ast::BinaryOp::kMultiply, lhs, rhs);
|
||||
@@ -179,12 +181,13 @@ TEST_F(GlslGeneratorImplTest_Binary, Multiply_VectorScalar_f16) {
|
||||
|
||||
std::stringstream out;
|
||||
EXPECT_TRUE(gen.EmitExpression(out, expr)) << gen.error();
|
||||
EXPECT_EQ(out.str(), "(f16vec3(1.0hf) * 1.0hf)");
|
||||
EXPECT_EQ(out.str(), "(a * 1.0hf)");
|
||||
}
|
||||
|
||||
TEST_F(GlslGeneratorImplTest_Binary, Multiply_ScalarVector_f32) {
|
||||
GlobalVar("a", vec3<f32>(1_f, 1_f, 1_f), ast::StorageClass::kPrivate);
|
||||
auto* lhs = Expr(1_f);
|
||||
auto* rhs = vec3<f32>(1_f, 1_f, 1_f);
|
||||
auto* rhs = Expr("a");
|
||||
|
||||
auto* expr = create<ast::BinaryExpression>(ast::BinaryOp::kMultiply, lhs, rhs);
|
||||
|
||||
@@ -194,14 +197,15 @@ TEST_F(GlslGeneratorImplTest_Binary, Multiply_ScalarVector_f32) {
|
||||
|
||||
std::stringstream out;
|
||||
EXPECT_TRUE(gen.EmitExpression(out, expr)) << gen.error();
|
||||
EXPECT_EQ(out.str(), "(1.0f * vec3(1.0f))");
|
||||
EXPECT_EQ(out.str(), "(1.0f * a)");
|
||||
}
|
||||
|
||||
TEST_F(GlslGeneratorImplTest_Binary, Multiply_ScalarVector_f16) {
|
||||
Enable(ast::Extension::kF16);
|
||||
|
||||
GlobalVar("a", vec3<f16>(1_h, 1_h, 1_h), ast::StorageClass::kPrivate);
|
||||
auto* lhs = Expr(1_h);
|
||||
auto* rhs = vec3<f16>(1_h, 1_h, 1_h);
|
||||
auto* rhs = Expr("a");
|
||||
|
||||
auto* expr = create<ast::BinaryExpression>(ast::BinaryOp::kMultiply, lhs, rhs);
|
||||
|
||||
@@ -211,7 +215,7 @@ TEST_F(GlslGeneratorImplTest_Binary, Multiply_ScalarVector_f16) {
|
||||
|
||||
std::stringstream out;
|
||||
EXPECT_TRUE(gen.EmitExpression(out, expr)) << gen.error();
|
||||
EXPECT_EQ(out.str(), "(1.0hf * f16vec3(1.0hf))");
|
||||
EXPECT_EQ(out.str(), "(1.0hf * a)");
|
||||
}
|
||||
|
||||
TEST_F(GlslGeneratorImplTest_Binary, Multiply_MatrixScalar_f32) {
|
||||
|
||||
@@ -184,7 +184,7 @@ TEST_F(HlslGeneratorImplTest_Binary, Multiply_VectorScalar_f32) {
|
||||
|
||||
std::stringstream out;
|
||||
EXPECT_TRUE(gen.EmitExpression(out, expr)) << gen.error();
|
||||
EXPECT_EQ(out.str(), "((1.0f).xxx * 1.0f)");
|
||||
EXPECT_EQ(out.str(), "(1.0f).xxx");
|
||||
}
|
||||
|
||||
TEST_F(HlslGeneratorImplTest_Binary, Multiply_VectorScalar_f16) {
|
||||
@@ -201,7 +201,7 @@ TEST_F(HlslGeneratorImplTest_Binary, Multiply_VectorScalar_f16) {
|
||||
|
||||
std::stringstream out;
|
||||
EXPECT_TRUE(gen.EmitExpression(out, expr)) << gen.error();
|
||||
EXPECT_EQ(out.str(), "((float16_t(1.0h)).xxx * float16_t(1.0h))");
|
||||
EXPECT_EQ(out.str(), "(float16_t(1.0h)).xxx");
|
||||
}
|
||||
|
||||
TEST_F(HlslGeneratorImplTest_Binary, Multiply_ScalarVector_f32) {
|
||||
@@ -216,7 +216,7 @@ TEST_F(HlslGeneratorImplTest_Binary, Multiply_ScalarVector_f32) {
|
||||
|
||||
std::stringstream out;
|
||||
EXPECT_TRUE(gen.EmitExpression(out, expr)) << gen.error();
|
||||
EXPECT_EQ(out.str(), "(1.0f * (1.0f).xxx)");
|
||||
EXPECT_EQ(out.str(), "(1.0f).xxx");
|
||||
}
|
||||
|
||||
TEST_F(HlslGeneratorImplTest_Binary, Multiply_ScalarVector_f16) {
|
||||
@@ -233,7 +233,7 @@ TEST_F(HlslGeneratorImplTest_Binary, Multiply_ScalarVector_f16) {
|
||||
|
||||
std::stringstream out;
|
||||
EXPECT_TRUE(gen.EmitExpression(out, expr)) << gen.error();
|
||||
EXPECT_EQ(out.str(), "(float16_t(1.0h) * (float16_t(1.0h)).xxx)");
|
||||
EXPECT_EQ(out.str(), "(float16_t(1.0h)).xxx");
|
||||
}
|
||||
|
||||
TEST_F(HlslGeneratorImplTest_Binary, Multiply_MatrixScalar_f32) {
|
||||
|
||||
Reference in New Issue
Block a user