tint: Implement const eval of binary multiply
Bug: tint:1581 Change-Id: I70ff40ed4d8faf0a665824fef936ffbafb3f0948 Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/99362 Kokoro: Kokoro <noreply+kokoro@google.com> Commit-Queue: Antonio Maiorano <amaiorano@google.com> Reviewed-by: Ben Clayton <bclayton@google.com>
This commit is contained in:
parent
ae6f76fe3a
commit
c20c5dfb4a
|
@ -898,15 +898,15 @@ op ! <N: num> (vec<N, bool>) -> vec<N, bool>
|
|||
@const op - <T: fia_fiu32_f16, N: num> (T, vec<N, T>) -> vec<N, T>
|
||||
@const op - <T: fa_f32_f16, N: num, M: num> (mat<N, M, T>, mat<N, M, T>) -> mat<N, M, T>
|
||||
|
||||
op * <T: fiu32_f16>(T, T) -> T
|
||||
op * <T: fiu32_f16, N: num> (vec<N, T>, vec<N, T>) -> vec<N, T>
|
||||
op * <T: fiu32_f16, N: num> (vec<N, T>, T) -> vec<N, T>
|
||||
op * <T: fiu32_f16, N: num> (T, vec<N, T>) -> vec<N, T>
|
||||
op * <T: f32_f16, N: num, M: num> (T, mat<N, M, T>) -> mat<N, M, T>
|
||||
op * <T: f32_f16, N: num, M: num> (mat<N, M, T>, T) -> mat<N, M, T>
|
||||
op * <T: f32_f16, C: num, R: num> (mat<C, R, T>, vec<C, T>) -> vec<R, T>
|
||||
op * <T: f32_f16, C: num, R: num> (vec<R, T>, mat<C, R, T>) -> vec<C, T>
|
||||
op * <T: f32_f16, K: num, C: num, R: num> (mat<K, R, T>, mat<C, K, T>) -> mat<C, R, T>
|
||||
@const("Multiply") op * <T: fia_fiu32_f16>(T, T) -> T
|
||||
@const("Multiply") op * <T: fia_fiu32_f16, N: num> (vec<N, T>, vec<N, T>) -> vec<N, T>
|
||||
@const("Multiply") op * <T: fia_fiu32_f16, N: num> (vec<N, T>, T) -> vec<N, T>
|
||||
@const("Multiply") op * <T: fia_fiu32_f16, N: num> (T, vec<N, T>) -> vec<N, T>
|
||||
@const("Multiply") op * <T: fa_f32_f16, N: num, M: num> (T, mat<N, M, T>) -> mat<N, M, T>
|
||||
@const("Multiply") op * <T: fa_f32_f16, N: num, M: num> (mat<N, M, T>, T) -> mat<N, M, T>
|
||||
@const("MultiplyMatVec") op * <T: fa_f32_f16, C: num, R: num> (mat<C, R, T>, vec<C, T>) -> vec<R, T>
|
||||
@const("MultiplyVecMat") op * <T: fa_f32_f16, C: num, R: num> (vec<R, T>, mat<C, R, T>) -> vec<C, T>
|
||||
@const("MultiplyMatMat") op * <T: fa_f32_f16, K: num, C: num, R: num> (mat<K, R, T>, mat<C, K, T>) -> mat<C, R, T>
|
||||
|
||||
op / <T: fiu32_f16>(T, T) -> T
|
||||
op / <T: fiu32_f16, N: num> (vec<N, T>, vec<N, T>) -> vec<N, T>
|
||||
|
|
|
@ -22,6 +22,7 @@
|
|||
#include <optional>
|
||||
#include <ostream>
|
||||
|
||||
#include "src/tint/traits.h"
|
||||
#include "src/tint/utils/compiler_macros.h"
|
||||
#include "src/tint/utils/result.h"
|
||||
|
||||
|
@ -33,6 +34,14 @@ struct Number;
|
|||
} // namespace tint
|
||||
|
||||
namespace tint::detail {
|
||||
/// Base template for IsNumber
|
||||
template <typename T>
|
||||
struct IsNumber : std::false_type {};
|
||||
|
||||
/// Specialization for IsNumber
|
||||
template <typename T>
|
||||
struct IsNumber<Number<T>> : std::true_type {};
|
||||
|
||||
/// An empty structure used as a unique template type for Number when
|
||||
/// specializing for the f16 type.
|
||||
struct NumberKindF16 {};
|
||||
|
@ -68,6 +77,10 @@ constexpr bool IsInteger = std::is_integral_v<T>;
|
|||
template <typename T>
|
||||
constexpr bool IsNumeric = IsInteger<T> || IsFloatingPoint<T>;
|
||||
|
||||
/// Evaluates to true iff T is a Number
|
||||
template <typename T>
|
||||
constexpr bool IsNumber = detail::IsNumber<T>::value;
|
||||
|
||||
/// Resolves to the underlying type for a Number.
|
||||
template <typename T>
|
||||
using UnwrapNumber = typename detail::NumberUnwrapper<T>::type;
|
||||
|
@ -236,6 +249,26 @@ using f32 = Number<float>;
|
|||
/// However since C++ don't have native binary16 type, the value is stored as float.
|
||||
using f16 = Number<detail::NumberKindF16>;
|
||||
|
||||
/// @returns the friendly name of Number type T
|
||||
template <typename T, typename = traits::EnableIf<IsNumber<T>>>
|
||||
const char* FriendlyName() {
|
||||
if constexpr (std::is_same_v<T, AInt>) {
|
||||
return "abstract-int";
|
||||
} else if constexpr (std::is_same_v<T, AFloat>) {
|
||||
return "abstract-float";
|
||||
} else if constexpr (std::is_same_v<T, i32>) {
|
||||
return "i32";
|
||||
} else if constexpr (std::is_same_v<T, u32>) {
|
||||
return "u32";
|
||||
} else if constexpr (std::is_same_v<T, f32>) {
|
||||
return "f32";
|
||||
} else if constexpr (std::is_same_v<T, f16>) {
|
||||
return "f16";
|
||||
} else {
|
||||
static_assert(!sizeof(T), "Unhandled type");
|
||||
}
|
||||
}
|
||||
|
||||
/// Enumerator of failure reasons when converting from one number to another.
|
||||
enum class ConversionFailure {
|
||||
kExceedsPositiveLimit, // The value was too big (+'ve) to fit in the target type
|
||||
|
@ -437,6 +470,15 @@ inline std::optional<AInt> CheckedMul(AInt a, AInt b) {
|
|||
return AInt(result);
|
||||
}
|
||||
|
||||
/// @returns a * b, or an empty optional if the resulting value overflowed the AFloat
|
||||
inline std::optional<AFloat> CheckedMul(AFloat a, AFloat b) {
|
||||
auto result = a.value * b.value;
|
||||
if (!std::isfinite(result)) {
|
||||
return {};
|
||||
}
|
||||
return AFloat{result};
|
||||
}
|
||||
|
||||
/// @returns a * b + c, or an empty optional if the value overflowed the AInt
|
||||
inline std::optional<AInt> CheckedMadd(AInt a, AInt b, AInt c) {
|
||||
// https://gcc.gnu.org/bugzilla/show_bug.cgi?id=80635
|
||||
|
|
|
@ -38,6 +38,7 @@
|
|||
#include "src/tint/sem/vector.h"
|
||||
#include "src/tint/utils/compiler_macros.h"
|
||||
#include "src/tint/utils/map.h"
|
||||
#include "src/tint/utils/scoped_assignment.h"
|
||||
#include "src/tint/utils/transform.h"
|
||||
|
||||
using namespace tint::number_suffixes; // NOLINT
|
||||
|
@ -508,11 +509,202 @@ const Constant* TransformBinaryElements(ProgramBuilder& builder,
|
|||
auto* ty = n0 > n1 ? c0->Type() : c1->Type();
|
||||
return CreateComposite(builder, ty, std::move(els));
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
ConstEval::ConstEval(ProgramBuilder& b) : builder(b) {}
|
||||
|
||||
template <typename NumberT>
|
||||
utils::Result<NumberT> ConstEval::Add(NumberT a, NumberT b) {
|
||||
using T = UnwrapNumber<NumberT>;
|
||||
auto add_values = [](T lhs, T rhs) {
|
||||
if constexpr (std::is_integral_v<T> && std::is_signed_v<T>) {
|
||||
// Ensure no UB for signed overflow
|
||||
using UT = std::make_unsigned_t<T>;
|
||||
return static_cast<T>(static_cast<UT>(lhs) + static_cast<UT>(rhs));
|
||||
} else {
|
||||
return lhs + rhs;
|
||||
}
|
||||
};
|
||||
NumberT result;
|
||||
if constexpr (std::is_same_v<NumberT, AInt> || std::is_same_v<NumberT, AFloat>) {
|
||||
// Check for over/underflow for abstract values
|
||||
if (auto r = CheckedAdd(a, b)) {
|
||||
result = r->value;
|
||||
} else {
|
||||
AddError("'" + std::to_string(add_values(a.value, b.value)) +
|
||||
"' cannot be represented as '" + FriendlyName<NumberT>() + "'",
|
||||
*current_source);
|
||||
return utils::Failure;
|
||||
}
|
||||
} else {
|
||||
result = add_values(a.value, b.value);
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
template <typename NumberT>
|
||||
utils::Result<NumberT> ConstEval::Mul(NumberT a, NumberT b) {
|
||||
using T = UnwrapNumber<NumberT>;
|
||||
auto mul_values = [](T lhs, T rhs) { //
|
||||
if constexpr (std::is_integral_v<T> && std::is_signed_v<T>) {
|
||||
// For signed integrals, avoid C++ UB by multiplying as unsigned
|
||||
using UT = std::make_unsigned_t<T>;
|
||||
return static_cast<T>(static_cast<UT>(lhs) * static_cast<UT>(rhs));
|
||||
} else {
|
||||
return lhs * rhs;
|
||||
}
|
||||
};
|
||||
NumberT result;
|
||||
if constexpr (std::is_same_v<NumberT, AInt> || std::is_same_v<NumberT, AFloat>) {
|
||||
// Check for over/underflow for abstract values
|
||||
if (auto r = CheckedMul(a, b)) {
|
||||
result = r->value;
|
||||
} else {
|
||||
AddError("'" + std::to_string(mul_values(a.value, b.value)) +
|
||||
"' cannot be represented as '" + FriendlyName<NumberT>() + "'",
|
||||
*current_source);
|
||||
return utils::Failure;
|
||||
}
|
||||
} else {
|
||||
result = mul_values(a.value, b.value);
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
template <typename NumberT>
|
||||
utils::Result<NumberT> ConstEval::Dot2(NumberT a1, NumberT a2, NumberT b1, NumberT b2) {
|
||||
auto r1 = Mul(a1, b1);
|
||||
if (!r1) {
|
||||
return utils::Failure;
|
||||
}
|
||||
auto r2 = Mul(a2, b2);
|
||||
if (!r2) {
|
||||
return utils::Failure;
|
||||
}
|
||||
auto r = Add(r1.Get(), r2.Get());
|
||||
if (!r) {
|
||||
return utils::Failure;
|
||||
}
|
||||
return r;
|
||||
}
|
||||
|
||||
template <typename NumberT>
|
||||
utils::Result<NumberT> ConstEval::Dot3(NumberT a1,
|
||||
NumberT a2,
|
||||
NumberT a3,
|
||||
NumberT b1,
|
||||
NumberT b2,
|
||||
NumberT b3) {
|
||||
auto r1 = Mul(a1, b1);
|
||||
if (!r1) {
|
||||
return utils::Failure;
|
||||
}
|
||||
auto r2 = Mul(a2, b2);
|
||||
if (!r2) {
|
||||
return utils::Failure;
|
||||
}
|
||||
auto r3 = Mul(a3, b3);
|
||||
if (!r3) {
|
||||
return utils::Failure;
|
||||
}
|
||||
auto r = Add(r1.Get(), r2.Get());
|
||||
if (!r) {
|
||||
return utils::Failure;
|
||||
}
|
||||
r = Add(r.Get(), r3.Get());
|
||||
if (!r) {
|
||||
return utils::Failure;
|
||||
}
|
||||
return r;
|
||||
}
|
||||
|
||||
template <typename NumberT>
|
||||
utils::Result<NumberT> ConstEval::Dot4(NumberT a1,
|
||||
NumberT a2,
|
||||
NumberT a3,
|
||||
NumberT a4,
|
||||
NumberT b1,
|
||||
NumberT b2,
|
||||
NumberT b3,
|
||||
NumberT b4) {
|
||||
auto r1 = Mul(a1, b1);
|
||||
if (!r1) {
|
||||
return utils::Failure;
|
||||
}
|
||||
auto r2 = Mul(a2, b2);
|
||||
if (!r2) {
|
||||
return utils::Failure;
|
||||
}
|
||||
auto r3 = Mul(a3, b3);
|
||||
if (!r3) {
|
||||
return utils::Failure;
|
||||
}
|
||||
auto r4 = Mul(a4, b4);
|
||||
if (!r4) {
|
||||
return utils::Failure;
|
||||
}
|
||||
auto r = Add(r1.Get(), r2.Get());
|
||||
if (!r) {
|
||||
return utils::Failure;
|
||||
}
|
||||
r = Add(r.Get(), r3.Get());
|
||||
if (!r) {
|
||||
return utils::Failure;
|
||||
}
|
||||
r = Add(r.Get(), r4.Get());
|
||||
if (!r) {
|
||||
return utils::Failure;
|
||||
}
|
||||
return r;
|
||||
}
|
||||
|
||||
auto ConstEval::AddFunc(const sem::Type* elem_ty) {
|
||||
return [=](auto a1, auto a2) -> utils::Result<const Constant*> {
|
||||
if (auto r = Add(a1, a2)) {
|
||||
return CreateElement(builder, elem_ty, r.Get());
|
||||
}
|
||||
return utils::Failure;
|
||||
};
|
||||
}
|
||||
|
||||
auto ConstEval::MulFunc(const sem::Type* elem_ty) {
|
||||
return [=](auto a1, auto a2) -> utils::Result<const Constant*> {
|
||||
if (auto r = Mul(a1, a2)) {
|
||||
return CreateElement(builder, elem_ty, r.Get());
|
||||
}
|
||||
return utils::Failure;
|
||||
};
|
||||
}
|
||||
|
||||
auto ConstEval::Dot2Func(const sem::Type* elem_ty) {
|
||||
return [=](auto a1, auto a2, auto b1, auto b2) -> utils::Result<const Constant*> {
|
||||
if (auto r = Dot2(a1, a2, b1, b2)) {
|
||||
return CreateElement(builder, elem_ty, r.Get());
|
||||
}
|
||||
return utils::Failure;
|
||||
};
|
||||
}
|
||||
|
||||
auto ConstEval::Dot3Func(const sem::Type* elem_ty) {
|
||||
return [=](auto a1, auto a2, auto a3, auto b1, auto b2,
|
||||
auto b3) -> utils::Result<const Constant*> {
|
||||
if (auto r = Dot3(a1, a2, a3, b1, b2, b3)) {
|
||||
return CreateElement(builder, elem_ty, r.Get());
|
||||
}
|
||||
return utils::Failure;
|
||||
};
|
||||
}
|
||||
|
||||
auto ConstEval::Dot4Func(const sem::Type* elem_ty) {
|
||||
return [=](auto a1, auto a2, auto a3, auto a4, auto b1, auto b2, auto b3,
|
||||
auto b4) -> utils::Result<const Constant*> {
|
||||
if (auto r = Dot4(a1, a2, a3, a4, b1, b2, b3, b4)) {
|
||||
return CreateElement(builder, elem_ty, r.Get());
|
||||
}
|
||||
return utils::Failure;
|
||||
};
|
||||
}
|
||||
|
||||
ConstEval::ConstantResult ConstEval::Literal(const sem::Type* ty,
|
||||
const ast::LiteralExpression* literal) {
|
||||
return Switch(
|
||||
|
@ -756,42 +948,15 @@ ConstEval::ConstantResult ConstEval::OpUnaryMinus(const sem::Type*,
|
|||
return TransformElements(builder, transform, args[0]);
|
||||
}
|
||||
|
||||
ConstEval::ConstantResult ConstEval::OpPlus(const sem::Type* ty,
|
||||
ConstEval::ConstantResult ConstEval::OpPlus(const sem::Type*,
|
||||
utils::VectorRef<const sem::Constant*> args,
|
||||
const Source& source) {
|
||||
auto transform = [&](const sem::Constant* c0, const sem::Constant* c1) {
|
||||
auto create = [&](auto i, auto j) -> const Constant* {
|
||||
using NumberT = decltype(i);
|
||||
using T = UnwrapNumber<NumberT>;
|
||||
|
||||
auto add_values = [](T lhs, T rhs) {
|
||||
if constexpr (std::is_integral_v<T> && std::is_signed_v<T>) {
|
||||
// Ensure no UB for signed overflow
|
||||
using UT = std::make_unsigned_t<T>;
|
||||
return static_cast<T>(static_cast<UT>(lhs) + static_cast<UT>(rhs));
|
||||
} else {
|
||||
return lhs + rhs;
|
||||
}
|
||||
};
|
||||
|
||||
NumberT result;
|
||||
if constexpr (std::is_same_v<NumberT, AInt> || std::is_same_v<NumberT, AFloat>) {
|
||||
// Check for over/underflow for abstract values
|
||||
if (auto r = CheckedAdd(i, j)) {
|
||||
result = r->value;
|
||||
} else {
|
||||
AddError("'" + std::to_string(add_values(i.value, j.value)) +
|
||||
"' cannot be represented as '" +
|
||||
ty->FriendlyName(builder.Symbols()) + "'",
|
||||
source);
|
||||
return nullptr;
|
||||
}
|
||||
} else {
|
||||
result = add_values(i.value, j.value);
|
||||
}
|
||||
return CreateElement(builder, c0->Type(), result);
|
||||
};
|
||||
return Dispatch_fia_fiu32_f16(create, c0, c1);
|
||||
TINT_SCOPED_ASSIGNMENT(current_source, &source);
|
||||
auto transform = [&](const sem::Constant* c0, const sem::Constant* c1) -> const Constant* {
|
||||
if (auto r = Dispatch_fia_fiu32_f16(AddFunc(c0->Type()), c0, c1)) {
|
||||
return r.Get();
|
||||
}
|
||||
return nullptr;
|
||||
};
|
||||
|
||||
auto r = TransformBinaryElements(builder, transform, args[0], args[1]);
|
||||
|
@ -846,6 +1011,192 @@ ConstEval::ConstantResult ConstEval::OpMinus(const sem::Type* ty,
|
|||
return r;
|
||||
}
|
||||
|
||||
ConstEval::ConstantResult ConstEval::OpMultiply(const sem::Type* /*ty*/,
|
||||
utils::VectorRef<const sem::Constant*> args,
|
||||
const Source& source) {
|
||||
TINT_SCOPED_ASSIGNMENT(current_source, &source);
|
||||
auto transform = [&](const sem::Constant* c0, const sem::Constant* c1) -> const Constant* {
|
||||
if (auto r = Dispatch_fia_fiu32_f16(MulFunc(c0->Type()), c0, c1)) {
|
||||
return r.Get();
|
||||
}
|
||||
return nullptr;
|
||||
};
|
||||
|
||||
auto r = TransformBinaryElements(builder, transform, args[0], args[1]);
|
||||
if (builder.Diagnostics().contains_errors()) {
|
||||
return utils::Failure;
|
||||
}
|
||||
return r;
|
||||
}
|
||||
|
||||
ConstEval::ConstantResult ConstEval::OpMultiplyMatVec(const sem::Type* ty,
|
||||
utils::VectorRef<const sem::Constant*> args,
|
||||
const Source& source) {
|
||||
TINT_SCOPED_ASSIGNMENT(current_source, &source);
|
||||
auto* mat_ty = args[0]->Type()->As<sem::Matrix>();
|
||||
auto* vec_ty = args[1]->Type()->As<sem::Vector>();
|
||||
auto* elem_ty = vec_ty->type();
|
||||
|
||||
auto dot = [&](const sem::Constant* m, size_t row, const sem::Constant* v) {
|
||||
utils::Result<const Constant*> result;
|
||||
switch (mat_ty->columns()) {
|
||||
case 2:
|
||||
result = Dispatch_fa_f32_f16(Dot2Func(elem_ty), //
|
||||
m->Index(0)->Index(row), //
|
||||
m->Index(1)->Index(row), //
|
||||
v->Index(0), //
|
||||
v->Index(1));
|
||||
break;
|
||||
case 3:
|
||||
result = Dispatch_fa_f32_f16(Dot3Func(elem_ty), //
|
||||
m->Index(0)->Index(row), //
|
||||
m->Index(1)->Index(row), //
|
||||
m->Index(2)->Index(row), //
|
||||
v->Index(0), //
|
||||
v->Index(1), v->Index(2));
|
||||
break;
|
||||
case 4:
|
||||
result = Dispatch_fa_f32_f16(Dot4Func(elem_ty), //
|
||||
m->Index(0)->Index(row), //
|
||||
m->Index(1)->Index(row), //
|
||||
m->Index(2)->Index(row), //
|
||||
m->Index(3)->Index(row), //
|
||||
v->Index(0), //
|
||||
v->Index(1), //
|
||||
v->Index(2), //
|
||||
v->Index(3));
|
||||
break;
|
||||
}
|
||||
return result;
|
||||
};
|
||||
|
||||
utils::Vector<const sem::Constant*, 4> result;
|
||||
for (size_t i = 0; i < mat_ty->rows(); ++i) {
|
||||
auto r = dot(args[0], i, args[1]); // matrix row i * vector
|
||||
if (!r) {
|
||||
return utils::Failure;
|
||||
}
|
||||
result.Push(r.Get());
|
||||
}
|
||||
return CreateComposite(builder, ty, result);
|
||||
}
|
||||
ConstEval::ConstantResult ConstEval::OpMultiplyVecMat(const sem::Type* ty,
|
||||
utils::VectorRef<const sem::Constant*> args,
|
||||
const Source& source) {
|
||||
TINT_SCOPED_ASSIGNMENT(current_source, &source);
|
||||
auto* vec_ty = args[0]->Type()->As<sem::Vector>();
|
||||
auto* mat_ty = args[1]->Type()->As<sem::Matrix>();
|
||||
auto* elem_ty = vec_ty->type();
|
||||
|
||||
auto dot = [&](const sem::Constant* v, const sem::Constant* m, size_t col) {
|
||||
utils::Result<const Constant*> result;
|
||||
switch (mat_ty->rows()) {
|
||||
case 2:
|
||||
result = Dispatch_fa_f32_f16(Dot2Func(elem_ty), //
|
||||
m->Index(col)->Index(0), //
|
||||
m->Index(col)->Index(1), //
|
||||
v->Index(0), //
|
||||
v->Index(1));
|
||||
break;
|
||||
case 3:
|
||||
result = Dispatch_fa_f32_f16(Dot3Func(elem_ty), //
|
||||
m->Index(col)->Index(0), //
|
||||
m->Index(col)->Index(1), //
|
||||
m->Index(col)->Index(2),
|
||||
v->Index(0), //
|
||||
v->Index(1), //
|
||||
v->Index(2));
|
||||
break;
|
||||
case 4:
|
||||
result = Dispatch_fa_f32_f16(Dot4Func(elem_ty), //
|
||||
m->Index(col)->Index(0), //
|
||||
m->Index(col)->Index(1), //
|
||||
m->Index(col)->Index(2), //
|
||||
m->Index(col)->Index(3), //
|
||||
v->Index(0), //
|
||||
v->Index(1), //
|
||||
v->Index(2), //
|
||||
v->Index(3));
|
||||
}
|
||||
return result;
|
||||
};
|
||||
|
||||
utils::Vector<const sem::Constant*, 4> result;
|
||||
for (size_t i = 0; i < mat_ty->columns(); ++i) {
|
||||
auto r = dot(args[0], args[1], i); // vector * matrix col i
|
||||
if (!r) {
|
||||
return utils::Failure;
|
||||
}
|
||||
result.Push(r.Get());
|
||||
}
|
||||
return CreateComposite(builder, ty, result);
|
||||
}
|
||||
|
||||
ConstEval::ConstantResult ConstEval::OpMultiplyMatMat(const sem::Type* ty,
|
||||
utils::VectorRef<const sem::Constant*> args,
|
||||
const Source& source) {
|
||||
TINT_SCOPED_ASSIGNMENT(current_source, &source);
|
||||
auto* mat1 = args[0];
|
||||
auto* mat2 = args[1];
|
||||
auto* mat1_ty = mat1->Type()->As<sem::Matrix>();
|
||||
auto* mat2_ty = mat2->Type()->As<sem::Matrix>();
|
||||
auto* elem_ty = mat1_ty->type();
|
||||
|
||||
auto dot = [&](const sem::Constant* m1, size_t row, const sem::Constant* m2, size_t col) {
|
||||
auto m1e = [&](size_t r, size_t c) { return m1->Index(c)->Index(r); };
|
||||
auto m2e = [&](size_t r, size_t c) { return m2->Index(c)->Index(r); };
|
||||
|
||||
utils::Result<const Constant*> result;
|
||||
switch (mat1_ty->columns()) {
|
||||
case 2:
|
||||
result = Dispatch_fa_f32_f16(Dot2Func(elem_ty), //
|
||||
m1e(row, 0), //
|
||||
m1e(row, 1), //
|
||||
m2e(0, col), //
|
||||
m2e(1, col));
|
||||
break;
|
||||
case 3:
|
||||
result = Dispatch_fa_f32_f16(Dot3Func(elem_ty), //
|
||||
m1e(row, 0), //
|
||||
m1e(row, 1), //
|
||||
m1e(row, 2), //
|
||||
m2e(0, col), //
|
||||
m2e(1, col), //
|
||||
m2e(2, col));
|
||||
break;
|
||||
case 4:
|
||||
result = Dispatch_fa_f32_f16(Dot4Func(elem_ty), //
|
||||
m1e(row, 0), //
|
||||
m1e(row, 1), //
|
||||
m1e(row, 2), //
|
||||
m1e(row, 3), //
|
||||
m2e(0, col), //
|
||||
m2e(1, col), //
|
||||
m2e(2, col), //
|
||||
m2e(3, col));
|
||||
break;
|
||||
}
|
||||
return result;
|
||||
};
|
||||
|
||||
utils::Vector<const sem::Constant*, 4> result_mat;
|
||||
for (size_t c = 0; c < mat2_ty->columns(); ++c) {
|
||||
utils::Vector<const sem::Constant*, 4> col_vec;
|
||||
for (size_t r = 0; r < mat1_ty->rows(); ++r) {
|
||||
auto v = dot(mat1, r, mat2, c); // mat1 row r * mat2 col c
|
||||
if (!v) {
|
||||
return utils::Failure;
|
||||
}
|
||||
col_vec.Push(v.Get()); // mat1 row r * mat2 col c
|
||||
}
|
||||
|
||||
// Add column vector to matrix
|
||||
auto* col_vec_ty = ty->As<sem::Matrix>()->ColumnType();
|
||||
result_mat.Push(CreateComposite(builder, col_vec_ty, col_vec));
|
||||
}
|
||||
return CreateComposite(builder, ty, result_mat);
|
||||
}
|
||||
|
||||
ConstEval::ConstantResult ConstEval::atan2(const sem::Type*,
|
||||
utils::VectorRef<const sem::Constant*> args,
|
||||
const Source&) {
|
||||
|
|
|
@ -230,6 +230,42 @@ class ConstEval {
|
|||
utils::VectorRef<const sem::Constant*> args,
|
||||
const Source& source);
|
||||
|
||||
/// Multiply operator '*' for the same type on the LHS and RHS
|
||||
/// @param ty the expression type
|
||||
/// @param args the input arguments
|
||||
/// @param source the source location of the conversion
|
||||
/// @return the result value, or null if the value cannot be calculated
|
||||
ConstantResult OpMultiply(const sem::Type* ty,
|
||||
utils::VectorRef<const sem::Constant*> args,
|
||||
const Source& source);
|
||||
|
||||
/// Multiply operator '*' for matCxR<T> * vecC<T>
|
||||
/// @param ty the expression type
|
||||
/// @param args the input arguments
|
||||
/// @param source the source location of the conversion
|
||||
/// @return the result value, or null if the value cannot be calculated
|
||||
ConstantResult OpMultiplyMatVec(const sem::Type* ty,
|
||||
utils::VectorRef<const sem::Constant*> args,
|
||||
const Source& source);
|
||||
|
||||
/// Multiply operator '*' for vecR<T> * matCxR<T>
|
||||
/// @param ty the expression type
|
||||
/// @param args the input arguments
|
||||
/// @param source the source location of the conversion
|
||||
/// @return the result value, or null if the value cannot be calculated
|
||||
ConstantResult OpMultiplyVecMat(const sem::Type* ty,
|
||||
utils::VectorRef<const sem::Constant*> args,
|
||||
const Source& source);
|
||||
|
||||
/// Multiply operator '*' for matKxR<T> * matCxK<T>
|
||||
/// @param ty the expression type
|
||||
/// @param args the input arguments
|
||||
/// @param source the source location of the conversion
|
||||
/// @return the result value, or null if the value cannot be calculated
|
||||
ConstantResult OpMultiplyMatMat(const sem::Type* ty,
|
||||
utils::VectorRef<const sem::Constant*> args,
|
||||
const Source& source);
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////
|
||||
// Builtins
|
||||
////////////////////////////////////////////////////////////////////////////
|
||||
|
@ -259,7 +295,97 @@ class ConstEval {
|
|||
/// Adds the given warning message to the diagnostics
|
||||
void AddWarning(const std::string& msg, const Source& source) const;
|
||||
|
||||
/// Adds two Number<T>s
|
||||
/// @param a the lhs number
|
||||
/// @param b the rhs number
|
||||
/// @returns the result number on success, or logs an error and returns Failure
|
||||
template <typename NumberT>
|
||||
utils::Result<NumberT> Add(NumberT a, NumberT b);
|
||||
|
||||
/// Multiplies two Number<T>s
|
||||
/// @param a the lhs number
|
||||
/// @param b the rhs number
|
||||
/// @returns the result number on success, or logs an error and returns Failure
|
||||
template <typename NumberT>
|
||||
utils::Result<NumberT> Mul(NumberT a, NumberT b);
|
||||
|
||||
/// Returns the dot product of (a1,a2) with (b1,b2)
|
||||
/// @param a1 component 1 of lhs vector
|
||||
/// @param a2 component 2 of lhs vector
|
||||
/// @param b1 component 1 of rhs vector
|
||||
/// @param b2 component 2 of rhs vector
|
||||
/// @returns the result number on success, or logs an error and returns Failure
|
||||
template <typename NumberT>
|
||||
utils::Result<NumberT> Dot2(NumberT a1, NumberT a2, NumberT b1, NumberT b2);
|
||||
|
||||
/// Returns the dot product of (a1,a2,a3) with (b1,b2,b3)
|
||||
/// @param a1 component 1 of lhs vector
|
||||
/// @param a2 component 2 of lhs vector
|
||||
/// @param a3 component 3 of lhs vector
|
||||
/// @param b1 component 1 of rhs vector
|
||||
/// @param b2 component 2 of rhs vector
|
||||
/// @param b3 component 3 of rhs vector
|
||||
/// @returns the result number on success, or logs an error and returns Failure
|
||||
template <typename NumberT>
|
||||
utils::Result<NumberT> Dot3(NumberT a1,
|
||||
NumberT a2,
|
||||
NumberT a3,
|
||||
NumberT b1,
|
||||
NumberT b2,
|
||||
NumberT b3);
|
||||
|
||||
/// Returns the dot product of (a1,b1,c1,d1) with (a2,b2,c2,d2)
|
||||
/// @param a1 component 1 of lhs vector
|
||||
/// @param a2 component 2 of lhs vector
|
||||
/// @param a3 component 3 of lhs vector
|
||||
/// @param a4 component 4 of lhs vector
|
||||
/// @param b1 component 1 of rhs vector
|
||||
/// @param b2 component 2 of rhs vector
|
||||
/// @param b3 component 3 of rhs vector
|
||||
/// @param b4 component 4 of rhs vector
|
||||
/// @returns the result number on success, or logs an error and returns Failure
|
||||
template <typename NumberT>
|
||||
utils::Result<NumberT> Dot4(NumberT a1,
|
||||
NumberT a2,
|
||||
NumberT a3,
|
||||
NumberT a4,
|
||||
NumberT b1,
|
||||
NumberT b2,
|
||||
NumberT b3,
|
||||
NumberT b4);
|
||||
|
||||
/// Returns a callable that calls Add, and creates a Constant with its result of type `elem_ty`
|
||||
/// if successful, or returns Failure otherwise.
|
||||
/// @param elem_ty the element type of the Constant to create on success
|
||||
/// @returns the callable function
|
||||
auto AddFunc(const sem::Type* elem_ty);
|
||||
|
||||
/// Returns a callable that calls Mul, and creates a Constant with its result of type `elem_ty`
|
||||
/// if successful, or returns Failure otherwise.
|
||||
/// @param elem_ty the element type of the Constant to create on success
|
||||
/// @returns the callable function
|
||||
auto MulFunc(const sem::Type* elem_ty);
|
||||
|
||||
/// Returns a callable that calls Dot2, and creates a Constant with its result of type `elem_ty`
|
||||
/// if successful, or returns Failure otherwise.
|
||||
/// @param elem_ty the element type of the Constant to create on success
|
||||
/// @returns the callable function
|
||||
auto Dot2Func(const sem::Type* elem_ty);
|
||||
|
||||
/// Returns a callable that calls Dot3, and creates a Constant with its result of type `elem_ty`
|
||||
/// if successful, or returns Failure otherwise.
|
||||
/// @param elem_ty the element type of the Constant to create on success
|
||||
/// @returns the callable function
|
||||
auto Dot3Func(const sem::Type* elem_ty);
|
||||
|
||||
/// Returns a callable that calls Dot4, and creates a Constant with its result of type `elem_ty`
|
||||
/// if successful, or returns Failure otherwise.
|
||||
/// @param elem_ty the element type of the Constant to create on success
|
||||
/// @returns the callable function
|
||||
auto Dot4Func(const sem::Type* elem_ty);
|
||||
|
||||
ProgramBuilder& builder;
|
||||
const Source* current_source = nullptr;
|
||||
};
|
||||
|
||||
} // namespace tint::resolver
|
||||
|
|
|
@ -15,6 +15,7 @@
|
|||
#include <cmath>
|
||||
#include <type_traits>
|
||||
|
||||
#include "gmock/gmock.h"
|
||||
#include "gtest/gtest.h"
|
||||
#include "src/tint/resolver/resolver_test_helper.h"
|
||||
#include "src/tint/sem/builtin_type.h"
|
||||
|
@ -24,6 +25,8 @@
|
|||
#include "src/tint/sem/test_helper.h"
|
||||
#include "src/tint/utils/transform.h"
|
||||
|
||||
using ::testing::HasSubstr;
|
||||
|
||||
using namespace tint::number_suffixes; // NOLINT
|
||||
|
||||
namespace tint::resolver {
|
||||
|
@ -74,6 +77,19 @@ auto Abs(const Number<T>& v) {
|
|||
}
|
||||
}
|
||||
|
||||
TINT_BEGIN_DISABLE_WARNING(CONSTANT_OVERFLOW);
|
||||
template <typename T>
|
||||
constexpr Number<T> Mul(Number<T> v1, Number<T> v2) {
|
||||
if constexpr (std::is_integral_v<T> && std::is_signed_v<T>) {
|
||||
// For signed integrals, avoid C++ UB by multiplying as unsigned
|
||||
using UT = std::make_unsigned_t<T>;
|
||||
return static_cast<Number<T>>(static_cast<UT>(v1) * static_cast<UT>(v2));
|
||||
} else {
|
||||
return static_cast<Number<T>>(v1 * v2);
|
||||
}
|
||||
}
|
||||
TINT_END_DISABLE_WARNING(CONSTANT_OVERFLOW);
|
||||
|
||||
// Concats any number of std::vectors
|
||||
template <typename Vec, typename... Vecs>
|
||||
[[nodiscard]] auto Concat(Vec&& v1, Vecs&&... vs) {
|
||||
|
@ -3228,29 +3244,26 @@ Case C(T lhs, U rhs, V expected, bool overflow = false) {
|
|||
return Case{Val(lhs), Val(rhs), Val(expected), overflow};
|
||||
}
|
||||
|
||||
static std::ostream& operator<<(std::ostream& o, const Case& c) {
|
||||
auto print_value = [&](auto&& value) {
|
||||
std::visit(
|
||||
[&](auto&& v) {
|
||||
using ValueType = std::decay_t<decltype(v)>;
|
||||
o << ValueType::DataType::Name() << "(";
|
||||
for (auto& a : v.args.values) {
|
||||
o << std::get<typename ValueType::ElementType>(a);
|
||||
if (&a != &v.args.values.Back()) {
|
||||
o << ", ";
|
||||
}
|
||||
static std::ostream& operator<<(std::ostream& o, const Types& types) {
|
||||
std::visit(
|
||||
[&](auto&& v) {
|
||||
using ValueType = std::decay_t<decltype(v)>;
|
||||
o << ValueType::DataType::Name() << "(";
|
||||
for (auto& a : v.args.values) {
|
||||
o << std::get<typename ValueType::ElementType>(a);
|
||||
if (&a != &v.args.values.Back()) {
|
||||
o << ", ";
|
||||
}
|
||||
o << ")";
|
||||
},
|
||||
value);
|
||||
};
|
||||
o << "lhs: ";
|
||||
print_value(c.lhs);
|
||||
o << ", rhs: ";
|
||||
print_value(c.rhs);
|
||||
o << ", expected: ";
|
||||
print_value(c.expected);
|
||||
o << ", overflow: " << c.overflow;
|
||||
}
|
||||
o << ")";
|
||||
},
|
||||
types);
|
||||
return o;
|
||||
}
|
||||
|
||||
static std::ostream& operator<<(std::ostream& o, const Case& c) {
|
||||
o << "lhs: " << c.lhs << ", rhs: " << c.rhs << ", expected: " << c.expected
|
||||
<< ", overflow: " << c.overflow;
|
||||
return o;
|
||||
}
|
||||
|
||||
|
@ -3281,7 +3294,6 @@ bool ForEachElemPair(const sem::Constant* a, const sem::Constant* b, Func&& f) {
|
|||
using ResolverConstEvalBinaryOpTest = ResolverTestWithParam<std::tuple<ast::BinaryOp, Case>>;
|
||||
TEST_P(ResolverConstEvalBinaryOpTest, Test) {
|
||||
Enable(ast::Extension::kF16);
|
||||
|
||||
auto op = std::get<0>(GetParam());
|
||||
auto& c = std::get<1>(GetParam());
|
||||
|
||||
|
@ -3300,10 +3312,8 @@ TEST_P(ResolverConstEvalBinaryOpTest, Test) {
|
|||
auto* expr = create<ast::BinaryExpression>(op, lhs_expr, rhs_expr);
|
||||
|
||||
GlobalConst("C", expr);
|
||||
|
||||
auto* expected_expr = expected.Expr(*this);
|
||||
GlobalConst("E", expected_expr);
|
||||
|
||||
EXPECT_TRUE(r()->Resolve()) << r()->error();
|
||||
|
||||
auto* sem = Sem().Get(expr);
|
||||
|
@ -3413,6 +3423,215 @@ INSTANTIATE_TEST_SUITE_P(Sub,
|
|||
OpSubFloatCases<f32>(),
|
||||
OpSubFloatCases<f16>()))));
|
||||
|
||||
template <typename T>
|
||||
std::vector<Case> OpMulScalarCases() {
|
||||
return {
|
||||
C(T{0}, T{0}, T{0}),
|
||||
C(T{1}, T{2}, T{2}),
|
||||
C(T{2}, T{3}, T{6}),
|
||||
C(Negate(T{2}), T{3}, Negate(T{6})),
|
||||
C(T::Highest(), T{1}, T::Highest()),
|
||||
C(T::Lowest(), T{1}, T::Lowest()),
|
||||
C(T::Highest(), T::Highest(), Mul(T::Highest(), T::Highest()), true),
|
||||
C(T::Lowest(), T::Lowest(), Mul(T::Lowest(), T::Lowest()), true),
|
||||
};
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
std::vector<Case> OpMulVecCases() {
|
||||
return {
|
||||
// s * vec3 = vec3
|
||||
C(Val(T{2.0}), Vec(T{1.25}, T{2.25}, T{3.25}), Vec(T{2.5}, T{4.5}, T{6.5})),
|
||||
// vec3 * s = vec3
|
||||
C(Vec(T{1.25}, T{2.25}, T{3.25}), Val(T{2.0}), Vec(T{2.5}, T{4.5}, T{6.5})),
|
||||
// vec3 * vec3 = vec3
|
||||
C(Vec(T{1.25}, T{2.25}, T{3.25}), Vec(T{2.0}, T{2.0}, T{2.0}), Vec(T{2.5}, T{4.5}, T{6.5})),
|
||||
};
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
std::vector<Case> OpMulMatCases() {
|
||||
return {
|
||||
// s * mat3x2 = mat3x2
|
||||
C(Val(T{2.25}),
|
||||
Mat({T{1.0}, T{4.0}}, //
|
||||
{T{2.0}, T{5.0}}, //
|
||||
{T{3.0}, T{6.0}}),
|
||||
Mat({T{2.25}, T{9.0}}, //
|
||||
{T{4.5}, T{11.25}}, //
|
||||
{T{6.75}, T{13.5}})),
|
||||
// mat3x2 * s = mat3x2
|
||||
C(Mat({T{1.0}, T{4.0}}, //
|
||||
{T{2.0}, T{5.0}}, //
|
||||
{T{3.0}, T{6.0}}),
|
||||
Val(T{2.25}),
|
||||
Mat({T{2.25}, T{9.0}}, //
|
||||
{T{4.5}, T{11.25}}, //
|
||||
{T{6.75}, T{13.5}})),
|
||||
// vec3 * mat2x3 = vec2
|
||||
C(Vec(T{1.25}, T{2.25}, T{3.25}), //
|
||||
Mat({T{1.0}, T{2.0}, T{3.0}}, //
|
||||
{T{4.0}, T{5.0}, T{6.0}}), //
|
||||
Vec(T{15.5}, T{35.75})),
|
||||
// mat2x3 * vec2 = vec3
|
||||
C(Mat({T{1.0}, T{2.0}, T{3.0}}, //
|
||||
{T{4.0}, T{5.0}, T{6.0}}), //
|
||||
Vec(T{1.25}, T{2.25}), //
|
||||
Vec(T{10.25}, T{13.75}, T{17.25})),
|
||||
// mat3x2 * mat2x3 = mat2x2
|
||||
C(Mat({T{1.0}, T{2.0}}, //
|
||||
{T{3.0}, T{4.0}}, //
|
||||
{T{5.0}, T{6.0}}), //
|
||||
Mat({T{1.25}, T{2.25}, T{3.25}}, //
|
||||
{T{4.25}, T{5.25}, T{6.25}}), //
|
||||
Mat({T{24.25}, T{31.0}}, //
|
||||
{T{51.25}, T{67.0}})), //
|
||||
};
|
||||
}
|
||||
|
||||
INSTANTIATE_TEST_SUITE_P(Mul,
|
||||
ResolverConstEvalBinaryOpTest,
|
||||
testing::Combine( //
|
||||
testing::Values(ast::BinaryOp::kMultiply),
|
||||
testing::ValuesIn(Concat( //
|
||||
OpMulScalarCases<AInt>(),
|
||||
OpMulScalarCases<i32>(),
|
||||
OpMulScalarCases<u32>(),
|
||||
OpMulScalarCases<AFloat>(),
|
||||
OpMulScalarCases<f32>(),
|
||||
OpMulScalarCases<f16>(),
|
||||
OpMulVecCases<AInt>(),
|
||||
OpMulVecCases<i32>(),
|
||||
OpMulVecCases<u32>(),
|
||||
OpMulVecCases<AFloat>(),
|
||||
OpMulVecCases<f32>(),
|
||||
OpMulVecCases<f16>(),
|
||||
OpMulMatCases<AFloat>(),
|
||||
OpMulMatCases<f32>(),
|
||||
OpMulMatCases<f16>()))));
|
||||
|
||||
// Tests for errors on overflow/underflow of binary operations with abstract numbers
|
||||
struct OverflowCase {
|
||||
ast::BinaryOp op;
|
||||
Types lhs;
|
||||
Types rhs;
|
||||
std::string overflowed_result;
|
||||
};
|
||||
|
||||
static std::ostream& operator<<(std::ostream& o, const OverflowCase& c) {
|
||||
o << ast::FriendlyName(c.op) << ", lhs: " << c.lhs << ", rhs: " << c.rhs;
|
||||
return o;
|
||||
}
|
||||
using ResolverConstEvalBinaryOpTest_Overflow = ResolverTestWithParam<OverflowCase>;
|
||||
TEST_P(ResolverConstEvalBinaryOpTest_Overflow, Test) {
|
||||
Enable(ast::Extension::kF16);
|
||||
auto& c = GetParam();
|
||||
auto* lhs_expr = std::visit([&](auto&& value) { return value.Expr(*this); }, c.lhs);
|
||||
auto* rhs_expr = std::visit([&](auto&& value) { return value.Expr(*this); }, c.rhs);
|
||||
auto* expr = create<ast::BinaryExpression>(Source{{1, 1}}, c.op, lhs_expr, rhs_expr);
|
||||
GlobalConst("C", expr);
|
||||
ASSERT_FALSE(r()->Resolve());
|
||||
|
||||
std::string type_name = std::visit(
|
||||
[&](auto&& value) {
|
||||
using ValueType = std::decay_t<decltype(value)>;
|
||||
return tint::FriendlyName<typename ValueType::ElementType>();
|
||||
},
|
||||
c.lhs);
|
||||
|
||||
EXPECT_THAT(r()->error(), HasSubstr("1:1 error: '" + c.overflowed_result +
|
||||
"' cannot be represented as '" + type_name + "'"));
|
||||
}
|
||||
INSTANTIATE_TEST_SUITE_P(
|
||||
Test,
|
||||
ResolverConstEvalBinaryOpTest_Overflow,
|
||||
testing::Values( //
|
||||
// scalar-scalar add
|
||||
OverflowCase{ast::BinaryOp::kAdd, Val(AInt::Highest()), Val(1_a), "-9223372036854775808"},
|
||||
OverflowCase{ast::BinaryOp::kAdd, Val(AInt::Lowest()), Val(-1_a), "9223372036854775807"},
|
||||
OverflowCase{ast::BinaryOp::kAdd, Val(AFloat::Highest()), Val(AFloat::Highest()), "inf"},
|
||||
OverflowCase{ast::BinaryOp::kAdd, Val(AFloat::Lowest()), Val(AFloat::Lowest()), "-inf"},
|
||||
// scalar-scalar subtract
|
||||
OverflowCase{ast::BinaryOp::kSubtract, Val(AInt::Lowest()), Val(1_a),
|
||||
"9223372036854775807"},
|
||||
OverflowCase{ast::BinaryOp::kSubtract, Val(AInt::Highest()), Val(-1_a),
|
||||
"-9223372036854775808"},
|
||||
OverflowCase{ast::BinaryOp::kSubtract, Val(AFloat::Highest()), Val(AFloat::Lowest()),
|
||||
"inf"},
|
||||
OverflowCase{ast::BinaryOp::kSubtract, Val(AFloat::Lowest()), Val(AFloat::Highest()),
|
||||
"-inf"},
|
||||
|
||||
// scalar-scalar multiply
|
||||
OverflowCase{ast::BinaryOp::kMultiply, Val(AInt::Highest()), Val(2_a), "-2"},
|
||||
OverflowCase{ast::BinaryOp::kMultiply, Val(AInt::Lowest()), Val(-2_a), "0"},
|
||||
|
||||
// scalar-vector multiply
|
||||
OverflowCase{ast::BinaryOp::kMultiply, Val(AInt::Highest()), Vec(2_a, 1_a), "-2"},
|
||||
OverflowCase{ast::BinaryOp::kMultiply, Val(AInt::Lowest()), Vec(-2_a, 1_a), "0"},
|
||||
|
||||
// vector-matrix multiply
|
||||
|
||||
// Overflow from first multiplication of dot product of vector and matrix column 0
|
||||
// i.e. (v[0] * m[0][0] + v[1] * m[0][1])
|
||||
// ^
|
||||
OverflowCase{ast::BinaryOp::kMultiply, //
|
||||
Vec(AFloat::Highest(), 1.0_a), //
|
||||
Mat({2.0_a, 1.0_a}, //
|
||||
{1.0_a, 1.0_a}), //
|
||||
"inf"},
|
||||
|
||||
// Overflow from second multiplication of dot product of vector and matrix column 0
|
||||
// i.e. (v[0] * m[0][0] + v[1] * m[0][1])
|
||||
// ^
|
||||
OverflowCase{ast::BinaryOp::kMultiply, //
|
||||
Vec(1.0_a, AFloat::Highest()), //
|
||||
Mat({1.0_a, 2.0_a}, //
|
||||
{1.0_a, 1.0_a}), //
|
||||
"inf"},
|
||||
|
||||
// Overflow from addition of dot product of vector and matrix column 0
|
||||
// i.e. (v[0] * m[0][0] + v[1] * m[0][1])
|
||||
// ^
|
||||
OverflowCase{ast::BinaryOp::kMultiply, //
|
||||
Vec(AFloat::Highest(), AFloat::Highest()), //
|
||||
Mat({1.0_a, 1.0_a}, //
|
||||
{1.0_a, 1.0_a}), //
|
||||
"inf"},
|
||||
|
||||
// matrix-matrix multiply
|
||||
|
||||
// Overflow from first multiplication of dot product of lhs row 0 and rhs column 0
|
||||
// i.e. m1[0][0] * m2[0][0] + m1[0][1] * m[1][0]
|
||||
// ^
|
||||
OverflowCase{ast::BinaryOp::kMultiply, //
|
||||
Mat({AFloat::Highest(), 1.0_a}, //
|
||||
{1.0_a, 1.0_a}), //
|
||||
Mat({2.0_a, 1.0_a}, //
|
||||
{1.0_a, 1.0_a}), //
|
||||
"inf"},
|
||||
|
||||
// Overflow from second multiplication of dot product of lhs row 0 and rhs column 0
|
||||
// i.e. m1[0][0] * m2[0][0] + m1[0][1] * m[1][0]
|
||||
// ^
|
||||
OverflowCase{ast::BinaryOp::kMultiply, //
|
||||
Mat({1.0_a, AFloat::Highest()}, //
|
||||
{1.0_a, 1.0_a}), //
|
||||
Mat({1.0_a, 1.0_a}, //
|
||||
{2.0_a, 1.0_a}), //
|
||||
"inf"},
|
||||
|
||||
// Overflow from addition of dot product of lhs row 0 and rhs column 0
|
||||
// i.e. m1[0][0] * m2[0][0] + m1[0][1] * m[1][0]
|
||||
// ^
|
||||
OverflowCase{ast::BinaryOp::kMultiply, //
|
||||
Mat({AFloat::Highest(), 1.0_a}, //
|
||||
{AFloat::Highest(), 1.0_a}), //
|
||||
Mat({1.0_a, 1.0_a}, //
|
||||
{1.0_a, 1.0_a}), //
|
||||
"inf"}
|
||||
|
||||
));
|
||||
|
||||
TEST_F(ResolverConstEvalTest, BinaryAbstractAddOverflow_AInt) {
|
||||
GlobalConst("c", Add(Source{{1, 1}}, Expr(AInt::Highest()), 1_a));
|
||||
EXPECT_FALSE(r()->Resolve());
|
||||
|
|
|
@ -9572,108 +9572,108 @@ constexpr OverloadInfo kOverloads[] = {
|
|||
/* num parameters */ 2,
|
||||
/* num template types */ 1,
|
||||
/* num template numbers */ 0,
|
||||
/* template types */ &kTemplateTypes[15],
|
||||
/* template types */ &kTemplateTypes[13],
|
||||
/* template numbers */ &kTemplateNumbers[10],
|
||||
/* parameters */ &kParameters[727],
|
||||
/* return matcher indices */ &kMatcherIndices[1],
|
||||
/* flags */ OverloadFlags(OverloadFlag::kIsOperator, OverloadFlag::kSupportsVertexPipeline, OverloadFlag::kSupportsFragmentPipeline, OverloadFlag::kSupportsComputePipeline),
|
||||
/* const eval */ nullptr,
|
||||
/* const eval */ &ConstEval::OpMultiply,
|
||||
},
|
||||
{
|
||||
/* [118] */
|
||||
/* num parameters */ 2,
|
||||
/* num template types */ 1,
|
||||
/* num template numbers */ 1,
|
||||
/* template types */ &kTemplateTypes[15],
|
||||
/* template types */ &kTemplateTypes[13],
|
||||
/* template numbers */ &kTemplateNumbers[6],
|
||||
/* parameters */ &kParameters[725],
|
||||
/* return matcher indices */ &kMatcherIndices[30],
|
||||
/* flags */ OverloadFlags(OverloadFlag::kIsOperator, OverloadFlag::kSupportsVertexPipeline, OverloadFlag::kSupportsFragmentPipeline, OverloadFlag::kSupportsComputePipeline),
|
||||
/* const eval */ nullptr,
|
||||
/* const eval */ &ConstEval::OpMultiply,
|
||||
},
|
||||
{
|
||||
/* [119] */
|
||||
/* num parameters */ 2,
|
||||
/* num template types */ 1,
|
||||
/* num template numbers */ 1,
|
||||
/* template types */ &kTemplateTypes[15],
|
||||
/* template types */ &kTemplateTypes[13],
|
||||
/* template numbers */ &kTemplateNumbers[6],
|
||||
/* parameters */ &kParameters[723],
|
||||
/* return matcher indices */ &kMatcherIndices[30],
|
||||
/* flags */ OverloadFlags(OverloadFlag::kIsOperator, OverloadFlag::kSupportsVertexPipeline, OverloadFlag::kSupportsFragmentPipeline, OverloadFlag::kSupportsComputePipeline),
|
||||
/* const eval */ nullptr,
|
||||
/* const eval */ &ConstEval::OpMultiply,
|
||||
},
|
||||
{
|
||||
/* [120] */
|
||||
/* num parameters */ 2,
|
||||
/* num template types */ 1,
|
||||
/* num template numbers */ 1,
|
||||
/* template types */ &kTemplateTypes[15],
|
||||
/* template types */ &kTemplateTypes[13],
|
||||
/* template numbers */ &kTemplateNumbers[6],
|
||||
/* parameters */ &kParameters[721],
|
||||
/* return matcher indices */ &kMatcherIndices[30],
|
||||
/* flags */ OverloadFlags(OverloadFlag::kIsOperator, OverloadFlag::kSupportsVertexPipeline, OverloadFlag::kSupportsFragmentPipeline, OverloadFlag::kSupportsComputePipeline),
|
||||
/* const eval */ nullptr,
|
||||
/* const eval */ &ConstEval::OpMultiply,
|
||||
},
|
||||
{
|
||||
/* [121] */
|
||||
/* num parameters */ 2,
|
||||
/* num template types */ 1,
|
||||
/* num template numbers */ 2,
|
||||
/* template types */ &kTemplateTypes[11],
|
||||
/* template types */ &kTemplateTypes[12],
|
||||
/* template numbers */ &kTemplateNumbers[6],
|
||||
/* parameters */ &kParameters[719],
|
||||
/* return matcher indices */ &kMatcherIndices[10],
|
||||
/* flags */ OverloadFlags(OverloadFlag::kIsOperator, OverloadFlag::kSupportsVertexPipeline, OverloadFlag::kSupportsFragmentPipeline, OverloadFlag::kSupportsComputePipeline),
|
||||
/* const eval */ nullptr,
|
||||
/* const eval */ &ConstEval::OpMultiply,
|
||||
},
|
||||
{
|
||||
/* [122] */
|
||||
/* num parameters */ 2,
|
||||
/* num template types */ 1,
|
||||
/* num template numbers */ 2,
|
||||
/* template types */ &kTemplateTypes[11],
|
||||
/* template types */ &kTemplateTypes[12],
|
||||
/* template numbers */ &kTemplateNumbers[6],
|
||||
/* parameters */ &kParameters[717],
|
||||
/* return matcher indices */ &kMatcherIndices[10],
|
||||
/* flags */ OverloadFlags(OverloadFlag::kIsOperator, OverloadFlag::kSupportsVertexPipeline, OverloadFlag::kSupportsFragmentPipeline, OverloadFlag::kSupportsComputePipeline),
|
||||
/* const eval */ nullptr,
|
||||
/* const eval */ &ConstEval::OpMultiply,
|
||||
},
|
||||
{
|
||||
/* [123] */
|
||||
/* num parameters */ 2,
|
||||
/* num template types */ 1,
|
||||
/* num template numbers */ 2,
|
||||
/* template types */ &kTemplateTypes[11],
|
||||
/* template types */ &kTemplateTypes[12],
|
||||
/* template numbers */ &kTemplateNumbers[1],
|
||||
/* parameters */ &kParameters[715],
|
||||
/* return matcher indices */ &kMatcherIndices[69],
|
||||
/* flags */ OverloadFlags(OverloadFlag::kIsOperator, OverloadFlag::kSupportsVertexPipeline, OverloadFlag::kSupportsFragmentPipeline, OverloadFlag::kSupportsComputePipeline),
|
||||
/* const eval */ nullptr,
|
||||
/* const eval */ &ConstEval::OpMultiplyMatVec,
|
||||
},
|
||||
{
|
||||
/* [124] */
|
||||
/* num parameters */ 2,
|
||||
/* num template types */ 1,
|
||||
/* num template numbers */ 2,
|
||||
/* template types */ &kTemplateTypes[11],
|
||||
/* template types */ &kTemplateTypes[12],
|
||||
/* template numbers */ &kTemplateNumbers[1],
|
||||
/* parameters */ &kParameters[713],
|
||||
/* return matcher indices */ &kMatcherIndices[30],
|
||||
/* flags */ OverloadFlags(OverloadFlag::kIsOperator, OverloadFlag::kSupportsVertexPipeline, OverloadFlag::kSupportsFragmentPipeline, OverloadFlag::kSupportsComputePipeline),
|
||||
/* const eval */ nullptr,
|
||||
/* const eval */ &ConstEval::OpMultiplyVecMat,
|
||||
},
|
||||
{
|
||||
/* [125] */
|
||||
/* num parameters */ 2,
|
||||
/* num template types */ 1,
|
||||
/* num template numbers */ 3,
|
||||
/* template types */ &kTemplateTypes[11],
|
||||
/* template types */ &kTemplateTypes[12],
|
||||
/* template numbers */ &kTemplateNumbers[0],
|
||||
/* parameters */ &kParameters[711],
|
||||
/* return matcher indices */ &kMatcherIndices[22],
|
||||
/* flags */ OverloadFlags(OverloadFlag::kIsOperator, OverloadFlag::kSupportsVertexPipeline, OverloadFlag::kSupportsFragmentPipeline, OverloadFlag::kSupportsComputePipeline),
|
||||
/* const eval */ nullptr,
|
||||
/* const eval */ &ConstEval::OpMultiplyMatMat,
|
||||
},
|
||||
{
|
||||
/* [126] */
|
||||
|
@ -14630,15 +14630,15 @@ constexpr IntrinsicInfo kBinaryOperators[] = {
|
|||
},
|
||||
{
|
||||
/* [2] */
|
||||
/* op *<T : fiu32_f16>(T, T) -> T */
|
||||
/* op *<T : fiu32_f16, N : num>(vec<N, T>, vec<N, T>) -> vec<N, T> */
|
||||
/* op *<T : fiu32_f16, N : num>(vec<N, T>, T) -> vec<N, T> */
|
||||
/* op *<T : fiu32_f16, N : num>(T, vec<N, T>) -> vec<N, T> */
|
||||
/* op *<T : f32_f16, N : num, M : num>(T, mat<N, M, T>) -> mat<N, M, T> */
|
||||
/* op *<T : f32_f16, N : num, M : num>(mat<N, M, T>, T) -> mat<N, M, T> */
|
||||
/* op *<T : f32_f16, C : num, R : num>(mat<C, R, T>, vec<C, T>) -> vec<R, T> */
|
||||
/* op *<T : f32_f16, C : num, R : num>(vec<R, T>, mat<C, R, T>) -> vec<C, T> */
|
||||
/* op *<T : f32_f16, K : num, C : num, R : num>(mat<K, R, T>, mat<C, K, T>) -> mat<C, R, T> */
|
||||
/* op *<T : fia_fiu32_f16>(T, T) -> T */
|
||||
/* op *<T : fia_fiu32_f16, N : num>(vec<N, T>, vec<N, T>) -> vec<N, T> */
|
||||
/* op *<T : fia_fiu32_f16, N : num>(vec<N, T>, T) -> vec<N, T> */
|
||||
/* op *<T : fia_fiu32_f16, N : num>(T, vec<N, T>) -> vec<N, T> */
|
||||
/* op *<T : fa_f32_f16, N : num, M : num>(T, mat<N, M, T>) -> mat<N, M, T> */
|
||||
/* op *<T : fa_f32_f16, N : num, M : num>(mat<N, M, T>, T) -> mat<N, M, T> */
|
||||
/* op *<T : fa_f32_f16, C : num, R : num>(mat<C, R, T>, vec<C, T>) -> vec<R, T> */
|
||||
/* op *<T : fa_f32_f16, C : num, R : num>(vec<R, T>, mat<C, R, T>) -> vec<C, T> */
|
||||
/* op *<T : fa_f32_f16, K : num, C : num, R : num>(mat<K, R, T>, mat<C, K, T>) -> mat<C, R, T> */
|
||||
/* num overloads */ 9,
|
||||
/* overloads */ &kOverloads[117],
|
||||
},
|
||||
|
|
|
@ -641,15 +641,15 @@ TEST_F(IntrinsicTableTest, MismatchBinaryOp) {
|
|||
EXPECT_EQ(Diagnostics().str(), R"(12:34 error: no matching overload for operator * (f32, bool)
|
||||
|
||||
9 candidate operators:
|
||||
operator * (T, T) -> T where: T is f32, i32, u32 or f16
|
||||
operator * (vecN<T>, T) -> vecN<T> where: T is f32, i32, u32 or f16
|
||||
operator * (T, vecN<T>) -> vecN<T> where: T is f32, i32, u32 or f16
|
||||
operator * (T, matNxM<T>) -> matNxM<T> where: T is f32 or f16
|
||||
operator * (matNxM<T>, T) -> matNxM<T> where: T is f32 or f16
|
||||
operator * (vecN<T>, vecN<T>) -> vecN<T> where: T is f32, i32, u32 or f16
|
||||
operator * (matCxR<T>, vecC<T>) -> vecR<T> where: T is f32 or f16
|
||||
operator * (vecR<T>, matCxR<T>) -> vecC<T> where: T is f32 or f16
|
||||
operator * (matKxR<T>, matCxK<T>) -> matCxR<T> where: T is f32 or f16
|
||||
operator * (T, T) -> T where: T is abstract-float, abstract-int, f32, i32, u32 or f16
|
||||
operator * (vecN<T>, T) -> vecN<T> where: T is abstract-float, abstract-int, f32, i32, u32 or f16
|
||||
operator * (T, vecN<T>) -> vecN<T> where: T is abstract-float, abstract-int, f32, i32, u32 or f16
|
||||
operator * (T, matNxM<T>) -> matNxM<T> where: T is abstract-float, f32 or f16
|
||||
operator * (matNxM<T>, T) -> matNxM<T> where: T is abstract-float, f32 or f16
|
||||
operator * (vecN<T>, vecN<T>) -> vecN<T> where: T is abstract-float, abstract-int, f32, i32, u32 or f16
|
||||
operator * (matCxR<T>, vecC<T>) -> vecR<T> where: T is abstract-float, f32 or f16
|
||||
operator * (vecR<T>, matCxR<T>) -> vecC<T> where: T is abstract-float, f32 or f16
|
||||
operator * (matKxR<T>, matCxK<T>) -> matCxR<T> where: T is abstract-float, f32 or f16
|
||||
)");
|
||||
}
|
||||
|
||||
|
@ -673,15 +673,15 @@ TEST_F(IntrinsicTableTest, MismatchCompoundOp) {
|
|||
EXPECT_EQ(Diagnostics().str(), R"(12:34 error: no matching overload for operator *= (f32, bool)
|
||||
|
||||
9 candidate operators:
|
||||
operator *= (T, T) -> T where: T is f32, i32, u32 or f16
|
||||
operator *= (vecN<T>, T) -> vecN<T> where: T is f32, i32, u32 or f16
|
||||
operator *= (T, vecN<T>) -> vecN<T> where: T is f32, i32, u32 or f16
|
||||
operator *= (T, matNxM<T>) -> matNxM<T> where: T is f32 or f16
|
||||
operator *= (matNxM<T>, T) -> matNxM<T> where: T is f32 or f16
|
||||
operator *= (vecN<T>, vecN<T>) -> vecN<T> where: T is f32, i32, u32 or f16
|
||||
operator *= (matCxR<T>, vecC<T>) -> vecR<T> where: T is f32 or f16
|
||||
operator *= (vecR<T>, matCxR<T>) -> vecC<T> where: T is f32 or f16
|
||||
operator *= (matKxR<T>, matCxK<T>) -> matCxR<T> where: T is f32 or f16
|
||||
operator *= (T, T) -> T where: T is abstract-float, abstract-int, f32, i32, u32 or f16
|
||||
operator *= (vecN<T>, T) -> vecN<T> where: T is abstract-float, abstract-int, f32, i32, u32 or f16
|
||||
operator *= (T, vecN<T>) -> vecN<T> where: T is abstract-float, abstract-int, f32, i32, u32 or f16
|
||||
operator *= (T, matNxM<T>) -> matNxM<T> where: T is abstract-float, f32 or f16
|
||||
operator *= (matNxM<T>, T) -> matNxM<T> where: T is abstract-float, f32 or f16
|
||||
operator *= (vecN<T>, vecN<T>) -> vecN<T> where: T is abstract-float, abstract-int, f32, i32, u32 or f16
|
||||
operator *= (matCxR<T>, vecC<T>) -> vecR<T> where: T is abstract-float, f32 or f16
|
||||
operator *= (vecR<T>, matCxR<T>) -> vecC<T> where: T is abstract-float, f32 or f16
|
||||
operator *= (matKxR<T>, matCxK<T>) -> matCxR<T> where: T is abstract-float, f32 or f16
|
||||
)");
|
||||
}
|
||||
|
||||
|
|
|
@ -17,6 +17,7 @@
|
|||
|
||||
#include <ostream>
|
||||
#include <variant>
|
||||
#include "src/tint/debug.h"
|
||||
|
||||
namespace tint::utils {
|
||||
|
||||
|
@ -36,6 +37,9 @@ struct [[nodiscard]] Result {
|
|||
static_assert(!std::is_same_v<SUCCESS_TYPE, FAILURE_TYPE>,
|
||||
"Result must not have the same type for SUCCESS_TYPE and FAILURE_TYPE");
|
||||
|
||||
/// Default constructor initializes to invalid state
|
||||
Result() : value(std::monostate{}) {}
|
||||
|
||||
/// Constructor
|
||||
/// @param success the success result
|
||||
Result(const SUCCESS_TYPE& success) // NOLINT(runtime/explicit):
|
||||
|
@ -47,27 +51,43 @@ struct [[nodiscard]] Result {
|
|||
: value{failure} {}
|
||||
|
||||
/// @returns true if the result was a success
|
||||
operator bool() const { return std::holds_alternative<SUCCESS_TYPE>(value); }
|
||||
operator bool() const {
|
||||
Validate();
|
||||
return std::holds_alternative<SUCCESS_TYPE>(value);
|
||||
}
|
||||
|
||||
/// @returns true if the result was a failure
|
||||
bool operator!() const { return std::holds_alternative<FAILURE_TYPE>(value); }
|
||||
bool operator!() const {
|
||||
Validate();
|
||||
return std::holds_alternative<FAILURE_TYPE>(value);
|
||||
}
|
||||
|
||||
/// @returns the success value
|
||||
/// @warning attempting to call this when the Result holds an failure will result in UB.
|
||||
const SUCCESS_TYPE* operator->() const { return &std::get<SUCCESS_TYPE>(value); }
|
||||
const SUCCESS_TYPE* operator->() const {
|
||||
Validate();
|
||||
return &(Get());
|
||||
}
|
||||
|
||||
/// @returns the success value
|
||||
/// @warning attempting to call this when the Result holds an failure value will result in UB.
|
||||
const SUCCESS_TYPE& Get() const { return std::get<SUCCESS_TYPE>(value); }
|
||||
const SUCCESS_TYPE& Get() const {
|
||||
Validate();
|
||||
return std::get<SUCCESS_TYPE>(value);
|
||||
}
|
||||
|
||||
/// @returns the failure value
|
||||
/// @warning attempting to call this when the Result holds a success value will result in UB.
|
||||
const FAILURE_TYPE& Failure() const { return std::get<FAILURE_TYPE>(value); }
|
||||
const FAILURE_TYPE& Failure() const {
|
||||
Validate();
|
||||
return std::get<FAILURE_TYPE>(value);
|
||||
}
|
||||
|
||||
/// Equality operator
|
||||
/// @param val the value to compare this Result to
|
||||
/// @returns true if this result holds a success value equal to `value`
|
||||
bool operator==(SUCCESS_TYPE val) const {
|
||||
Validate();
|
||||
if (auto* v = std::get_if<SUCCESS_TYPE>(&value)) {
|
||||
return *v == val;
|
||||
}
|
||||
|
@ -78,14 +98,18 @@ struct [[nodiscard]] Result {
|
|||
/// @param val the value to compare this Result to
|
||||
/// @returns true if this result holds a failure value equal to `value`
|
||||
bool operator==(FAILURE_TYPE val) const {
|
||||
Validate();
|
||||
if (auto* v = std::get_if<FAILURE_TYPE>(&value)) {
|
||||
return *v == val;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
private:
|
||||
void Validate() const { TINT_ASSERT(Utils, !std::holds_alternative<std::monostate>(value)); }
|
||||
|
||||
/// The result. Either a success of failure value.
|
||||
std::variant<SUCCESS_TYPE, FAILURE_TYPE> value;
|
||||
std::variant<std::monostate, SUCCESS_TYPE, FAILURE_TYPE> value;
|
||||
};
|
||||
|
||||
/// Writes the result to the ostream.
|
||||
|
|
|
@ -151,7 +151,8 @@ INSTANTIATE_TEST_SUITE_P(
|
|||
BinaryData{"(left % right)", ast::BinaryOp::kModulo}));
|
||||
|
||||
TEST_F(GlslGeneratorImplTest_Binary, Multiply_VectorScalar_f32) {
|
||||
auto* lhs = vec3<f32>(1_f, 1_f, 1_f);
|
||||
GlobalVar("a", vec3<f32>(1_f, 1_f, 1_f), ast::StorageClass::kPrivate);
|
||||
auto* lhs = Expr("a");
|
||||
auto* rhs = Expr(1_f);
|
||||
|
||||
auto* expr = create<ast::BinaryExpression>(ast::BinaryOp::kMultiply, lhs, rhs);
|
||||
|
@ -162,13 +163,14 @@ TEST_F(GlslGeneratorImplTest_Binary, Multiply_VectorScalar_f32) {
|
|||
|
||||
std::stringstream out;
|
||||
EXPECT_TRUE(gen.EmitExpression(out, expr)) << gen.error();
|
||||
EXPECT_EQ(out.str(), "(vec3(1.0f) * 1.0f)");
|
||||
EXPECT_EQ(out.str(), "(a * 1.0f)");
|
||||
}
|
||||
|
||||
TEST_F(GlslGeneratorImplTest_Binary, Multiply_VectorScalar_f16) {
|
||||
Enable(ast::Extension::kF16);
|
||||
|
||||
auto* lhs = vec3<f16>(1_h, 1_h, 1_h);
|
||||
GlobalVar("a", vec3<f16>(1_h, 1_h, 1_h), ast::StorageClass::kPrivate);
|
||||
auto* lhs = Expr("a");
|
||||
auto* rhs = Expr(1_h);
|
||||
|
||||
auto* expr = create<ast::BinaryExpression>(ast::BinaryOp::kMultiply, lhs, rhs);
|
||||
|
@ -179,12 +181,13 @@ TEST_F(GlslGeneratorImplTest_Binary, Multiply_VectorScalar_f16) {
|
|||
|
||||
std::stringstream out;
|
||||
EXPECT_TRUE(gen.EmitExpression(out, expr)) << gen.error();
|
||||
EXPECT_EQ(out.str(), "(f16vec3(1.0hf) * 1.0hf)");
|
||||
EXPECT_EQ(out.str(), "(a * 1.0hf)");
|
||||
}
|
||||
|
||||
TEST_F(GlslGeneratorImplTest_Binary, Multiply_ScalarVector_f32) {
|
||||
GlobalVar("a", vec3<f32>(1_f, 1_f, 1_f), ast::StorageClass::kPrivate);
|
||||
auto* lhs = Expr(1_f);
|
||||
auto* rhs = vec3<f32>(1_f, 1_f, 1_f);
|
||||
auto* rhs = Expr("a");
|
||||
|
||||
auto* expr = create<ast::BinaryExpression>(ast::BinaryOp::kMultiply, lhs, rhs);
|
||||
|
||||
|
@ -194,14 +197,15 @@ TEST_F(GlslGeneratorImplTest_Binary, Multiply_ScalarVector_f32) {
|
|||
|
||||
std::stringstream out;
|
||||
EXPECT_TRUE(gen.EmitExpression(out, expr)) << gen.error();
|
||||
EXPECT_EQ(out.str(), "(1.0f * vec3(1.0f))");
|
||||
EXPECT_EQ(out.str(), "(1.0f * a)");
|
||||
}
|
||||
|
||||
TEST_F(GlslGeneratorImplTest_Binary, Multiply_ScalarVector_f16) {
|
||||
Enable(ast::Extension::kF16);
|
||||
|
||||
GlobalVar("a", vec3<f16>(1_h, 1_h, 1_h), ast::StorageClass::kPrivate);
|
||||
auto* lhs = Expr(1_h);
|
||||
auto* rhs = vec3<f16>(1_h, 1_h, 1_h);
|
||||
auto* rhs = Expr("a");
|
||||
|
||||
auto* expr = create<ast::BinaryExpression>(ast::BinaryOp::kMultiply, lhs, rhs);
|
||||
|
||||
|
@ -211,7 +215,7 @@ TEST_F(GlslGeneratorImplTest_Binary, Multiply_ScalarVector_f16) {
|
|||
|
||||
std::stringstream out;
|
||||
EXPECT_TRUE(gen.EmitExpression(out, expr)) << gen.error();
|
||||
EXPECT_EQ(out.str(), "(1.0hf * f16vec3(1.0hf))");
|
||||
EXPECT_EQ(out.str(), "(1.0hf * a)");
|
||||
}
|
||||
|
||||
TEST_F(GlslGeneratorImplTest_Binary, Multiply_MatrixScalar_f32) {
|
||||
|
|
|
@ -184,7 +184,7 @@ TEST_F(HlslGeneratorImplTest_Binary, Multiply_VectorScalar_f32) {
|
|||
|
||||
std::stringstream out;
|
||||
EXPECT_TRUE(gen.EmitExpression(out, expr)) << gen.error();
|
||||
EXPECT_EQ(out.str(), "((1.0f).xxx * 1.0f)");
|
||||
EXPECT_EQ(out.str(), "(1.0f).xxx");
|
||||
}
|
||||
|
||||
TEST_F(HlslGeneratorImplTest_Binary, Multiply_VectorScalar_f16) {
|
||||
|
@ -201,7 +201,7 @@ TEST_F(HlslGeneratorImplTest_Binary, Multiply_VectorScalar_f16) {
|
|||
|
||||
std::stringstream out;
|
||||
EXPECT_TRUE(gen.EmitExpression(out, expr)) << gen.error();
|
||||
EXPECT_EQ(out.str(), "((float16_t(1.0h)).xxx * float16_t(1.0h))");
|
||||
EXPECT_EQ(out.str(), "(float16_t(1.0h)).xxx");
|
||||
}
|
||||
|
||||
TEST_F(HlslGeneratorImplTest_Binary, Multiply_ScalarVector_f32) {
|
||||
|
@ -216,7 +216,7 @@ TEST_F(HlslGeneratorImplTest_Binary, Multiply_ScalarVector_f32) {
|
|||
|
||||
std::stringstream out;
|
||||
EXPECT_TRUE(gen.EmitExpression(out, expr)) << gen.error();
|
||||
EXPECT_EQ(out.str(), "(1.0f * (1.0f).xxx)");
|
||||
EXPECT_EQ(out.str(), "(1.0f).xxx");
|
||||
}
|
||||
|
||||
TEST_F(HlslGeneratorImplTest_Binary, Multiply_ScalarVector_f16) {
|
||||
|
@ -233,7 +233,7 @@ TEST_F(HlslGeneratorImplTest_Binary, Multiply_ScalarVector_f16) {
|
|||
|
||||
std::stringstream out;
|
||||
EXPECT_TRUE(gen.EmitExpression(out, expr)) << gen.error();
|
||||
EXPECT_EQ(out.str(), "(float16_t(1.0h) * (float16_t(1.0h)).xxx)");
|
||||
EXPECT_EQ(out.str(), "(float16_t(1.0h)).xxx");
|
||||
}
|
||||
|
||||
TEST_F(HlslGeneratorImplTest_Binary, Multiply_MatrixScalar_f32) {
|
||||
|
|
|
@ -10,7 +10,7 @@ struct tint_symbol_1 {
|
|||
};
|
||||
|
||||
void main_inner(uint3 GlobalInvocationID) {
|
||||
uint flatIndex = ((((2u * 2u) * GlobalInvocationID.z) + (2u * GlobalInvocationID.y)) + GlobalInvocationID.x);
|
||||
uint flatIndex = (((4u * GlobalInvocationID.z) + (2u * GlobalInvocationID.y)) + GlobalInvocationID.x);
|
||||
flatIndex = (flatIndex * 1u);
|
||||
float4 texel = myTexture.Load(int4(int3(int2(GlobalInvocationID.xy), 0), 0));
|
||||
{
|
||||
|
|
|
@ -10,7 +10,7 @@ struct tint_symbol_1 {
|
|||
};
|
||||
|
||||
void main_inner(uint3 GlobalInvocationID) {
|
||||
uint flatIndex = ((((2u * 2u) * GlobalInvocationID.z) + (2u * GlobalInvocationID.y)) + GlobalInvocationID.x);
|
||||
uint flatIndex = (((4u * GlobalInvocationID.z) + (2u * GlobalInvocationID.y)) + GlobalInvocationID.x);
|
||||
flatIndex = (flatIndex * 1u);
|
||||
float4 texel = myTexture.Load(int4(int3(int2(GlobalInvocationID.xy), 0), 0));
|
||||
{
|
||||
|
|
|
@ -9,7 +9,7 @@ layout(binding = 3, std430) buffer Result_1 {
|
|||
} result;
|
||||
uniform highp sampler2DArray myTexture_1;
|
||||
void tint_symbol(uvec3 GlobalInvocationID) {
|
||||
uint flatIndex = ((((2u * 2u) * GlobalInvocationID.z) + (2u * GlobalInvocationID.y)) + GlobalInvocationID.x);
|
||||
uint flatIndex = (((4u * GlobalInvocationID.z) + (2u * GlobalInvocationID.y)) + GlobalInvocationID.x);
|
||||
flatIndex = (flatIndex * 1u);
|
||||
vec4 texel = texelFetch(myTexture_1, ivec3(ivec2(GlobalInvocationID.xy), 0), 0);
|
||||
{
|
||||
|
|
|
@ -23,7 +23,7 @@ struct Result {
|
|||
};
|
||||
|
||||
void tint_symbol_inner(uint3 GlobalInvocationID, texture2d_array<float, access::sample> tint_symbol_1, device Result* const tint_symbol_2) {
|
||||
uint flatIndex = ((((2u * 2u) * GlobalInvocationID[2]) + (2u * GlobalInvocationID[1])) + GlobalInvocationID[0]);
|
||||
uint flatIndex = (((4u * GlobalInvocationID[2]) + (2u * GlobalInvocationID[1])) + GlobalInvocationID[0]);
|
||||
flatIndex = (flatIndex * 1u);
|
||||
float4 texel = tint_symbol_1.read(uint2(int2(uint3(GlobalInvocationID).xy)), 0, 0);
|
||||
for(uint i = 0u; (i < 1u); i = (i + 1u)) {
|
||||
|
|
|
@ -52,6 +52,7 @@
|
|||
%result = OpVariable %_ptr_StorageBuffer_Result StorageBuffer
|
||||
%void = OpTypeVoid
|
||||
%17 = OpTypeFunction %void %v3uint
|
||||
%uint_4 = OpConstant %uint 4
|
||||
%uint_2 = OpConstant %uint 2
|
||||
%_ptr_Function_uint = OpTypePointer Function %uint
|
||||
%33 = OpConstantNull %uint
|
||||
|
@ -74,12 +75,11 @@
|
|||
%flatIndex = OpVariable %_ptr_Function_uint Function %33
|
||||
%texel = OpVariable %_ptr_Function_v4float Function %51
|
||||
%i = OpVariable %_ptr_Function_uint Function %33
|
||||
%23 = OpIMul %uint %uint_2 %uint_2
|
||||
%24 = OpCompositeExtract %uint %GlobalInvocationID 2
|
||||
%25 = OpIMul %uint %23 %24
|
||||
%23 = OpCompositeExtract %uint %GlobalInvocationID 2
|
||||
%24 = OpIMul %uint %uint_4 %23
|
||||
%26 = OpCompositeExtract %uint %GlobalInvocationID 1
|
||||
%27 = OpIMul %uint %uint_2 %26
|
||||
%28 = OpIAdd %uint %25 %27
|
||||
%28 = OpIAdd %uint %24 %27
|
||||
%29 = OpCompositeExtract %uint %GlobalInvocationID 0
|
||||
%30 = OpIAdd %uint %28 %29
|
||||
OpStore %flatIndex %30
|
||||
|
|
|
@ -68,7 +68,7 @@ void main_inner(uint3 local_id, uint3 global_id, uint local_invocation_index) {
|
|||
float ACached = 0.0f;
|
||||
float BCached[4] = (float[4])0;
|
||||
{
|
||||
[loop] for(uint index = 0u; (index < (4u * 4u)); index = (index + 1u)) {
|
||||
[loop] for(uint index = 0u; (index < 16u); index = (index + 1u)) {
|
||||
acc[index] = 0.0f;
|
||||
}
|
||||
}
|
||||
|
|
|
@ -68,7 +68,7 @@ void main_inner(uint3 local_id, uint3 global_id, uint local_invocation_index) {
|
|||
float ACached = 0.0f;
|
||||
float BCached[4] = (float[4])0;
|
||||
{
|
||||
[loop] for(uint index = 0u; (index < (4u * 4u)); index = (index + 1u)) {
|
||||
[loop] for(uint index = 0u; (index < 16u); index = (index + 1u)) {
|
||||
acc[index] = 0.0f;
|
||||
}
|
||||
}
|
||||
|
|
|
@ -77,7 +77,7 @@ void tint_symbol(uvec3 local_id, uvec3 global_id, uint local_invocation_index) {
|
|||
float ACached = 0.0f;
|
||||
float BCached[4] = float[4](0.0f, 0.0f, 0.0f, 0.0f);
|
||||
{
|
||||
for(uint index = 0u; (index < (4u * 4u)); index = (index + 1u)) {
|
||||
for(uint index = 0u; (index < 16u); index = (index + 1u)) {
|
||||
acc[index] = 0.0f;
|
||||
}
|
||||
}
|
||||
|
|
|
@ -63,7 +63,7 @@ void tint_symbol_inner(uint3 local_id, uint3 global_id, uint local_invocation_in
|
|||
tint_array<float, 16> acc = {};
|
||||
float ACached = 0.0f;
|
||||
tint_array<float, 4> BCached = {};
|
||||
for(uint index = 0u; (index < (4u * 4u)); index = (index + 1u)) {
|
||||
for(uint index = 0u; (index < 16u); index = (index + 1u)) {
|
||||
acc[index] = 0.0f;
|
||||
}
|
||||
uint const ColPerThreadA = (64u / 16u);
|
||||
|
|
|
@ -1,7 +1,7 @@
|
|||
; SPIR-V
|
||||
; Version: 1.3
|
||||
; Generator: Google Tint Compiler; 0
|
||||
; Bound: 374
|
||||
; Bound: 373
|
||||
; Schema: 0
|
||||
OpCapability Shader
|
||||
OpMemoryModel Logical GLSL450
|
||||
|
@ -127,7 +127,7 @@
|
|||
%_arr_float_uint_4 = OpTypeArray %float %uint_4
|
||||
%_ptr_Function__arr_float_uint_4 = OpTypePointer Function %_arr_float_uint_4
|
||||
%152 = OpConstantNull %_arr_float_uint_4
|
||||
%367 = OpTypeFunction %void
|
||||
%366 = OpTypeFunction %void
|
||||
%mm_readA = OpFunction %float None %24
|
||||
%row = OpFunctionParameter %uint
|
||||
%col = OpFunctionParameter %uint
|
||||
|
@ -287,334 +287,333 @@
|
|||
OpBranch %157
|
||||
%157 = OpLabel
|
||||
%159 = OpLoad %uint %index
|
||||
%160 = OpIMul %uint %uint_4 %uint_4
|
||||
%161 = OpULessThan %bool %159 %160
|
||||
%158 = OpLogicalNot %bool %161
|
||||
OpSelectionMerge %162 None
|
||||
OpBranchConditional %158 %163 %162
|
||||
%163 = OpLabel
|
||||
OpBranch %155
|
||||
%160 = OpULessThan %bool %159 %uint_16
|
||||
%158 = OpLogicalNot %bool %160
|
||||
OpSelectionMerge %161 None
|
||||
OpBranchConditional %158 %162 %161
|
||||
%162 = OpLabel
|
||||
%164 = OpLoad %uint %index
|
||||
%165 = OpAccessChain %_ptr_Function_float %acc %164
|
||||
OpStore %165 %51
|
||||
OpBranch %155
|
||||
%161 = OpLabel
|
||||
%163 = OpLoad %uint %index
|
||||
%164 = OpAccessChain %_ptr_Function_float %acc %163
|
||||
OpStore %164 %51
|
||||
OpBranch %156
|
||||
%156 = OpLabel
|
||||
%166 = OpLoad %uint %index
|
||||
%167 = OpIAdd %uint %166 %uint_1
|
||||
OpStore %index %167
|
||||
%165 = OpLoad %uint %index
|
||||
%166 = OpIAdd %uint %165 %uint_1
|
||||
OpStore %index %166
|
||||
OpBranch %154
|
||||
%155 = OpLabel
|
||||
%168 = OpUDiv %uint %uint_64 %uint_16
|
||||
%169 = OpCompositeExtract %uint %local_id 0
|
||||
%170 = OpIMul %uint %169 %168
|
||||
%171 = OpUDiv %uint %uint_64 %uint_16
|
||||
%172 = OpCompositeExtract %uint %local_id 1
|
||||
%173 = OpIMul %uint %172 %171
|
||||
%167 = OpUDiv %uint %uint_64 %uint_16
|
||||
%168 = OpCompositeExtract %uint %local_id 0
|
||||
%169 = OpIMul %uint %168 %167
|
||||
%170 = OpUDiv %uint %uint_64 %uint_16
|
||||
%171 = OpCompositeExtract %uint %local_id 1
|
||||
%172 = OpIMul %uint %171 %170
|
||||
OpStore %t %105
|
||||
OpBranch %175
|
||||
%175 = OpLabel
|
||||
OpLoopMerge %176 %177 None
|
||||
OpBranch %178
|
||||
%178 = OpLabel
|
||||
%180 = OpLoad %uint %t
|
||||
%181 = OpULessThan %bool %180 %141
|
||||
%179 = OpLogicalNot %bool %181
|
||||
OpSelectionMerge %182 None
|
||||
OpBranchConditional %179 %183 %182
|
||||
%183 = OpLabel
|
||||
OpBranch %176
|
||||
%182 = OpLabel
|
||||
OpStore %innerRow %105
|
||||
OpBranch %185
|
||||
%185 = OpLabel
|
||||
OpLoopMerge %186 %187 None
|
||||
OpBranch %188
|
||||
%188 = OpLabel
|
||||
%190 = OpLoad %uint %innerRow
|
||||
%191 = OpULessThan %bool %190 %uint_4
|
||||
%189 = OpLogicalNot %bool %191
|
||||
OpSelectionMerge %192 None
|
||||
OpBranchConditional %189 %193 %192
|
||||
%193 = OpLabel
|
||||
OpBranch %186
|
||||
%192 = OpLabel
|
||||
OpStore %innerCol %105
|
||||
OpBranch %195
|
||||
%195 = OpLabel
|
||||
OpLoopMerge %196 %197 None
|
||||
OpBranch %198
|
||||
%198 = OpLabel
|
||||
%200 = OpLoad %uint %innerCol
|
||||
%201 = OpULessThan %bool %200 %168
|
||||
%199 = OpLogicalNot %bool %201
|
||||
OpSelectionMerge %202 None
|
||||
OpBranchConditional %199 %203 %202
|
||||
%203 = OpLabel
|
||||
OpBranch %196
|
||||
%202 = OpLabel
|
||||
%204 = OpLoad %uint %innerRow
|
||||
%205 = OpIAdd %uint %130 %204
|
||||
%206 = OpLoad %uint %innerCol
|
||||
%207 = OpIAdd %uint %170 %206
|
||||
%209 = OpLoad %uint %innerRow
|
||||
%210 = OpIAdd %uint %134 %209
|
||||
%211 = OpLoad %uint %t
|
||||
%212 = OpIMul %uint %211 %uint_64
|
||||
%213 = OpIAdd %uint %212 %207
|
||||
%208 = OpFunctionCall %float %mm_readA %210 %213
|
||||
%214 = OpAccessChain %_ptr_Workgroup_float %mm_Asub %205 %207
|
||||
OpStore %214 %208
|
||||
OpBranch %197
|
||||
%197 = OpLabel
|
||||
%215 = OpLoad %uint %innerCol
|
||||
%216 = OpIAdd %uint %215 %uint_1
|
||||
OpStore %innerCol %216
|
||||
OpBranch %195
|
||||
%196 = OpLabel
|
||||
OpBranch %187
|
||||
%187 = OpLabel
|
||||
%217 = OpLoad %uint %innerRow
|
||||
%218 = OpIAdd %uint %217 %uint_1
|
||||
OpStore %innerRow %218
|
||||
OpBranch %185
|
||||
%186 = OpLabel
|
||||
OpStore %innerRow_0 %105
|
||||
OpBranch %220
|
||||
%220 = OpLabel
|
||||
OpLoopMerge %221 %222 None
|
||||
OpBranch %223
|
||||
%223 = OpLabel
|
||||
%225 = OpLoad %uint %innerRow_0
|
||||
%226 = OpULessThan %bool %225 %171
|
||||
%224 = OpLogicalNot %bool %226
|
||||
OpSelectionMerge %227 None
|
||||
OpBranchConditional %224 %228 %227
|
||||
%228 = OpLabel
|
||||
OpBranch %221
|
||||
%227 = OpLabel
|
||||
OpStore %innerCol_0 %105
|
||||
OpBranch %230
|
||||
%230 = OpLabel
|
||||
OpLoopMerge %231 %232 None
|
||||
OpBranch %233
|
||||
%233 = OpLabel
|
||||
%235 = OpLoad %uint %innerCol_0
|
||||
%236 = OpULessThan %bool %235 %uint_4
|
||||
%234 = OpLogicalNot %bool %236
|
||||
OpSelectionMerge %237 None
|
||||
OpBranchConditional %234 %238 %237
|
||||
%238 = OpLabel
|
||||
OpBranch %231
|
||||
%237 = OpLabel
|
||||
%239 = OpLoad %uint %innerRow_0
|
||||
%240 = OpIAdd %uint %173 %239
|
||||
%241 = OpLoad %uint %innerCol_0
|
||||
%242 = OpIAdd %uint %132 %241
|
||||
%244 = OpLoad %uint %t
|
||||
%245 = OpIMul %uint %244 %uint_64
|
||||
%246 = OpIAdd %uint %245 %240
|
||||
%247 = OpLoad %uint %innerCol_0
|
||||
%248 = OpIAdd %uint %136 %247
|
||||
%243 = OpFunctionCall %float %mm_readB %246 %248
|
||||
%249 = OpLoad %uint %innerCol_0
|
||||
%250 = OpAccessChain %_ptr_Workgroup_float %mm_Bsub %249 %242
|
||||
OpStore %250 %243
|
||||
OpBranch %232
|
||||
%232 = OpLabel
|
||||
%251 = OpLoad %uint %innerCol_0
|
||||
%252 = OpIAdd %uint %251 %uint_1
|
||||
OpStore %innerCol_0 %252
|
||||
OpBranch %230
|
||||
%231 = OpLabel
|
||||
OpBranch %222
|
||||
%222 = OpLabel
|
||||
%253 = OpLoad %uint %innerRow_0
|
||||
%254 = OpIAdd %uint %253 %uint_1
|
||||
OpStore %innerRow_0 %254
|
||||
OpBranch %220
|
||||
%221 = OpLabel
|
||||
OpControlBarrier %uint_2 %uint_2 %uint_264
|
||||
OpStore %k %105
|
||||
OpBranch %257
|
||||
%257 = OpLabel
|
||||
OpLoopMerge %258 %259 None
|
||||
OpBranch %260
|
||||
%260 = OpLabel
|
||||
%262 = OpLoad %uint %k
|
||||
%263 = OpULessThan %bool %262 %uint_64
|
||||
%261 = OpLogicalNot %bool %263
|
||||
OpSelectionMerge %264 None
|
||||
OpBranchConditional %261 %265 %264
|
||||
%265 = OpLabel
|
||||
OpBranch %258
|
||||
%264 = OpLabel
|
||||
OpStore %inner %105
|
||||
OpBranch %267
|
||||
%267 = OpLabel
|
||||
OpLoopMerge %268 %269 None
|
||||
OpBranch %270
|
||||
%270 = OpLabel
|
||||
%272 = OpLoad %uint %inner
|
||||
%273 = OpULessThan %bool %272 %uint_4
|
||||
%271 = OpLogicalNot %bool %273
|
||||
OpSelectionMerge %274 None
|
||||
OpBranchConditional %271 %275 %274
|
||||
%275 = OpLabel
|
||||
OpBranch %268
|
||||
%274 = OpLabel
|
||||
%276 = OpLoad %uint %inner
|
||||
%277 = OpAccessChain %_ptr_Function_float %BCached %276
|
||||
%278 = OpLoad %uint %k
|
||||
%279 = OpLoad %uint %inner
|
||||
%280 = OpIAdd %uint %132 %279
|
||||
%281 = OpAccessChain %_ptr_Workgroup_float %mm_Bsub %278 %280
|
||||
%282 = OpLoad %float %281
|
||||
OpStore %277 %282
|
||||
OpBranch %269
|
||||
%269 = OpLabel
|
||||
%283 = OpLoad %uint %inner
|
||||
%284 = OpIAdd %uint %283 %uint_1
|
||||
OpStore %inner %284
|
||||
OpBranch %267
|
||||
%268 = OpLabel
|
||||
OpStore %innerRow_1 %105
|
||||
OpBranch %286
|
||||
%286 = OpLabel
|
||||
OpLoopMerge %287 %288 None
|
||||
OpBranch %289
|
||||
%289 = OpLabel
|
||||
%291 = OpLoad %uint %innerRow_1
|
||||
%292 = OpULessThan %bool %291 %uint_4
|
||||
%290 = OpLogicalNot %bool %292
|
||||
OpSelectionMerge %293 None
|
||||
OpBranchConditional %290 %294 %293
|
||||
%294 = OpLabel
|
||||
OpBranch %287
|
||||
%293 = OpLabel
|
||||
%295 = OpLoad %uint %innerRow_1
|
||||
%296 = OpIAdd %uint %130 %295
|
||||
%297 = OpLoad %uint %k
|
||||
%298 = OpAccessChain %_ptr_Workgroup_float %mm_Asub %296 %297
|
||||
%299 = OpLoad %float %298
|
||||
OpStore %ACached %299
|
||||
OpStore %innerCol_1 %105
|
||||
OpBranch %301
|
||||
%301 = OpLabel
|
||||
OpLoopMerge %302 %303 None
|
||||
OpBranch %304
|
||||
%304 = OpLabel
|
||||
%306 = OpLoad %uint %innerCol_1
|
||||
%307 = OpULessThan %bool %306 %uint_4
|
||||
%305 = OpLogicalNot %bool %307
|
||||
OpSelectionMerge %308 None
|
||||
OpBranchConditional %305 %309 %308
|
||||
%309 = OpLabel
|
||||
OpBranch %302
|
||||
%308 = OpLabel
|
||||
%310 = OpLoad %uint %innerRow_1
|
||||
%311 = OpIMul %uint %310 %uint_4
|
||||
%312 = OpLoad %uint %innerCol_1
|
||||
%313 = OpIAdd %uint %311 %312
|
||||
%314 = OpAccessChain %_ptr_Function_float %acc %313
|
||||
%315 = OpAccessChain %_ptr_Function_float %acc %313
|
||||
%316 = OpLoad %float %315
|
||||
%317 = OpLoad %float %ACached
|
||||
%318 = OpLoad %uint %innerCol_1
|
||||
%319 = OpAccessChain %_ptr_Function_float %BCached %318
|
||||
%320 = OpLoad %float %319
|
||||
%321 = OpFMul %float %317 %320
|
||||
%322 = OpFAdd %float %316 %321
|
||||
OpStore %314 %322
|
||||
OpBranch %303
|
||||
%303 = OpLabel
|
||||
%323 = OpLoad %uint %innerCol_1
|
||||
%324 = OpIAdd %uint %323 %uint_1
|
||||
OpStore %innerCol_1 %324
|
||||
OpBranch %301
|
||||
%302 = OpLabel
|
||||
OpBranch %288
|
||||
%288 = OpLabel
|
||||
%325 = OpLoad %uint %innerRow_1
|
||||
%326 = OpIAdd %uint %325 %uint_1
|
||||
OpStore %innerRow_1 %326
|
||||
OpBranch %286
|
||||
%287 = OpLabel
|
||||
OpBranch %259
|
||||
%259 = OpLabel
|
||||
%327 = OpLoad %uint %k
|
||||
%328 = OpIAdd %uint %327 %uint_1
|
||||
OpStore %k %328
|
||||
OpBranch %257
|
||||
%258 = OpLabel
|
||||
OpControlBarrier %uint_2 %uint_2 %uint_264
|
||||
OpBranch %174
|
||||
%174 = OpLabel
|
||||
OpLoopMerge %175 %176 None
|
||||
OpBranch %177
|
||||
%177 = OpLabel
|
||||
%330 = OpLoad %uint %t
|
||||
%331 = OpIAdd %uint %330 %uint_1
|
||||
OpStore %t %331
|
||||
%179 = OpLoad %uint %t
|
||||
%180 = OpULessThan %bool %179 %141
|
||||
%178 = OpLogicalNot %bool %180
|
||||
OpSelectionMerge %181 None
|
||||
OpBranchConditional %178 %182 %181
|
||||
%182 = OpLabel
|
||||
OpBranch %175
|
||||
%181 = OpLabel
|
||||
OpStore %innerRow %105
|
||||
OpBranch %184
|
||||
%184 = OpLabel
|
||||
OpLoopMerge %185 %186 None
|
||||
OpBranch %187
|
||||
%187 = OpLabel
|
||||
%189 = OpLoad %uint %innerRow
|
||||
%190 = OpULessThan %bool %189 %uint_4
|
||||
%188 = OpLogicalNot %bool %190
|
||||
OpSelectionMerge %191 None
|
||||
OpBranchConditional %188 %192 %191
|
||||
%192 = OpLabel
|
||||
OpBranch %185
|
||||
%191 = OpLabel
|
||||
OpStore %innerCol %105
|
||||
OpBranch %194
|
||||
%194 = OpLabel
|
||||
OpLoopMerge %195 %196 None
|
||||
OpBranch %197
|
||||
%197 = OpLabel
|
||||
%199 = OpLoad %uint %innerCol
|
||||
%200 = OpULessThan %bool %199 %167
|
||||
%198 = OpLogicalNot %bool %200
|
||||
OpSelectionMerge %201 None
|
||||
OpBranchConditional %198 %202 %201
|
||||
%202 = OpLabel
|
||||
OpBranch %195
|
||||
%201 = OpLabel
|
||||
%203 = OpLoad %uint %innerRow
|
||||
%204 = OpIAdd %uint %130 %203
|
||||
%205 = OpLoad %uint %innerCol
|
||||
%206 = OpIAdd %uint %169 %205
|
||||
%208 = OpLoad %uint %innerRow
|
||||
%209 = OpIAdd %uint %134 %208
|
||||
%210 = OpLoad %uint %t
|
||||
%211 = OpIMul %uint %210 %uint_64
|
||||
%212 = OpIAdd %uint %211 %206
|
||||
%207 = OpFunctionCall %float %mm_readA %209 %212
|
||||
%213 = OpAccessChain %_ptr_Workgroup_float %mm_Asub %204 %206
|
||||
OpStore %213 %207
|
||||
OpBranch %196
|
||||
%196 = OpLabel
|
||||
%214 = OpLoad %uint %innerCol
|
||||
%215 = OpIAdd %uint %214 %uint_1
|
||||
OpStore %innerCol %215
|
||||
OpBranch %194
|
||||
%195 = OpLabel
|
||||
OpBranch %186
|
||||
%186 = OpLabel
|
||||
%216 = OpLoad %uint %innerRow
|
||||
%217 = OpIAdd %uint %216 %uint_1
|
||||
OpStore %innerRow %217
|
||||
OpBranch %184
|
||||
%185 = OpLabel
|
||||
OpStore %innerRow_0 %105
|
||||
OpBranch %219
|
||||
%219 = OpLabel
|
||||
OpLoopMerge %220 %221 None
|
||||
OpBranch %222
|
||||
%222 = OpLabel
|
||||
%224 = OpLoad %uint %innerRow_0
|
||||
%225 = OpULessThan %bool %224 %170
|
||||
%223 = OpLogicalNot %bool %225
|
||||
OpSelectionMerge %226 None
|
||||
OpBranchConditional %223 %227 %226
|
||||
%227 = OpLabel
|
||||
OpBranch %220
|
||||
%226 = OpLabel
|
||||
OpStore %innerCol_0 %105
|
||||
OpBranch %229
|
||||
%229 = OpLabel
|
||||
OpLoopMerge %230 %231 None
|
||||
OpBranch %232
|
||||
%232 = OpLabel
|
||||
%234 = OpLoad %uint %innerCol_0
|
||||
%235 = OpULessThan %bool %234 %uint_4
|
||||
%233 = OpLogicalNot %bool %235
|
||||
OpSelectionMerge %236 None
|
||||
OpBranchConditional %233 %237 %236
|
||||
%237 = OpLabel
|
||||
OpBranch %230
|
||||
%236 = OpLabel
|
||||
%238 = OpLoad %uint %innerRow_0
|
||||
%239 = OpIAdd %uint %172 %238
|
||||
%240 = OpLoad %uint %innerCol_0
|
||||
%241 = OpIAdd %uint %132 %240
|
||||
%243 = OpLoad %uint %t
|
||||
%244 = OpIMul %uint %243 %uint_64
|
||||
%245 = OpIAdd %uint %244 %239
|
||||
%246 = OpLoad %uint %innerCol_0
|
||||
%247 = OpIAdd %uint %136 %246
|
||||
%242 = OpFunctionCall %float %mm_readB %245 %247
|
||||
%248 = OpLoad %uint %innerCol_0
|
||||
%249 = OpAccessChain %_ptr_Workgroup_float %mm_Bsub %248 %241
|
||||
OpStore %249 %242
|
||||
OpBranch %231
|
||||
%231 = OpLabel
|
||||
%250 = OpLoad %uint %innerCol_0
|
||||
%251 = OpIAdd %uint %250 %uint_1
|
||||
OpStore %innerCol_0 %251
|
||||
OpBranch %229
|
||||
%230 = OpLabel
|
||||
OpBranch %221
|
||||
%221 = OpLabel
|
||||
%252 = OpLoad %uint %innerRow_0
|
||||
%253 = OpIAdd %uint %252 %uint_1
|
||||
OpStore %innerRow_0 %253
|
||||
OpBranch %219
|
||||
%220 = OpLabel
|
||||
OpControlBarrier %uint_2 %uint_2 %uint_264
|
||||
OpStore %k %105
|
||||
OpBranch %256
|
||||
%256 = OpLabel
|
||||
OpLoopMerge %257 %258 None
|
||||
OpBranch %259
|
||||
%259 = OpLabel
|
||||
%261 = OpLoad %uint %k
|
||||
%262 = OpULessThan %bool %261 %uint_64
|
||||
%260 = OpLogicalNot %bool %262
|
||||
OpSelectionMerge %263 None
|
||||
OpBranchConditional %260 %264 %263
|
||||
%264 = OpLabel
|
||||
OpBranch %257
|
||||
%263 = OpLabel
|
||||
OpStore %inner %105
|
||||
OpBranch %266
|
||||
%266 = OpLabel
|
||||
OpLoopMerge %267 %268 None
|
||||
OpBranch %269
|
||||
%269 = OpLabel
|
||||
%271 = OpLoad %uint %inner
|
||||
%272 = OpULessThan %bool %271 %uint_4
|
||||
%270 = OpLogicalNot %bool %272
|
||||
OpSelectionMerge %273 None
|
||||
OpBranchConditional %270 %274 %273
|
||||
%274 = OpLabel
|
||||
OpBranch %267
|
||||
%273 = OpLabel
|
||||
%275 = OpLoad %uint %inner
|
||||
%276 = OpAccessChain %_ptr_Function_float %BCached %275
|
||||
%277 = OpLoad %uint %k
|
||||
%278 = OpLoad %uint %inner
|
||||
%279 = OpIAdd %uint %132 %278
|
||||
%280 = OpAccessChain %_ptr_Workgroup_float %mm_Bsub %277 %279
|
||||
%281 = OpLoad %float %280
|
||||
OpStore %276 %281
|
||||
OpBranch %268
|
||||
%268 = OpLabel
|
||||
%282 = OpLoad %uint %inner
|
||||
%283 = OpIAdd %uint %282 %uint_1
|
||||
OpStore %inner %283
|
||||
OpBranch %266
|
||||
%267 = OpLabel
|
||||
OpStore %innerRow_1 %105
|
||||
OpBranch %285
|
||||
%285 = OpLabel
|
||||
OpLoopMerge %286 %287 None
|
||||
OpBranch %288
|
||||
%288 = OpLabel
|
||||
%290 = OpLoad %uint %innerRow_1
|
||||
%291 = OpULessThan %bool %290 %uint_4
|
||||
%289 = OpLogicalNot %bool %291
|
||||
OpSelectionMerge %292 None
|
||||
OpBranchConditional %289 %293 %292
|
||||
%293 = OpLabel
|
||||
OpBranch %286
|
||||
%292 = OpLabel
|
||||
%294 = OpLoad %uint %innerRow_1
|
||||
%295 = OpIAdd %uint %130 %294
|
||||
%296 = OpLoad %uint %k
|
||||
%297 = OpAccessChain %_ptr_Workgroup_float %mm_Asub %295 %296
|
||||
%298 = OpLoad %float %297
|
||||
OpStore %ACached %298
|
||||
OpStore %innerCol_1 %105
|
||||
OpBranch %300
|
||||
%300 = OpLabel
|
||||
OpLoopMerge %301 %302 None
|
||||
OpBranch %303
|
||||
%303 = OpLabel
|
||||
%305 = OpLoad %uint %innerCol_1
|
||||
%306 = OpULessThan %bool %305 %uint_4
|
||||
%304 = OpLogicalNot %bool %306
|
||||
OpSelectionMerge %307 None
|
||||
OpBranchConditional %304 %308 %307
|
||||
%308 = OpLabel
|
||||
OpBranch %301
|
||||
%307 = OpLabel
|
||||
%309 = OpLoad %uint %innerRow_1
|
||||
%310 = OpIMul %uint %309 %uint_4
|
||||
%311 = OpLoad %uint %innerCol_1
|
||||
%312 = OpIAdd %uint %310 %311
|
||||
%313 = OpAccessChain %_ptr_Function_float %acc %312
|
||||
%314 = OpAccessChain %_ptr_Function_float %acc %312
|
||||
%315 = OpLoad %float %314
|
||||
%316 = OpLoad %float %ACached
|
||||
%317 = OpLoad %uint %innerCol_1
|
||||
%318 = OpAccessChain %_ptr_Function_float %BCached %317
|
||||
%319 = OpLoad %float %318
|
||||
%320 = OpFMul %float %316 %319
|
||||
%321 = OpFAdd %float %315 %320
|
||||
OpStore %313 %321
|
||||
OpBranch %302
|
||||
%302 = OpLabel
|
||||
%322 = OpLoad %uint %innerCol_1
|
||||
%323 = OpIAdd %uint %322 %uint_1
|
||||
OpStore %innerCol_1 %323
|
||||
OpBranch %300
|
||||
%301 = OpLabel
|
||||
OpBranch %287
|
||||
%287 = OpLabel
|
||||
%324 = OpLoad %uint %innerRow_1
|
||||
%325 = OpIAdd %uint %324 %uint_1
|
||||
OpStore %innerRow_1 %325
|
||||
OpBranch %285
|
||||
%286 = OpLabel
|
||||
OpBranch %258
|
||||
%258 = OpLabel
|
||||
%326 = OpLoad %uint %k
|
||||
%327 = OpIAdd %uint %326 %uint_1
|
||||
OpStore %k %327
|
||||
OpBranch %256
|
||||
%257 = OpLabel
|
||||
OpControlBarrier %uint_2 %uint_2 %uint_264
|
||||
OpBranch %176
|
||||
%176 = OpLabel
|
||||
%329 = OpLoad %uint %t
|
||||
%330 = OpIAdd %uint %329 %uint_1
|
||||
OpStore %t %330
|
||||
OpBranch %174
|
||||
%175 = OpLabel
|
||||
OpStore %innerRow_2 %105
|
||||
OpBranch %333
|
||||
%333 = OpLabel
|
||||
OpLoopMerge %334 %335 None
|
||||
OpBranch %336
|
||||
%336 = OpLabel
|
||||
%338 = OpLoad %uint %innerRow_2
|
||||
%339 = OpULessThan %bool %338 %uint_4
|
||||
%337 = OpLogicalNot %bool %339
|
||||
OpSelectionMerge %340 None
|
||||
OpBranchConditional %337 %341 %340
|
||||
%341 = OpLabel
|
||||
OpBranch %334
|
||||
%340 = OpLabel
|
||||
OpStore %innerCol_2 %105
|
||||
OpBranch %343
|
||||
%343 = OpLabel
|
||||
OpLoopMerge %344 %345 None
|
||||
OpBranch %346
|
||||
%346 = OpLabel
|
||||
%348 = OpLoad %uint %innerCol_2
|
||||
%349 = OpULessThan %bool %348 %uint_4
|
||||
%347 = OpLogicalNot %bool %349
|
||||
OpSelectionMerge %350 None
|
||||
OpBranchConditional %347 %351 %350
|
||||
%351 = OpLabel
|
||||
OpBranch %344
|
||||
%350 = OpLabel
|
||||
%352 = OpLoad %uint %innerRow_2
|
||||
%353 = OpIMul %uint %352 %uint_4
|
||||
%354 = OpLoad %uint %innerCol_2
|
||||
%355 = OpIAdd %uint %353 %354
|
||||
%357 = OpLoad %uint %innerRow_2
|
||||
%358 = OpIAdd %uint %134 %357
|
||||
%359 = OpLoad %uint %innerCol_2
|
||||
%360 = OpIAdd %uint %136 %359
|
||||
%361 = OpAccessChain %_ptr_Function_float %acc %355
|
||||
%362 = OpLoad %float %361
|
||||
%356 = OpFunctionCall %void %mm_write %358 %360 %362
|
||||
OpBranch %345
|
||||
%345 = OpLabel
|
||||
%363 = OpLoad %uint %innerCol_2
|
||||
%364 = OpIAdd %uint %363 %uint_1
|
||||
OpStore %innerCol_2 %364
|
||||
OpBranch %343
|
||||
%344 = OpLabel
|
||||
OpBranch %332
|
||||
%332 = OpLabel
|
||||
OpLoopMerge %333 %334 None
|
||||
OpBranch %335
|
||||
%335 = OpLabel
|
||||
%365 = OpLoad %uint %innerRow_2
|
||||
%366 = OpIAdd %uint %365 %uint_1
|
||||
OpStore %innerRow_2 %366
|
||||
%337 = OpLoad %uint %innerRow_2
|
||||
%338 = OpULessThan %bool %337 %uint_4
|
||||
%336 = OpLogicalNot %bool %338
|
||||
OpSelectionMerge %339 None
|
||||
OpBranchConditional %336 %340 %339
|
||||
%340 = OpLabel
|
||||
OpBranch %333
|
||||
%339 = OpLabel
|
||||
OpStore %innerCol_2 %105
|
||||
OpBranch %342
|
||||
%342 = OpLabel
|
||||
OpLoopMerge %343 %344 None
|
||||
OpBranch %345
|
||||
%345 = OpLabel
|
||||
%347 = OpLoad %uint %innerCol_2
|
||||
%348 = OpULessThan %bool %347 %uint_4
|
||||
%346 = OpLogicalNot %bool %348
|
||||
OpSelectionMerge %349 None
|
||||
OpBranchConditional %346 %350 %349
|
||||
%350 = OpLabel
|
||||
OpBranch %343
|
||||
%349 = OpLabel
|
||||
%351 = OpLoad %uint %innerRow_2
|
||||
%352 = OpIMul %uint %351 %uint_4
|
||||
%353 = OpLoad %uint %innerCol_2
|
||||
%354 = OpIAdd %uint %352 %353
|
||||
%356 = OpLoad %uint %innerRow_2
|
||||
%357 = OpIAdd %uint %134 %356
|
||||
%358 = OpLoad %uint %innerCol_2
|
||||
%359 = OpIAdd %uint %136 %358
|
||||
%360 = OpAccessChain %_ptr_Function_float %acc %354
|
||||
%361 = OpLoad %float %360
|
||||
%355 = OpFunctionCall %void %mm_write %357 %359 %361
|
||||
OpBranch %344
|
||||
%344 = OpLabel
|
||||
%362 = OpLoad %uint %innerCol_2
|
||||
%363 = OpIAdd %uint %362 %uint_1
|
||||
OpStore %innerCol_2 %363
|
||||
OpBranch %342
|
||||
%343 = OpLabel
|
||||
OpBranch %334
|
||||
%334 = OpLabel
|
||||
%364 = OpLoad %uint %innerRow_2
|
||||
%365 = OpIAdd %uint %364 %uint_1
|
||||
OpStore %innerRow_2 %365
|
||||
OpBranch %332
|
||||
%333 = OpLabel
|
||||
OpReturn
|
||||
OpFunctionEnd
|
||||
%main = OpFunction %void None %367
|
||||
%369 = OpLabel
|
||||
%371 = OpLoad %v3uint %local_id_1
|
||||
%372 = OpLoad %v3uint %global_id_1
|
||||
%373 = OpLoad %uint %local_invocation_index_1
|
||||
%370 = OpFunctionCall %void %main_inner %371 %372 %373
|
||||
%main = OpFunction %void None %366
|
||||
%368 = OpLabel
|
||||
%370 = OpLoad %v3uint %local_id_1
|
||||
%371 = OpLoad %v3uint %global_id_1
|
||||
%372 = OpLoad %uint %local_invocation_index_1
|
||||
%369 = OpFunctionCall %void %main_inner %370 %371 %372
|
||||
OpReturn
|
||||
OpFunctionEnd
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
[numthreads(1, 1, 1)]
|
||||
void f() {
|
||||
const float16_t r = (float16_t(1.0h) * float16_t(2.0h));
|
||||
const float16_t r = float16_t(2.0h);
|
||||
return;
|
||||
}
|
||||
|
|
|
@ -2,7 +2,7 @@
|
|||
#extension GL_AMD_gpu_shader_half_float : require
|
||||
|
||||
void f() {
|
||||
float16_t r = (1.0hf * 2.0hf);
|
||||
float16_t r = 2.0hf;
|
||||
}
|
||||
|
||||
layout(local_size_x = 1, local_size_y = 1, local_size_z = 1) in;
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
[numthreads(1, 1, 1)]
|
||||
void f() {
|
||||
const float r = (1.0f * 2.0f);
|
||||
const float r = 2.0f;
|
||||
return;
|
||||
}
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
[numthreads(1, 1, 1)]
|
||||
void f() {
|
||||
const float r = (1.0f * 2.0f);
|
||||
const float r = 2.0f;
|
||||
return;
|
||||
}
|
||||
|
|
|
@ -1,7 +1,7 @@
|
|||
#version 310 es
|
||||
|
||||
void f() {
|
||||
float r = (1.0f * 2.0f);
|
||||
float r = 2.0f;
|
||||
}
|
||||
|
||||
layout(local_size_x = 1, local_size_y = 1, local_size_z = 1) in;
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
[numthreads(1, 1, 1)]
|
||||
void f() {
|
||||
const int r = (1 * 2);
|
||||
const int r = 2;
|
||||
return;
|
||||
}
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
[numthreads(1, 1, 1)]
|
||||
void f() {
|
||||
const int r = (1 * 2);
|
||||
const int r = 2;
|
||||
return;
|
||||
}
|
||||
|
|
|
@ -1,7 +1,7 @@
|
|||
#version 310 es
|
||||
|
||||
void f() {
|
||||
int r = (1 * 2);
|
||||
int r = 2;
|
||||
}
|
||||
|
||||
layout(local_size_x = 1, local_size_y = 1, local_size_z = 1) in;
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
[numthreads(1, 1, 1)]
|
||||
void f() {
|
||||
const uint r = (1u * 2u);
|
||||
const uint r = 2u;
|
||||
return;
|
||||
}
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
[numthreads(1, 1, 1)]
|
||||
void f() {
|
||||
const uint r = (1u * 2u);
|
||||
const uint r = 2u;
|
||||
return;
|
||||
}
|
||||
|
|
|
@ -1,7 +1,7 @@
|
|||
#version 310 es
|
||||
|
||||
void f() {
|
||||
uint r = (1u * 2u);
|
||||
uint r = 2u;
|
||||
}
|
||||
|
||||
layout(local_size_x = 1, local_size_y = 1, local_size_z = 1) in;
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
float main() {
|
||||
return (((2.0f * 3.0f) - 4.0f) / 5.0f);
|
||||
return (2.0f / 5.0f);
|
||||
}
|
||||
|
||||
[numthreads(2, 1, 1)]
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
float main() {
|
||||
return (((2.0f * 3.0f) - 4.0f) / 5.0f);
|
||||
return (2.0f / 5.0f);
|
||||
}
|
||||
|
||||
[numthreads(2, 1, 1)]
|
||||
|
|
|
@ -2,7 +2,7 @@
|
|||
|
||||
using namespace metal;
|
||||
float tint_symbol() {
|
||||
return (((2.0f * 3.0f) - 4.0f) / 5.0f);
|
||||
return (2.0f / 5.0f);
|
||||
}
|
||||
|
||||
kernel void ep() {
|
||||
|
|
|
@ -1,7 +1,7 @@
|
|||
; SPIR-V
|
||||
; Version: 1.3
|
||||
; Generator: Google Tint Compiler; 0
|
||||
; Bound: 16
|
||||
; Bound: 12
|
||||
; Schema: 0
|
||||
OpCapability Shader
|
||||
OpMemoryModel Logical GLSL450
|
||||
|
@ -12,19 +12,15 @@
|
|||
%float = OpTypeFloat 32
|
||||
%1 = OpTypeFunction %float
|
||||
%float_2 = OpConstant %float 2
|
||||
%float_3 = OpConstant %float 3
|
||||
%float_4 = OpConstant %float 4
|
||||
%float_5 = OpConstant %float 5
|
||||
%void = OpTypeVoid
|
||||
%12 = OpTypeFunction %void
|
||||
%8 = OpTypeFunction %void
|
||||
%main = OpFunction %float None %1
|
||||
%4 = OpLabel
|
||||
%7 = OpFMul %float %float_2 %float_3
|
||||
%9 = OpFSub %float %7 %float_4
|
||||
%11 = OpFDiv %float %9 %float_5
|
||||
OpReturnValue %11
|
||||
%7 = OpFDiv %float %float_2 %float_5
|
||||
OpReturnValue %7
|
||||
OpFunctionEnd
|
||||
%ep = OpFunction %void None %12
|
||||
%15 = OpLabel
|
||||
%ep = OpFunction %void None %8
|
||||
%11 = OpLabel
|
||||
OpReturn
|
||||
OpFunctionEnd
|
||||
|
|
Loading…
Reference in New Issue