tint: Implement const eval of binary multiply

Bug: tint:1581
Change-Id: I70ff40ed4d8faf0a665824fef936ffbafb3f0948
Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/99362
Kokoro: Kokoro <noreply+kokoro@google.com>
Commit-Queue: Antonio Maiorano <amaiorano@google.com>
Reviewed-by: Ben Clayton <bclayton@google.com>
This commit is contained in:
Antonio Maiorano 2022-09-01 14:57:39 +00:00 committed by Dawn LUCI CQ
parent ae6f76fe3a
commit c20c5dfb4a
35 changed files with 1243 additions and 482 deletions

View File

@ -898,15 +898,15 @@ op ! <N: num> (vec<N, bool>) -> vec<N, bool>
@const op - <T: fia_fiu32_f16, N: num> (T, vec<N, T>) -> vec<N, T>
@const op - <T: fa_f32_f16, N: num, M: num> (mat<N, M, T>, mat<N, M, T>) -> mat<N, M, T>
op * <T: fiu32_f16>(T, T) -> T
op * <T: fiu32_f16, N: num> (vec<N, T>, vec<N, T>) -> vec<N, T>
op * <T: fiu32_f16, N: num> (vec<N, T>, T) -> vec<N, T>
op * <T: fiu32_f16, N: num> (T, vec<N, T>) -> vec<N, T>
op * <T: f32_f16, N: num, M: num> (T, mat<N, M, T>) -> mat<N, M, T>
op * <T: f32_f16, N: num, M: num> (mat<N, M, T>, T) -> mat<N, M, T>
op * <T: f32_f16, C: num, R: num> (mat<C, R, T>, vec<C, T>) -> vec<R, T>
op * <T: f32_f16, C: num, R: num> (vec<R, T>, mat<C, R, T>) -> vec<C, T>
op * <T: f32_f16, K: num, C: num, R: num> (mat<K, R, T>, mat<C, K, T>) -> mat<C, R, T>
@const("Multiply") op * <T: fia_fiu32_f16>(T, T) -> T
@const("Multiply") op * <T: fia_fiu32_f16, N: num> (vec<N, T>, vec<N, T>) -> vec<N, T>
@const("Multiply") op * <T: fia_fiu32_f16, N: num> (vec<N, T>, T) -> vec<N, T>
@const("Multiply") op * <T: fia_fiu32_f16, N: num> (T, vec<N, T>) -> vec<N, T>
@const("Multiply") op * <T: fa_f32_f16, N: num, M: num> (T, mat<N, M, T>) -> mat<N, M, T>
@const("Multiply") op * <T: fa_f32_f16, N: num, M: num> (mat<N, M, T>, T) -> mat<N, M, T>
@const("MultiplyMatVec") op * <T: fa_f32_f16, C: num, R: num> (mat<C, R, T>, vec<C, T>) -> vec<R, T>
@const("MultiplyVecMat") op * <T: fa_f32_f16, C: num, R: num> (vec<R, T>, mat<C, R, T>) -> vec<C, T>
@const("MultiplyMatMat") op * <T: fa_f32_f16, K: num, C: num, R: num> (mat<K, R, T>, mat<C, K, T>) -> mat<C, R, T>
op / <T: fiu32_f16>(T, T) -> T
op / <T: fiu32_f16, N: num> (vec<N, T>, vec<N, T>) -> vec<N, T>

View File

@ -22,6 +22,7 @@
#include <optional>
#include <ostream>
#include "src/tint/traits.h"
#include "src/tint/utils/compiler_macros.h"
#include "src/tint/utils/result.h"
@ -33,6 +34,14 @@ struct Number;
} // namespace tint
namespace tint::detail {
/// Base template for IsNumber
template <typename T>
struct IsNumber : std::false_type {};
/// Specialization for IsNumber
template <typename T>
struct IsNumber<Number<T>> : std::true_type {};
/// An empty structure used as a unique template type for Number when
/// specializing for the f16 type.
struct NumberKindF16 {};
@ -68,6 +77,10 @@ constexpr bool IsInteger = std::is_integral_v<T>;
template <typename T>
constexpr bool IsNumeric = IsInteger<T> || IsFloatingPoint<T>;
/// Evaluates to true iff T is a Number
template <typename T>
constexpr bool IsNumber = detail::IsNumber<T>::value;
/// Resolves to the underlying type for a Number.
template <typename T>
using UnwrapNumber = typename detail::NumberUnwrapper<T>::type;
@ -236,6 +249,26 @@ using f32 = Number<float>;
/// However since C++ don't have native binary16 type, the value is stored as float.
using f16 = Number<detail::NumberKindF16>;
/// @returns the friendly name of Number type T
template <typename T, typename = traits::EnableIf<IsNumber<T>>>
const char* FriendlyName() {
if constexpr (std::is_same_v<T, AInt>) {
return "abstract-int";
} else if constexpr (std::is_same_v<T, AFloat>) {
return "abstract-float";
} else if constexpr (std::is_same_v<T, i32>) {
return "i32";
} else if constexpr (std::is_same_v<T, u32>) {
return "u32";
} else if constexpr (std::is_same_v<T, f32>) {
return "f32";
} else if constexpr (std::is_same_v<T, f16>) {
return "f16";
} else {
static_assert(!sizeof(T), "Unhandled type");
}
}
/// Enumerator of failure reasons when converting from one number to another.
enum class ConversionFailure {
kExceedsPositiveLimit, // The value was too big (+'ve) to fit in the target type
@ -437,6 +470,15 @@ inline std::optional<AInt> CheckedMul(AInt a, AInt b) {
return AInt(result);
}
/// @returns a * b, or an empty optional if the resulting value overflowed the AFloat
inline std::optional<AFloat> CheckedMul(AFloat a, AFloat b) {
auto result = a.value * b.value;
if (!std::isfinite(result)) {
return {};
}
return AFloat{result};
}
/// @returns a * b + c, or an empty optional if the value overflowed the AInt
inline std::optional<AInt> CheckedMadd(AInt a, AInt b, AInt c) {
// https://gcc.gnu.org/bugzilla/show_bug.cgi?id=80635

View File

@ -38,6 +38,7 @@
#include "src/tint/sem/vector.h"
#include "src/tint/utils/compiler_macros.h"
#include "src/tint/utils/map.h"
#include "src/tint/utils/scoped_assignment.h"
#include "src/tint/utils/transform.h"
using namespace tint::number_suffixes; // NOLINT
@ -508,11 +509,202 @@ const Constant* TransformBinaryElements(ProgramBuilder& builder,
auto* ty = n0 > n1 ? c0->Type() : c1->Type();
return CreateComposite(builder, ty, std::move(els));
}
} // namespace
ConstEval::ConstEval(ProgramBuilder& b) : builder(b) {}
template <typename NumberT>
utils::Result<NumberT> ConstEval::Add(NumberT a, NumberT b) {
using T = UnwrapNumber<NumberT>;
auto add_values = [](T lhs, T rhs) {
if constexpr (std::is_integral_v<T> && std::is_signed_v<T>) {
// Ensure no UB for signed overflow
using UT = std::make_unsigned_t<T>;
return static_cast<T>(static_cast<UT>(lhs) + static_cast<UT>(rhs));
} else {
return lhs + rhs;
}
};
NumberT result;
if constexpr (std::is_same_v<NumberT, AInt> || std::is_same_v<NumberT, AFloat>) {
// Check for over/underflow for abstract values
if (auto r = CheckedAdd(a, b)) {
result = r->value;
} else {
AddError("'" + std::to_string(add_values(a.value, b.value)) +
"' cannot be represented as '" + FriendlyName<NumberT>() + "'",
*current_source);
return utils::Failure;
}
} else {
result = add_values(a.value, b.value);
}
return result;
}
template <typename NumberT>
utils::Result<NumberT> ConstEval::Mul(NumberT a, NumberT b) {
using T = UnwrapNumber<NumberT>;
auto mul_values = [](T lhs, T rhs) { //
if constexpr (std::is_integral_v<T> && std::is_signed_v<T>) {
// For signed integrals, avoid C++ UB by multiplying as unsigned
using UT = std::make_unsigned_t<T>;
return static_cast<T>(static_cast<UT>(lhs) * static_cast<UT>(rhs));
} else {
return lhs * rhs;
}
};
NumberT result;
if constexpr (std::is_same_v<NumberT, AInt> || std::is_same_v<NumberT, AFloat>) {
// Check for over/underflow for abstract values
if (auto r = CheckedMul(a, b)) {
result = r->value;
} else {
AddError("'" + std::to_string(mul_values(a.value, b.value)) +
"' cannot be represented as '" + FriendlyName<NumberT>() + "'",
*current_source);
return utils::Failure;
}
} else {
result = mul_values(a.value, b.value);
}
return result;
}
template <typename NumberT>
utils::Result<NumberT> ConstEval::Dot2(NumberT a1, NumberT a2, NumberT b1, NumberT b2) {
auto r1 = Mul(a1, b1);
if (!r1) {
return utils::Failure;
}
auto r2 = Mul(a2, b2);
if (!r2) {
return utils::Failure;
}
auto r = Add(r1.Get(), r2.Get());
if (!r) {
return utils::Failure;
}
return r;
}
template <typename NumberT>
utils::Result<NumberT> ConstEval::Dot3(NumberT a1,
NumberT a2,
NumberT a3,
NumberT b1,
NumberT b2,
NumberT b3) {
auto r1 = Mul(a1, b1);
if (!r1) {
return utils::Failure;
}
auto r2 = Mul(a2, b2);
if (!r2) {
return utils::Failure;
}
auto r3 = Mul(a3, b3);
if (!r3) {
return utils::Failure;
}
auto r = Add(r1.Get(), r2.Get());
if (!r) {
return utils::Failure;
}
r = Add(r.Get(), r3.Get());
if (!r) {
return utils::Failure;
}
return r;
}
template <typename NumberT>
utils::Result<NumberT> ConstEval::Dot4(NumberT a1,
NumberT a2,
NumberT a3,
NumberT a4,
NumberT b1,
NumberT b2,
NumberT b3,
NumberT b4) {
auto r1 = Mul(a1, b1);
if (!r1) {
return utils::Failure;
}
auto r2 = Mul(a2, b2);
if (!r2) {
return utils::Failure;
}
auto r3 = Mul(a3, b3);
if (!r3) {
return utils::Failure;
}
auto r4 = Mul(a4, b4);
if (!r4) {
return utils::Failure;
}
auto r = Add(r1.Get(), r2.Get());
if (!r) {
return utils::Failure;
}
r = Add(r.Get(), r3.Get());
if (!r) {
return utils::Failure;
}
r = Add(r.Get(), r4.Get());
if (!r) {
return utils::Failure;
}
return r;
}
auto ConstEval::AddFunc(const sem::Type* elem_ty) {
return [=](auto a1, auto a2) -> utils::Result<const Constant*> {
if (auto r = Add(a1, a2)) {
return CreateElement(builder, elem_ty, r.Get());
}
return utils::Failure;
};
}
auto ConstEval::MulFunc(const sem::Type* elem_ty) {
return [=](auto a1, auto a2) -> utils::Result<const Constant*> {
if (auto r = Mul(a1, a2)) {
return CreateElement(builder, elem_ty, r.Get());
}
return utils::Failure;
};
}
auto ConstEval::Dot2Func(const sem::Type* elem_ty) {
return [=](auto a1, auto a2, auto b1, auto b2) -> utils::Result<const Constant*> {
if (auto r = Dot2(a1, a2, b1, b2)) {
return CreateElement(builder, elem_ty, r.Get());
}
return utils::Failure;
};
}
auto ConstEval::Dot3Func(const sem::Type* elem_ty) {
return [=](auto a1, auto a2, auto a3, auto b1, auto b2,
auto b3) -> utils::Result<const Constant*> {
if (auto r = Dot3(a1, a2, a3, b1, b2, b3)) {
return CreateElement(builder, elem_ty, r.Get());
}
return utils::Failure;
};
}
auto ConstEval::Dot4Func(const sem::Type* elem_ty) {
return [=](auto a1, auto a2, auto a3, auto a4, auto b1, auto b2, auto b3,
auto b4) -> utils::Result<const Constant*> {
if (auto r = Dot4(a1, a2, a3, a4, b1, b2, b3, b4)) {
return CreateElement(builder, elem_ty, r.Get());
}
return utils::Failure;
};
}
ConstEval::ConstantResult ConstEval::Literal(const sem::Type* ty,
const ast::LiteralExpression* literal) {
return Switch(
@ -756,42 +948,15 @@ ConstEval::ConstantResult ConstEval::OpUnaryMinus(const sem::Type*,
return TransformElements(builder, transform, args[0]);
}
ConstEval::ConstantResult ConstEval::OpPlus(const sem::Type* ty,
ConstEval::ConstantResult ConstEval::OpPlus(const sem::Type*,
utils::VectorRef<const sem::Constant*> args,
const Source& source) {
auto transform = [&](const sem::Constant* c0, const sem::Constant* c1) {
auto create = [&](auto i, auto j) -> const Constant* {
using NumberT = decltype(i);
using T = UnwrapNumber<NumberT>;
auto add_values = [](T lhs, T rhs) {
if constexpr (std::is_integral_v<T> && std::is_signed_v<T>) {
// Ensure no UB for signed overflow
using UT = std::make_unsigned_t<T>;
return static_cast<T>(static_cast<UT>(lhs) + static_cast<UT>(rhs));
} else {
return lhs + rhs;
TINT_SCOPED_ASSIGNMENT(current_source, &source);
auto transform = [&](const sem::Constant* c0, const sem::Constant* c1) -> const Constant* {
if (auto r = Dispatch_fia_fiu32_f16(AddFunc(c0->Type()), c0, c1)) {
return r.Get();
}
};
NumberT result;
if constexpr (std::is_same_v<NumberT, AInt> || std::is_same_v<NumberT, AFloat>) {
// Check for over/underflow for abstract values
if (auto r = CheckedAdd(i, j)) {
result = r->value;
} else {
AddError("'" + std::to_string(add_values(i.value, j.value)) +
"' cannot be represented as '" +
ty->FriendlyName(builder.Symbols()) + "'",
source);
return nullptr;
}
} else {
result = add_values(i.value, j.value);
}
return CreateElement(builder, c0->Type(), result);
};
return Dispatch_fia_fiu32_f16(create, c0, c1);
};
auto r = TransformBinaryElements(builder, transform, args[0], args[1]);
@ -846,6 +1011,192 @@ ConstEval::ConstantResult ConstEval::OpMinus(const sem::Type* ty,
return r;
}
ConstEval::ConstantResult ConstEval::OpMultiply(const sem::Type* /*ty*/,
utils::VectorRef<const sem::Constant*> args,
const Source& source) {
TINT_SCOPED_ASSIGNMENT(current_source, &source);
auto transform = [&](const sem::Constant* c0, const sem::Constant* c1) -> const Constant* {
if (auto r = Dispatch_fia_fiu32_f16(MulFunc(c0->Type()), c0, c1)) {
return r.Get();
}
return nullptr;
};
auto r = TransformBinaryElements(builder, transform, args[0], args[1]);
if (builder.Diagnostics().contains_errors()) {
return utils::Failure;
}
return r;
}
ConstEval::ConstantResult ConstEval::OpMultiplyMatVec(const sem::Type* ty,
utils::VectorRef<const sem::Constant*> args,
const Source& source) {
TINT_SCOPED_ASSIGNMENT(current_source, &source);
auto* mat_ty = args[0]->Type()->As<sem::Matrix>();
auto* vec_ty = args[1]->Type()->As<sem::Vector>();
auto* elem_ty = vec_ty->type();
auto dot = [&](const sem::Constant* m, size_t row, const sem::Constant* v) {
utils::Result<const Constant*> result;
switch (mat_ty->columns()) {
case 2:
result = Dispatch_fa_f32_f16(Dot2Func(elem_ty), //
m->Index(0)->Index(row), //
m->Index(1)->Index(row), //
v->Index(0), //
v->Index(1));
break;
case 3:
result = Dispatch_fa_f32_f16(Dot3Func(elem_ty), //
m->Index(0)->Index(row), //
m->Index(1)->Index(row), //
m->Index(2)->Index(row), //
v->Index(0), //
v->Index(1), v->Index(2));
break;
case 4:
result = Dispatch_fa_f32_f16(Dot4Func(elem_ty), //
m->Index(0)->Index(row), //
m->Index(1)->Index(row), //
m->Index(2)->Index(row), //
m->Index(3)->Index(row), //
v->Index(0), //
v->Index(1), //
v->Index(2), //
v->Index(3));
break;
}
return result;
};
utils::Vector<const sem::Constant*, 4> result;
for (size_t i = 0; i < mat_ty->rows(); ++i) {
auto r = dot(args[0], i, args[1]); // matrix row i * vector
if (!r) {
return utils::Failure;
}
result.Push(r.Get());
}
return CreateComposite(builder, ty, result);
}
ConstEval::ConstantResult ConstEval::OpMultiplyVecMat(const sem::Type* ty,
utils::VectorRef<const sem::Constant*> args,
const Source& source) {
TINT_SCOPED_ASSIGNMENT(current_source, &source);
auto* vec_ty = args[0]->Type()->As<sem::Vector>();
auto* mat_ty = args[1]->Type()->As<sem::Matrix>();
auto* elem_ty = vec_ty->type();
auto dot = [&](const sem::Constant* v, const sem::Constant* m, size_t col) {
utils::Result<const Constant*> result;
switch (mat_ty->rows()) {
case 2:
result = Dispatch_fa_f32_f16(Dot2Func(elem_ty), //
m->Index(col)->Index(0), //
m->Index(col)->Index(1), //
v->Index(0), //
v->Index(1));
break;
case 3:
result = Dispatch_fa_f32_f16(Dot3Func(elem_ty), //
m->Index(col)->Index(0), //
m->Index(col)->Index(1), //
m->Index(col)->Index(2),
v->Index(0), //
v->Index(1), //
v->Index(2));
break;
case 4:
result = Dispatch_fa_f32_f16(Dot4Func(elem_ty), //
m->Index(col)->Index(0), //
m->Index(col)->Index(1), //
m->Index(col)->Index(2), //
m->Index(col)->Index(3), //
v->Index(0), //
v->Index(1), //
v->Index(2), //
v->Index(3));
}
return result;
};
utils::Vector<const sem::Constant*, 4> result;
for (size_t i = 0; i < mat_ty->columns(); ++i) {
auto r = dot(args[0], args[1], i); // vector * matrix col i
if (!r) {
return utils::Failure;
}
result.Push(r.Get());
}
return CreateComposite(builder, ty, result);
}
ConstEval::ConstantResult ConstEval::OpMultiplyMatMat(const sem::Type* ty,
utils::VectorRef<const sem::Constant*> args,
const Source& source) {
TINT_SCOPED_ASSIGNMENT(current_source, &source);
auto* mat1 = args[0];
auto* mat2 = args[1];
auto* mat1_ty = mat1->Type()->As<sem::Matrix>();
auto* mat2_ty = mat2->Type()->As<sem::Matrix>();
auto* elem_ty = mat1_ty->type();
auto dot = [&](const sem::Constant* m1, size_t row, const sem::Constant* m2, size_t col) {
auto m1e = [&](size_t r, size_t c) { return m1->Index(c)->Index(r); };
auto m2e = [&](size_t r, size_t c) { return m2->Index(c)->Index(r); };
utils::Result<const Constant*> result;
switch (mat1_ty->columns()) {
case 2:
result = Dispatch_fa_f32_f16(Dot2Func(elem_ty), //
m1e(row, 0), //
m1e(row, 1), //
m2e(0, col), //
m2e(1, col));
break;
case 3:
result = Dispatch_fa_f32_f16(Dot3Func(elem_ty), //
m1e(row, 0), //
m1e(row, 1), //
m1e(row, 2), //
m2e(0, col), //
m2e(1, col), //
m2e(2, col));
break;
case 4:
result = Dispatch_fa_f32_f16(Dot4Func(elem_ty), //
m1e(row, 0), //
m1e(row, 1), //
m1e(row, 2), //
m1e(row, 3), //
m2e(0, col), //
m2e(1, col), //
m2e(2, col), //
m2e(3, col));
break;
}
return result;
};
utils::Vector<const sem::Constant*, 4> result_mat;
for (size_t c = 0; c < mat2_ty->columns(); ++c) {
utils::Vector<const sem::Constant*, 4> col_vec;
for (size_t r = 0; r < mat1_ty->rows(); ++r) {
auto v = dot(mat1, r, mat2, c); // mat1 row r * mat2 col c
if (!v) {
return utils::Failure;
}
col_vec.Push(v.Get()); // mat1 row r * mat2 col c
}
// Add column vector to matrix
auto* col_vec_ty = ty->As<sem::Matrix>()->ColumnType();
result_mat.Push(CreateComposite(builder, col_vec_ty, col_vec));
}
return CreateComposite(builder, ty, result_mat);
}
ConstEval::ConstantResult ConstEval::atan2(const sem::Type*,
utils::VectorRef<const sem::Constant*> args,
const Source&) {

View File

@ -230,6 +230,42 @@ class ConstEval {
utils::VectorRef<const sem::Constant*> args,
const Source& source);
/// Multiply operator '*' for the same type on the LHS and RHS
/// @param ty the expression type
/// @param args the input arguments
/// @param source the source location of the conversion
/// @return the result value, or null if the value cannot be calculated
ConstantResult OpMultiply(const sem::Type* ty,
utils::VectorRef<const sem::Constant*> args,
const Source& source);
/// Multiply operator '*' for matCxR<T> * vecC<T>
/// @param ty the expression type
/// @param args the input arguments
/// @param source the source location of the conversion
/// @return the result value, or null if the value cannot be calculated
ConstantResult OpMultiplyMatVec(const sem::Type* ty,
utils::VectorRef<const sem::Constant*> args,
const Source& source);
/// Multiply operator '*' for vecR<T> * matCxR<T>
/// @param ty the expression type
/// @param args the input arguments
/// @param source the source location of the conversion
/// @return the result value, or null if the value cannot be calculated
ConstantResult OpMultiplyVecMat(const sem::Type* ty,
utils::VectorRef<const sem::Constant*> args,
const Source& source);
/// Multiply operator '*' for matKxR<T> * matCxK<T>
/// @param ty the expression type
/// @param args the input arguments
/// @param source the source location of the conversion
/// @return the result value, or null if the value cannot be calculated
ConstantResult OpMultiplyMatMat(const sem::Type* ty,
utils::VectorRef<const sem::Constant*> args,
const Source& source);
////////////////////////////////////////////////////////////////////////////
// Builtins
////////////////////////////////////////////////////////////////////////////
@ -259,7 +295,97 @@ class ConstEval {
/// Adds the given warning message to the diagnostics
void AddWarning(const std::string& msg, const Source& source) const;
/// Adds two Number<T>s
/// @param a the lhs number
/// @param b the rhs number
/// @returns the result number on success, or logs an error and returns Failure
template <typename NumberT>
utils::Result<NumberT> Add(NumberT a, NumberT b);
/// Multiplies two Number<T>s
/// @param a the lhs number
/// @param b the rhs number
/// @returns the result number on success, or logs an error and returns Failure
template <typename NumberT>
utils::Result<NumberT> Mul(NumberT a, NumberT b);
/// Returns the dot product of (a1,a2) with (b1,b2)
/// @param a1 component 1 of lhs vector
/// @param a2 component 2 of lhs vector
/// @param b1 component 1 of rhs vector
/// @param b2 component 2 of rhs vector
/// @returns the result number on success, or logs an error and returns Failure
template <typename NumberT>
utils::Result<NumberT> Dot2(NumberT a1, NumberT a2, NumberT b1, NumberT b2);
/// Returns the dot product of (a1,a2,a3) with (b1,b2,b3)
/// @param a1 component 1 of lhs vector
/// @param a2 component 2 of lhs vector
/// @param a3 component 3 of lhs vector
/// @param b1 component 1 of rhs vector
/// @param b2 component 2 of rhs vector
/// @param b3 component 3 of rhs vector
/// @returns the result number on success, or logs an error and returns Failure
template <typename NumberT>
utils::Result<NumberT> Dot3(NumberT a1,
NumberT a2,
NumberT a3,
NumberT b1,
NumberT b2,
NumberT b3);
/// Returns the dot product of (a1,b1,c1,d1) with (a2,b2,c2,d2)
/// @param a1 component 1 of lhs vector
/// @param a2 component 2 of lhs vector
/// @param a3 component 3 of lhs vector
/// @param a4 component 4 of lhs vector
/// @param b1 component 1 of rhs vector
/// @param b2 component 2 of rhs vector
/// @param b3 component 3 of rhs vector
/// @param b4 component 4 of rhs vector
/// @returns the result number on success, or logs an error and returns Failure
template <typename NumberT>
utils::Result<NumberT> Dot4(NumberT a1,
NumberT a2,
NumberT a3,
NumberT a4,
NumberT b1,
NumberT b2,
NumberT b3,
NumberT b4);
/// Returns a callable that calls Add, and creates a Constant with its result of type `elem_ty`
/// if successful, or returns Failure otherwise.
/// @param elem_ty the element type of the Constant to create on success
/// @returns the callable function
auto AddFunc(const sem::Type* elem_ty);
/// Returns a callable that calls Mul, and creates a Constant with its result of type `elem_ty`
/// if successful, or returns Failure otherwise.
/// @param elem_ty the element type of the Constant to create on success
/// @returns the callable function
auto MulFunc(const sem::Type* elem_ty);
/// Returns a callable that calls Dot2, and creates a Constant with its result of type `elem_ty`
/// if successful, or returns Failure otherwise.
/// @param elem_ty the element type of the Constant to create on success
/// @returns the callable function
auto Dot2Func(const sem::Type* elem_ty);
/// Returns a callable that calls Dot3, and creates a Constant with its result of type `elem_ty`
/// if successful, or returns Failure otherwise.
/// @param elem_ty the element type of the Constant to create on success
/// @returns the callable function
auto Dot3Func(const sem::Type* elem_ty);
/// Returns a callable that calls Dot4, and creates a Constant with its result of type `elem_ty`
/// if successful, or returns Failure otherwise.
/// @param elem_ty the element type of the Constant to create on success
/// @returns the callable function
auto Dot4Func(const sem::Type* elem_ty);
ProgramBuilder& builder;
const Source* current_source = nullptr;
};
} // namespace tint::resolver

View File

@ -15,6 +15,7 @@
#include <cmath>
#include <type_traits>
#include "gmock/gmock.h"
#include "gtest/gtest.h"
#include "src/tint/resolver/resolver_test_helper.h"
#include "src/tint/sem/builtin_type.h"
@ -24,6 +25,8 @@
#include "src/tint/sem/test_helper.h"
#include "src/tint/utils/transform.h"
using ::testing::HasSubstr;
using namespace tint::number_suffixes; // NOLINT
namespace tint::resolver {
@ -74,6 +77,19 @@ auto Abs(const Number<T>& v) {
}
}
TINT_BEGIN_DISABLE_WARNING(CONSTANT_OVERFLOW);
template <typename T>
constexpr Number<T> Mul(Number<T> v1, Number<T> v2) {
if constexpr (std::is_integral_v<T> && std::is_signed_v<T>) {
// For signed integrals, avoid C++ UB by multiplying as unsigned
using UT = std::make_unsigned_t<T>;
return static_cast<Number<T>>(static_cast<UT>(v1) * static_cast<UT>(v2));
} else {
return static_cast<Number<T>>(v1 * v2);
}
}
TINT_END_DISABLE_WARNING(CONSTANT_OVERFLOW);
// Concats any number of std::vectors
template <typename Vec, typename... Vecs>
[[nodiscard]] auto Concat(Vec&& v1, Vecs&&... vs) {
@ -3228,8 +3244,7 @@ Case C(T lhs, U rhs, V expected, bool overflow = false) {
return Case{Val(lhs), Val(rhs), Val(expected), overflow};
}
static std::ostream& operator<<(std::ostream& o, const Case& c) {
auto print_value = [&](auto&& value) {
static std::ostream& operator<<(std::ostream& o, const Types& types) {
std::visit(
[&](auto&& v) {
using ValueType = std::decay_t<decltype(v)>;
@ -3242,15 +3257,13 @@ static std::ostream& operator<<(std::ostream& o, const Case& c) {
}
o << ")";
},
value);
};
o << "lhs: ";
print_value(c.lhs);
o << ", rhs: ";
print_value(c.rhs);
o << ", expected: ";
print_value(c.expected);
o << ", overflow: " << c.overflow;
types);
return o;
}
static std::ostream& operator<<(std::ostream& o, const Case& c) {
o << "lhs: " << c.lhs << ", rhs: " << c.rhs << ", expected: " << c.expected
<< ", overflow: " << c.overflow;
return o;
}
@ -3281,7 +3294,6 @@ bool ForEachElemPair(const sem::Constant* a, const sem::Constant* b, Func&& f) {
using ResolverConstEvalBinaryOpTest = ResolverTestWithParam<std::tuple<ast::BinaryOp, Case>>;
TEST_P(ResolverConstEvalBinaryOpTest, Test) {
Enable(ast::Extension::kF16);
auto op = std::get<0>(GetParam());
auto& c = std::get<1>(GetParam());
@ -3300,10 +3312,8 @@ TEST_P(ResolverConstEvalBinaryOpTest, Test) {
auto* expr = create<ast::BinaryExpression>(op, lhs_expr, rhs_expr);
GlobalConst("C", expr);
auto* expected_expr = expected.Expr(*this);
GlobalConst("E", expected_expr);
EXPECT_TRUE(r()->Resolve()) << r()->error();
auto* sem = Sem().Get(expr);
@ -3413,6 +3423,215 @@ INSTANTIATE_TEST_SUITE_P(Sub,
OpSubFloatCases<f32>(),
OpSubFloatCases<f16>()))));
template <typename T>
std::vector<Case> OpMulScalarCases() {
return {
C(T{0}, T{0}, T{0}),
C(T{1}, T{2}, T{2}),
C(T{2}, T{3}, T{6}),
C(Negate(T{2}), T{3}, Negate(T{6})),
C(T::Highest(), T{1}, T::Highest()),
C(T::Lowest(), T{1}, T::Lowest()),
C(T::Highest(), T::Highest(), Mul(T::Highest(), T::Highest()), true),
C(T::Lowest(), T::Lowest(), Mul(T::Lowest(), T::Lowest()), true),
};
}
template <typename T>
std::vector<Case> OpMulVecCases() {
return {
// s * vec3 = vec3
C(Val(T{2.0}), Vec(T{1.25}, T{2.25}, T{3.25}), Vec(T{2.5}, T{4.5}, T{6.5})),
// vec3 * s = vec3
C(Vec(T{1.25}, T{2.25}, T{3.25}), Val(T{2.0}), Vec(T{2.5}, T{4.5}, T{6.5})),
// vec3 * vec3 = vec3
C(Vec(T{1.25}, T{2.25}, T{3.25}), Vec(T{2.0}, T{2.0}, T{2.0}), Vec(T{2.5}, T{4.5}, T{6.5})),
};
}
template <typename T>
std::vector<Case> OpMulMatCases() {
return {
// s * mat3x2 = mat3x2
C(Val(T{2.25}),
Mat({T{1.0}, T{4.0}}, //
{T{2.0}, T{5.0}}, //
{T{3.0}, T{6.0}}),
Mat({T{2.25}, T{9.0}}, //
{T{4.5}, T{11.25}}, //
{T{6.75}, T{13.5}})),
// mat3x2 * s = mat3x2
C(Mat({T{1.0}, T{4.0}}, //
{T{2.0}, T{5.0}}, //
{T{3.0}, T{6.0}}),
Val(T{2.25}),
Mat({T{2.25}, T{9.0}}, //
{T{4.5}, T{11.25}}, //
{T{6.75}, T{13.5}})),
// vec3 * mat2x3 = vec2
C(Vec(T{1.25}, T{2.25}, T{3.25}), //
Mat({T{1.0}, T{2.0}, T{3.0}}, //
{T{4.0}, T{5.0}, T{6.0}}), //
Vec(T{15.5}, T{35.75})),
// mat2x3 * vec2 = vec3
C(Mat({T{1.0}, T{2.0}, T{3.0}}, //
{T{4.0}, T{5.0}, T{6.0}}), //
Vec(T{1.25}, T{2.25}), //
Vec(T{10.25}, T{13.75}, T{17.25})),
// mat3x2 * mat2x3 = mat2x2
C(Mat({T{1.0}, T{2.0}}, //
{T{3.0}, T{4.0}}, //
{T{5.0}, T{6.0}}), //
Mat({T{1.25}, T{2.25}, T{3.25}}, //
{T{4.25}, T{5.25}, T{6.25}}), //
Mat({T{24.25}, T{31.0}}, //
{T{51.25}, T{67.0}})), //
};
}
INSTANTIATE_TEST_SUITE_P(Mul,
ResolverConstEvalBinaryOpTest,
testing::Combine( //
testing::Values(ast::BinaryOp::kMultiply),
testing::ValuesIn(Concat( //
OpMulScalarCases<AInt>(),
OpMulScalarCases<i32>(),
OpMulScalarCases<u32>(),
OpMulScalarCases<AFloat>(),
OpMulScalarCases<f32>(),
OpMulScalarCases<f16>(),
OpMulVecCases<AInt>(),
OpMulVecCases<i32>(),
OpMulVecCases<u32>(),
OpMulVecCases<AFloat>(),
OpMulVecCases<f32>(),
OpMulVecCases<f16>(),
OpMulMatCases<AFloat>(),
OpMulMatCases<f32>(),
OpMulMatCases<f16>()))));
// Tests for errors on overflow/underflow of binary operations with abstract numbers
struct OverflowCase {
ast::BinaryOp op;
Types lhs;
Types rhs;
std::string overflowed_result;
};
static std::ostream& operator<<(std::ostream& o, const OverflowCase& c) {
o << ast::FriendlyName(c.op) << ", lhs: " << c.lhs << ", rhs: " << c.rhs;
return o;
}
using ResolverConstEvalBinaryOpTest_Overflow = ResolverTestWithParam<OverflowCase>;
TEST_P(ResolverConstEvalBinaryOpTest_Overflow, Test) {
Enable(ast::Extension::kF16);
auto& c = GetParam();
auto* lhs_expr = std::visit([&](auto&& value) { return value.Expr(*this); }, c.lhs);
auto* rhs_expr = std::visit([&](auto&& value) { return value.Expr(*this); }, c.rhs);
auto* expr = create<ast::BinaryExpression>(Source{{1, 1}}, c.op, lhs_expr, rhs_expr);
GlobalConst("C", expr);
ASSERT_FALSE(r()->Resolve());
std::string type_name = std::visit(
[&](auto&& value) {
using ValueType = std::decay_t<decltype(value)>;
return tint::FriendlyName<typename ValueType::ElementType>();
},
c.lhs);
EXPECT_THAT(r()->error(), HasSubstr("1:1 error: '" + c.overflowed_result +
"' cannot be represented as '" + type_name + "'"));
}
INSTANTIATE_TEST_SUITE_P(
Test,
ResolverConstEvalBinaryOpTest_Overflow,
testing::Values( //
// scalar-scalar add
OverflowCase{ast::BinaryOp::kAdd, Val(AInt::Highest()), Val(1_a), "-9223372036854775808"},
OverflowCase{ast::BinaryOp::kAdd, Val(AInt::Lowest()), Val(-1_a), "9223372036854775807"},
OverflowCase{ast::BinaryOp::kAdd, Val(AFloat::Highest()), Val(AFloat::Highest()), "inf"},
OverflowCase{ast::BinaryOp::kAdd, Val(AFloat::Lowest()), Val(AFloat::Lowest()), "-inf"},
// scalar-scalar subtract
OverflowCase{ast::BinaryOp::kSubtract, Val(AInt::Lowest()), Val(1_a),
"9223372036854775807"},
OverflowCase{ast::BinaryOp::kSubtract, Val(AInt::Highest()), Val(-1_a),
"-9223372036854775808"},
OverflowCase{ast::BinaryOp::kSubtract, Val(AFloat::Highest()), Val(AFloat::Lowest()),
"inf"},
OverflowCase{ast::BinaryOp::kSubtract, Val(AFloat::Lowest()), Val(AFloat::Highest()),
"-inf"},
// scalar-scalar multiply
OverflowCase{ast::BinaryOp::kMultiply, Val(AInt::Highest()), Val(2_a), "-2"},
OverflowCase{ast::BinaryOp::kMultiply, Val(AInt::Lowest()), Val(-2_a), "0"},
// scalar-vector multiply
OverflowCase{ast::BinaryOp::kMultiply, Val(AInt::Highest()), Vec(2_a, 1_a), "-2"},
OverflowCase{ast::BinaryOp::kMultiply, Val(AInt::Lowest()), Vec(-2_a, 1_a), "0"},
// vector-matrix multiply
// Overflow from first multiplication of dot product of vector and matrix column 0
// i.e. (v[0] * m[0][0] + v[1] * m[0][1])
// ^
OverflowCase{ast::BinaryOp::kMultiply, //
Vec(AFloat::Highest(), 1.0_a), //
Mat({2.0_a, 1.0_a}, //
{1.0_a, 1.0_a}), //
"inf"},
// Overflow from second multiplication of dot product of vector and matrix column 0
// i.e. (v[0] * m[0][0] + v[1] * m[0][1])
// ^
OverflowCase{ast::BinaryOp::kMultiply, //
Vec(1.0_a, AFloat::Highest()), //
Mat({1.0_a, 2.0_a}, //
{1.0_a, 1.0_a}), //
"inf"},
// Overflow from addition of dot product of vector and matrix column 0
// i.e. (v[0] * m[0][0] + v[1] * m[0][1])
// ^
OverflowCase{ast::BinaryOp::kMultiply, //
Vec(AFloat::Highest(), AFloat::Highest()), //
Mat({1.0_a, 1.0_a}, //
{1.0_a, 1.0_a}), //
"inf"},
// matrix-matrix multiply
// Overflow from first multiplication of dot product of lhs row 0 and rhs column 0
// i.e. m1[0][0] * m2[0][0] + m1[0][1] * m[1][0]
// ^
OverflowCase{ast::BinaryOp::kMultiply, //
Mat({AFloat::Highest(), 1.0_a}, //
{1.0_a, 1.0_a}), //
Mat({2.0_a, 1.0_a}, //
{1.0_a, 1.0_a}), //
"inf"},
// Overflow from second multiplication of dot product of lhs row 0 and rhs column 0
// i.e. m1[0][0] * m2[0][0] + m1[0][1] * m[1][0]
// ^
OverflowCase{ast::BinaryOp::kMultiply, //
Mat({1.0_a, AFloat::Highest()}, //
{1.0_a, 1.0_a}), //
Mat({1.0_a, 1.0_a}, //
{2.0_a, 1.0_a}), //
"inf"},
// Overflow from addition of dot product of lhs row 0 and rhs column 0
// i.e. m1[0][0] * m2[0][0] + m1[0][1] * m[1][0]
// ^
OverflowCase{ast::BinaryOp::kMultiply, //
Mat({AFloat::Highest(), 1.0_a}, //
{AFloat::Highest(), 1.0_a}), //
Mat({1.0_a, 1.0_a}, //
{1.0_a, 1.0_a}), //
"inf"}
));
TEST_F(ResolverConstEvalTest, BinaryAbstractAddOverflow_AInt) {
GlobalConst("c", Add(Source{{1, 1}}, Expr(AInt::Highest()), 1_a));
EXPECT_FALSE(r()->Resolve());

View File

@ -9572,108 +9572,108 @@ constexpr OverloadInfo kOverloads[] = {
/* num parameters */ 2,
/* num template types */ 1,
/* num template numbers */ 0,
/* template types */ &kTemplateTypes[15],
/* template types */ &kTemplateTypes[13],
/* template numbers */ &kTemplateNumbers[10],
/* parameters */ &kParameters[727],
/* return matcher indices */ &kMatcherIndices[1],
/* flags */ OverloadFlags(OverloadFlag::kIsOperator, OverloadFlag::kSupportsVertexPipeline, OverloadFlag::kSupportsFragmentPipeline, OverloadFlag::kSupportsComputePipeline),
/* const eval */ nullptr,
/* const eval */ &ConstEval::OpMultiply,
},
{
/* [118] */
/* num parameters */ 2,
/* num template types */ 1,
/* num template numbers */ 1,
/* template types */ &kTemplateTypes[15],
/* template types */ &kTemplateTypes[13],
/* template numbers */ &kTemplateNumbers[6],
/* parameters */ &kParameters[725],
/* return matcher indices */ &kMatcherIndices[30],
/* flags */ OverloadFlags(OverloadFlag::kIsOperator, OverloadFlag::kSupportsVertexPipeline, OverloadFlag::kSupportsFragmentPipeline, OverloadFlag::kSupportsComputePipeline),
/* const eval */ nullptr,
/* const eval */ &ConstEval::OpMultiply,
},
{
/* [119] */
/* num parameters */ 2,
/* num template types */ 1,
/* num template numbers */ 1,
/* template types */ &kTemplateTypes[15],
/* template types */ &kTemplateTypes[13],
/* template numbers */ &kTemplateNumbers[6],
/* parameters */ &kParameters[723],
/* return matcher indices */ &kMatcherIndices[30],
/* flags */ OverloadFlags(OverloadFlag::kIsOperator, OverloadFlag::kSupportsVertexPipeline, OverloadFlag::kSupportsFragmentPipeline, OverloadFlag::kSupportsComputePipeline),
/* const eval */ nullptr,
/* const eval */ &ConstEval::OpMultiply,
},
{
/* [120] */
/* num parameters */ 2,
/* num template types */ 1,
/* num template numbers */ 1,
/* template types */ &kTemplateTypes[15],
/* template types */ &kTemplateTypes[13],
/* template numbers */ &kTemplateNumbers[6],
/* parameters */ &kParameters[721],
/* return matcher indices */ &kMatcherIndices[30],
/* flags */ OverloadFlags(OverloadFlag::kIsOperator, OverloadFlag::kSupportsVertexPipeline, OverloadFlag::kSupportsFragmentPipeline, OverloadFlag::kSupportsComputePipeline),
/* const eval */ nullptr,
/* const eval */ &ConstEval::OpMultiply,
},
{
/* [121] */
/* num parameters */ 2,
/* num template types */ 1,
/* num template numbers */ 2,
/* template types */ &kTemplateTypes[11],
/* template types */ &kTemplateTypes[12],
/* template numbers */ &kTemplateNumbers[6],
/* parameters */ &kParameters[719],
/* return matcher indices */ &kMatcherIndices[10],
/* flags */ OverloadFlags(OverloadFlag::kIsOperator, OverloadFlag::kSupportsVertexPipeline, OverloadFlag::kSupportsFragmentPipeline, OverloadFlag::kSupportsComputePipeline),
/* const eval */ nullptr,
/* const eval */ &ConstEval::OpMultiply,
},
{
/* [122] */
/* num parameters */ 2,
/* num template types */ 1,
/* num template numbers */ 2,
/* template types */ &kTemplateTypes[11],
/* template types */ &kTemplateTypes[12],
/* template numbers */ &kTemplateNumbers[6],
/* parameters */ &kParameters[717],
/* return matcher indices */ &kMatcherIndices[10],
/* flags */ OverloadFlags(OverloadFlag::kIsOperator, OverloadFlag::kSupportsVertexPipeline, OverloadFlag::kSupportsFragmentPipeline, OverloadFlag::kSupportsComputePipeline),
/* const eval */ nullptr,
/* const eval */ &ConstEval::OpMultiply,
},
{
/* [123] */
/* num parameters */ 2,
/* num template types */ 1,
/* num template numbers */ 2,
/* template types */ &kTemplateTypes[11],
/* template types */ &kTemplateTypes[12],
/* template numbers */ &kTemplateNumbers[1],
/* parameters */ &kParameters[715],
/* return matcher indices */ &kMatcherIndices[69],
/* flags */ OverloadFlags(OverloadFlag::kIsOperator, OverloadFlag::kSupportsVertexPipeline, OverloadFlag::kSupportsFragmentPipeline, OverloadFlag::kSupportsComputePipeline),
/* const eval */ nullptr,
/* const eval */ &ConstEval::OpMultiplyMatVec,
},
{
/* [124] */
/* num parameters */ 2,
/* num template types */ 1,
/* num template numbers */ 2,
/* template types */ &kTemplateTypes[11],
/* template types */ &kTemplateTypes[12],
/* template numbers */ &kTemplateNumbers[1],
/* parameters */ &kParameters[713],
/* return matcher indices */ &kMatcherIndices[30],
/* flags */ OverloadFlags(OverloadFlag::kIsOperator, OverloadFlag::kSupportsVertexPipeline, OverloadFlag::kSupportsFragmentPipeline, OverloadFlag::kSupportsComputePipeline),
/* const eval */ nullptr,
/* const eval */ &ConstEval::OpMultiplyVecMat,
},
{
/* [125] */
/* num parameters */ 2,
/* num template types */ 1,
/* num template numbers */ 3,
/* template types */ &kTemplateTypes[11],
/* template types */ &kTemplateTypes[12],
/* template numbers */ &kTemplateNumbers[0],
/* parameters */ &kParameters[711],
/* return matcher indices */ &kMatcherIndices[22],
/* flags */ OverloadFlags(OverloadFlag::kIsOperator, OverloadFlag::kSupportsVertexPipeline, OverloadFlag::kSupportsFragmentPipeline, OverloadFlag::kSupportsComputePipeline),
/* const eval */ nullptr,
/* const eval */ &ConstEval::OpMultiplyMatMat,
},
{
/* [126] */
@ -14630,15 +14630,15 @@ constexpr IntrinsicInfo kBinaryOperators[] = {
},
{
/* [2] */
/* op *<T : fiu32_f16>(T, T) -> T */
/* op *<T : fiu32_f16, N : num>(vec<N, T>, vec<N, T>) -> vec<N, T> */
/* op *<T : fiu32_f16, N : num>(vec<N, T>, T) -> vec<N, T> */
/* op *<T : fiu32_f16, N : num>(T, vec<N, T>) -> vec<N, T> */
/* op *<T : f32_f16, N : num, M : num>(T, mat<N, M, T>) -> mat<N, M, T> */
/* op *<T : f32_f16, N : num, M : num>(mat<N, M, T>, T) -> mat<N, M, T> */
/* op *<T : f32_f16, C : num, R : num>(mat<C, R, T>, vec<C, T>) -> vec<R, T> */
/* op *<T : f32_f16, C : num, R : num>(vec<R, T>, mat<C, R, T>) -> vec<C, T> */
/* op *<T : f32_f16, K : num, C : num, R : num>(mat<K, R, T>, mat<C, K, T>) -> mat<C, R, T> */
/* op *<T : fia_fiu32_f16>(T, T) -> T */
/* op *<T : fia_fiu32_f16, N : num>(vec<N, T>, vec<N, T>) -> vec<N, T> */
/* op *<T : fia_fiu32_f16, N : num>(vec<N, T>, T) -> vec<N, T> */
/* op *<T : fia_fiu32_f16, N : num>(T, vec<N, T>) -> vec<N, T> */
/* op *<T : fa_f32_f16, N : num, M : num>(T, mat<N, M, T>) -> mat<N, M, T> */
/* op *<T : fa_f32_f16, N : num, M : num>(mat<N, M, T>, T) -> mat<N, M, T> */
/* op *<T : fa_f32_f16, C : num, R : num>(mat<C, R, T>, vec<C, T>) -> vec<R, T> */
/* op *<T : fa_f32_f16, C : num, R : num>(vec<R, T>, mat<C, R, T>) -> vec<C, T> */
/* op *<T : fa_f32_f16, K : num, C : num, R : num>(mat<K, R, T>, mat<C, K, T>) -> mat<C, R, T> */
/* num overloads */ 9,
/* overloads */ &kOverloads[117],
},

View File

@ -641,15 +641,15 @@ TEST_F(IntrinsicTableTest, MismatchBinaryOp) {
EXPECT_EQ(Diagnostics().str(), R"(12:34 error: no matching overload for operator * (f32, bool)
9 candidate operators:
operator * (T, T) -> T where: T is f32, i32, u32 or f16
operator * (vecN<T>, T) -> vecN<T> where: T is f32, i32, u32 or f16
operator * (T, vecN<T>) -> vecN<T> where: T is f32, i32, u32 or f16
operator * (T, matNxM<T>) -> matNxM<T> where: T is f32 or f16
operator * (matNxM<T>, T) -> matNxM<T> where: T is f32 or f16
operator * (vecN<T>, vecN<T>) -> vecN<T> where: T is f32, i32, u32 or f16
operator * (matCxR<T>, vecC<T>) -> vecR<T> where: T is f32 or f16
operator * (vecR<T>, matCxR<T>) -> vecC<T> where: T is f32 or f16
operator * (matKxR<T>, matCxK<T>) -> matCxR<T> where: T is f32 or f16
operator * (T, T) -> T where: T is abstract-float, abstract-int, f32, i32, u32 or f16
operator * (vecN<T>, T) -> vecN<T> where: T is abstract-float, abstract-int, f32, i32, u32 or f16
operator * (T, vecN<T>) -> vecN<T> where: T is abstract-float, abstract-int, f32, i32, u32 or f16
operator * (T, matNxM<T>) -> matNxM<T> where: T is abstract-float, f32 or f16
operator * (matNxM<T>, T) -> matNxM<T> where: T is abstract-float, f32 or f16
operator * (vecN<T>, vecN<T>) -> vecN<T> where: T is abstract-float, abstract-int, f32, i32, u32 or f16
operator * (matCxR<T>, vecC<T>) -> vecR<T> where: T is abstract-float, f32 or f16
operator * (vecR<T>, matCxR<T>) -> vecC<T> where: T is abstract-float, f32 or f16
operator * (matKxR<T>, matCxK<T>) -> matCxR<T> where: T is abstract-float, f32 or f16
)");
}
@ -673,15 +673,15 @@ TEST_F(IntrinsicTableTest, MismatchCompoundOp) {
EXPECT_EQ(Diagnostics().str(), R"(12:34 error: no matching overload for operator *= (f32, bool)
9 candidate operators:
operator *= (T, T) -> T where: T is f32, i32, u32 or f16
operator *= (vecN<T>, T) -> vecN<T> where: T is f32, i32, u32 or f16
operator *= (T, vecN<T>) -> vecN<T> where: T is f32, i32, u32 or f16
operator *= (T, matNxM<T>) -> matNxM<T> where: T is f32 or f16
operator *= (matNxM<T>, T) -> matNxM<T> where: T is f32 or f16
operator *= (vecN<T>, vecN<T>) -> vecN<T> where: T is f32, i32, u32 or f16
operator *= (matCxR<T>, vecC<T>) -> vecR<T> where: T is f32 or f16
operator *= (vecR<T>, matCxR<T>) -> vecC<T> where: T is f32 or f16
operator *= (matKxR<T>, matCxK<T>) -> matCxR<T> where: T is f32 or f16
operator *= (T, T) -> T where: T is abstract-float, abstract-int, f32, i32, u32 or f16
operator *= (vecN<T>, T) -> vecN<T> where: T is abstract-float, abstract-int, f32, i32, u32 or f16
operator *= (T, vecN<T>) -> vecN<T> where: T is abstract-float, abstract-int, f32, i32, u32 or f16
operator *= (T, matNxM<T>) -> matNxM<T> where: T is abstract-float, f32 or f16
operator *= (matNxM<T>, T) -> matNxM<T> where: T is abstract-float, f32 or f16
operator *= (vecN<T>, vecN<T>) -> vecN<T> where: T is abstract-float, abstract-int, f32, i32, u32 or f16
operator *= (matCxR<T>, vecC<T>) -> vecR<T> where: T is abstract-float, f32 or f16
operator *= (vecR<T>, matCxR<T>) -> vecC<T> where: T is abstract-float, f32 or f16
operator *= (matKxR<T>, matCxK<T>) -> matCxR<T> where: T is abstract-float, f32 or f16
)");
}

View File

@ -17,6 +17,7 @@
#include <ostream>
#include <variant>
#include "src/tint/debug.h"
namespace tint::utils {
@ -36,6 +37,9 @@ struct [[nodiscard]] Result {
static_assert(!std::is_same_v<SUCCESS_TYPE, FAILURE_TYPE>,
"Result must not have the same type for SUCCESS_TYPE and FAILURE_TYPE");
/// Default constructor initializes to invalid state
Result() : value(std::monostate{}) {}
/// Constructor
/// @param success the success result
Result(const SUCCESS_TYPE& success) // NOLINT(runtime/explicit):
@ -47,27 +51,43 @@ struct [[nodiscard]] Result {
: value{failure} {}
/// @returns true if the result was a success
operator bool() const { return std::holds_alternative<SUCCESS_TYPE>(value); }
operator bool() const {
Validate();
return std::holds_alternative<SUCCESS_TYPE>(value);
}
/// @returns true if the result was a failure
bool operator!() const { return std::holds_alternative<FAILURE_TYPE>(value); }
bool operator!() const {
Validate();
return std::holds_alternative<FAILURE_TYPE>(value);
}
/// @returns the success value
/// @warning attempting to call this when the Result holds an failure will result in UB.
const SUCCESS_TYPE* operator->() const { return &std::get<SUCCESS_TYPE>(value); }
const SUCCESS_TYPE* operator->() const {
Validate();
return &(Get());
}
/// @returns the success value
/// @warning attempting to call this when the Result holds an failure value will result in UB.
const SUCCESS_TYPE& Get() const { return std::get<SUCCESS_TYPE>(value); }
const SUCCESS_TYPE& Get() const {
Validate();
return std::get<SUCCESS_TYPE>(value);
}
/// @returns the failure value
/// @warning attempting to call this when the Result holds a success value will result in UB.
const FAILURE_TYPE& Failure() const { return std::get<FAILURE_TYPE>(value); }
const FAILURE_TYPE& Failure() const {
Validate();
return std::get<FAILURE_TYPE>(value);
}
/// Equality operator
/// @param val the value to compare this Result to
/// @returns true if this result holds a success value equal to `value`
bool operator==(SUCCESS_TYPE val) const {
Validate();
if (auto* v = std::get_if<SUCCESS_TYPE>(&value)) {
return *v == val;
}
@ -78,14 +98,18 @@ struct [[nodiscard]] Result {
/// @param val the value to compare this Result to
/// @returns true if this result holds a failure value equal to `value`
bool operator==(FAILURE_TYPE val) const {
Validate();
if (auto* v = std::get_if<FAILURE_TYPE>(&value)) {
return *v == val;
}
return false;
}
private:
void Validate() const { TINT_ASSERT(Utils, !std::holds_alternative<std::monostate>(value)); }
/// The result. Either a success of failure value.
std::variant<SUCCESS_TYPE, FAILURE_TYPE> value;
std::variant<std::monostate, SUCCESS_TYPE, FAILURE_TYPE> value;
};
/// Writes the result to the ostream.

View File

@ -151,7 +151,8 @@ INSTANTIATE_TEST_SUITE_P(
BinaryData{"(left % right)", ast::BinaryOp::kModulo}));
TEST_F(GlslGeneratorImplTest_Binary, Multiply_VectorScalar_f32) {
auto* lhs = vec3<f32>(1_f, 1_f, 1_f);
GlobalVar("a", vec3<f32>(1_f, 1_f, 1_f), ast::StorageClass::kPrivate);
auto* lhs = Expr("a");
auto* rhs = Expr(1_f);
auto* expr = create<ast::BinaryExpression>(ast::BinaryOp::kMultiply, lhs, rhs);
@ -162,13 +163,14 @@ TEST_F(GlslGeneratorImplTest_Binary, Multiply_VectorScalar_f32) {
std::stringstream out;
EXPECT_TRUE(gen.EmitExpression(out, expr)) << gen.error();
EXPECT_EQ(out.str(), "(vec3(1.0f) * 1.0f)");
EXPECT_EQ(out.str(), "(a * 1.0f)");
}
TEST_F(GlslGeneratorImplTest_Binary, Multiply_VectorScalar_f16) {
Enable(ast::Extension::kF16);
auto* lhs = vec3<f16>(1_h, 1_h, 1_h);
GlobalVar("a", vec3<f16>(1_h, 1_h, 1_h), ast::StorageClass::kPrivate);
auto* lhs = Expr("a");
auto* rhs = Expr(1_h);
auto* expr = create<ast::BinaryExpression>(ast::BinaryOp::kMultiply, lhs, rhs);
@ -179,12 +181,13 @@ TEST_F(GlslGeneratorImplTest_Binary, Multiply_VectorScalar_f16) {
std::stringstream out;
EXPECT_TRUE(gen.EmitExpression(out, expr)) << gen.error();
EXPECT_EQ(out.str(), "(f16vec3(1.0hf) * 1.0hf)");
EXPECT_EQ(out.str(), "(a * 1.0hf)");
}
TEST_F(GlslGeneratorImplTest_Binary, Multiply_ScalarVector_f32) {
GlobalVar("a", vec3<f32>(1_f, 1_f, 1_f), ast::StorageClass::kPrivate);
auto* lhs = Expr(1_f);
auto* rhs = vec3<f32>(1_f, 1_f, 1_f);
auto* rhs = Expr("a");
auto* expr = create<ast::BinaryExpression>(ast::BinaryOp::kMultiply, lhs, rhs);
@ -194,14 +197,15 @@ TEST_F(GlslGeneratorImplTest_Binary, Multiply_ScalarVector_f32) {
std::stringstream out;
EXPECT_TRUE(gen.EmitExpression(out, expr)) << gen.error();
EXPECT_EQ(out.str(), "(1.0f * vec3(1.0f))");
EXPECT_EQ(out.str(), "(1.0f * a)");
}
TEST_F(GlslGeneratorImplTest_Binary, Multiply_ScalarVector_f16) {
Enable(ast::Extension::kF16);
GlobalVar("a", vec3<f16>(1_h, 1_h, 1_h), ast::StorageClass::kPrivate);
auto* lhs = Expr(1_h);
auto* rhs = vec3<f16>(1_h, 1_h, 1_h);
auto* rhs = Expr("a");
auto* expr = create<ast::BinaryExpression>(ast::BinaryOp::kMultiply, lhs, rhs);
@ -211,7 +215,7 @@ TEST_F(GlslGeneratorImplTest_Binary, Multiply_ScalarVector_f16) {
std::stringstream out;
EXPECT_TRUE(gen.EmitExpression(out, expr)) << gen.error();
EXPECT_EQ(out.str(), "(1.0hf * f16vec3(1.0hf))");
EXPECT_EQ(out.str(), "(1.0hf * a)");
}
TEST_F(GlslGeneratorImplTest_Binary, Multiply_MatrixScalar_f32) {

View File

@ -184,7 +184,7 @@ TEST_F(HlslGeneratorImplTest_Binary, Multiply_VectorScalar_f32) {
std::stringstream out;
EXPECT_TRUE(gen.EmitExpression(out, expr)) << gen.error();
EXPECT_EQ(out.str(), "((1.0f).xxx * 1.0f)");
EXPECT_EQ(out.str(), "(1.0f).xxx");
}
TEST_F(HlslGeneratorImplTest_Binary, Multiply_VectorScalar_f16) {
@ -201,7 +201,7 @@ TEST_F(HlslGeneratorImplTest_Binary, Multiply_VectorScalar_f16) {
std::stringstream out;
EXPECT_TRUE(gen.EmitExpression(out, expr)) << gen.error();
EXPECT_EQ(out.str(), "((float16_t(1.0h)).xxx * float16_t(1.0h))");
EXPECT_EQ(out.str(), "(float16_t(1.0h)).xxx");
}
TEST_F(HlslGeneratorImplTest_Binary, Multiply_ScalarVector_f32) {
@ -216,7 +216,7 @@ TEST_F(HlslGeneratorImplTest_Binary, Multiply_ScalarVector_f32) {
std::stringstream out;
EXPECT_TRUE(gen.EmitExpression(out, expr)) << gen.error();
EXPECT_EQ(out.str(), "(1.0f * (1.0f).xxx)");
EXPECT_EQ(out.str(), "(1.0f).xxx");
}
TEST_F(HlslGeneratorImplTest_Binary, Multiply_ScalarVector_f16) {
@ -233,7 +233,7 @@ TEST_F(HlslGeneratorImplTest_Binary, Multiply_ScalarVector_f16) {
std::stringstream out;
EXPECT_TRUE(gen.EmitExpression(out, expr)) << gen.error();
EXPECT_EQ(out.str(), "(float16_t(1.0h) * (float16_t(1.0h)).xxx)");
EXPECT_EQ(out.str(), "(float16_t(1.0h)).xxx");
}
TEST_F(HlslGeneratorImplTest_Binary, Multiply_MatrixScalar_f32) {

View File

@ -10,7 +10,7 @@ struct tint_symbol_1 {
};
void main_inner(uint3 GlobalInvocationID) {
uint flatIndex = ((((2u * 2u) * GlobalInvocationID.z) + (2u * GlobalInvocationID.y)) + GlobalInvocationID.x);
uint flatIndex = (((4u * GlobalInvocationID.z) + (2u * GlobalInvocationID.y)) + GlobalInvocationID.x);
flatIndex = (flatIndex * 1u);
float4 texel = myTexture.Load(int4(int3(int2(GlobalInvocationID.xy), 0), 0));
{

View File

@ -10,7 +10,7 @@ struct tint_symbol_1 {
};
void main_inner(uint3 GlobalInvocationID) {
uint flatIndex = ((((2u * 2u) * GlobalInvocationID.z) + (2u * GlobalInvocationID.y)) + GlobalInvocationID.x);
uint flatIndex = (((4u * GlobalInvocationID.z) + (2u * GlobalInvocationID.y)) + GlobalInvocationID.x);
flatIndex = (flatIndex * 1u);
float4 texel = myTexture.Load(int4(int3(int2(GlobalInvocationID.xy), 0), 0));
{

View File

@ -9,7 +9,7 @@ layout(binding = 3, std430) buffer Result_1 {
} result;
uniform highp sampler2DArray myTexture_1;
void tint_symbol(uvec3 GlobalInvocationID) {
uint flatIndex = ((((2u * 2u) * GlobalInvocationID.z) + (2u * GlobalInvocationID.y)) + GlobalInvocationID.x);
uint flatIndex = (((4u * GlobalInvocationID.z) + (2u * GlobalInvocationID.y)) + GlobalInvocationID.x);
flatIndex = (flatIndex * 1u);
vec4 texel = texelFetch(myTexture_1, ivec3(ivec2(GlobalInvocationID.xy), 0), 0);
{

View File

@ -23,7 +23,7 @@ struct Result {
};
void tint_symbol_inner(uint3 GlobalInvocationID, texture2d_array<float, access::sample> tint_symbol_1, device Result* const tint_symbol_2) {
uint flatIndex = ((((2u * 2u) * GlobalInvocationID[2]) + (2u * GlobalInvocationID[1])) + GlobalInvocationID[0]);
uint flatIndex = (((4u * GlobalInvocationID[2]) + (2u * GlobalInvocationID[1])) + GlobalInvocationID[0]);
flatIndex = (flatIndex * 1u);
float4 texel = tint_symbol_1.read(uint2(int2(uint3(GlobalInvocationID).xy)), 0, 0);
for(uint i = 0u; (i < 1u); i = (i + 1u)) {

View File

@ -52,6 +52,7 @@
%result = OpVariable %_ptr_StorageBuffer_Result StorageBuffer
%void = OpTypeVoid
%17 = OpTypeFunction %void %v3uint
%uint_4 = OpConstant %uint 4
%uint_2 = OpConstant %uint 2
%_ptr_Function_uint = OpTypePointer Function %uint
%33 = OpConstantNull %uint
@ -74,12 +75,11 @@
%flatIndex = OpVariable %_ptr_Function_uint Function %33
%texel = OpVariable %_ptr_Function_v4float Function %51
%i = OpVariable %_ptr_Function_uint Function %33
%23 = OpIMul %uint %uint_2 %uint_2
%24 = OpCompositeExtract %uint %GlobalInvocationID 2
%25 = OpIMul %uint %23 %24
%23 = OpCompositeExtract %uint %GlobalInvocationID 2
%24 = OpIMul %uint %uint_4 %23
%26 = OpCompositeExtract %uint %GlobalInvocationID 1
%27 = OpIMul %uint %uint_2 %26
%28 = OpIAdd %uint %25 %27
%28 = OpIAdd %uint %24 %27
%29 = OpCompositeExtract %uint %GlobalInvocationID 0
%30 = OpIAdd %uint %28 %29
OpStore %flatIndex %30

View File

@ -68,7 +68,7 @@ void main_inner(uint3 local_id, uint3 global_id, uint local_invocation_index) {
float ACached = 0.0f;
float BCached[4] = (float[4])0;
{
[loop] for(uint index = 0u; (index < (4u * 4u)); index = (index + 1u)) {
[loop] for(uint index = 0u; (index < 16u); index = (index + 1u)) {
acc[index] = 0.0f;
}
}

View File

@ -68,7 +68,7 @@ void main_inner(uint3 local_id, uint3 global_id, uint local_invocation_index) {
float ACached = 0.0f;
float BCached[4] = (float[4])0;
{
[loop] for(uint index = 0u; (index < (4u * 4u)); index = (index + 1u)) {
[loop] for(uint index = 0u; (index < 16u); index = (index + 1u)) {
acc[index] = 0.0f;
}
}

View File

@ -77,7 +77,7 @@ void tint_symbol(uvec3 local_id, uvec3 global_id, uint local_invocation_index) {
float ACached = 0.0f;
float BCached[4] = float[4](0.0f, 0.0f, 0.0f, 0.0f);
{
for(uint index = 0u; (index < (4u * 4u)); index = (index + 1u)) {
for(uint index = 0u; (index < 16u); index = (index + 1u)) {
acc[index] = 0.0f;
}
}

View File

@ -63,7 +63,7 @@ void tint_symbol_inner(uint3 local_id, uint3 global_id, uint local_invocation_in
tint_array<float, 16> acc = {};
float ACached = 0.0f;
tint_array<float, 4> BCached = {};
for(uint index = 0u; (index < (4u * 4u)); index = (index + 1u)) {
for(uint index = 0u; (index < 16u); index = (index + 1u)) {
acc[index] = 0.0f;
}
uint const ColPerThreadA = (64u / 16u);

View File

@ -1,7 +1,7 @@
; SPIR-V
; Version: 1.3
; Generator: Google Tint Compiler; 0
; Bound: 374
; Bound: 373
; Schema: 0
OpCapability Shader
OpMemoryModel Logical GLSL450
@ -127,7 +127,7 @@
%_arr_float_uint_4 = OpTypeArray %float %uint_4
%_ptr_Function__arr_float_uint_4 = OpTypePointer Function %_arr_float_uint_4
%152 = OpConstantNull %_arr_float_uint_4
%367 = OpTypeFunction %void
%366 = OpTypeFunction %void
%mm_readA = OpFunction %float None %24
%row = OpFunctionParameter %uint
%col = OpFunctionParameter %uint
@ -287,334 +287,333 @@
OpBranch %157
%157 = OpLabel
%159 = OpLoad %uint %index
%160 = OpIMul %uint %uint_4 %uint_4
%161 = OpULessThan %bool %159 %160
%158 = OpLogicalNot %bool %161
OpSelectionMerge %162 None
OpBranchConditional %158 %163 %162
%163 = OpLabel
OpBranch %155
%160 = OpULessThan %bool %159 %uint_16
%158 = OpLogicalNot %bool %160
OpSelectionMerge %161 None
OpBranchConditional %158 %162 %161
%162 = OpLabel
%164 = OpLoad %uint %index
%165 = OpAccessChain %_ptr_Function_float %acc %164
OpStore %165 %51
OpBranch %155
%161 = OpLabel
%163 = OpLoad %uint %index
%164 = OpAccessChain %_ptr_Function_float %acc %163
OpStore %164 %51
OpBranch %156
%156 = OpLabel
%166 = OpLoad %uint %index
%167 = OpIAdd %uint %166 %uint_1
OpStore %index %167
%165 = OpLoad %uint %index
%166 = OpIAdd %uint %165 %uint_1
OpStore %index %166
OpBranch %154
%155 = OpLabel
%168 = OpUDiv %uint %uint_64 %uint_16
%169 = OpCompositeExtract %uint %local_id 0
%170 = OpIMul %uint %169 %168
%171 = OpUDiv %uint %uint_64 %uint_16
%172 = OpCompositeExtract %uint %local_id 1
%173 = OpIMul %uint %172 %171
%167 = OpUDiv %uint %uint_64 %uint_16
%168 = OpCompositeExtract %uint %local_id 0
%169 = OpIMul %uint %168 %167
%170 = OpUDiv %uint %uint_64 %uint_16
%171 = OpCompositeExtract %uint %local_id 1
%172 = OpIMul %uint %171 %170
OpStore %t %105
OpBranch %175
%175 = OpLabel
OpLoopMerge %176 %177 None
OpBranch %178
%178 = OpLabel
%180 = OpLoad %uint %t
%181 = OpULessThan %bool %180 %141
%179 = OpLogicalNot %bool %181
OpSelectionMerge %182 None
OpBranchConditional %179 %183 %182
%183 = OpLabel
OpBranch %176
%182 = OpLabel
OpStore %innerRow %105
OpBranch %185
%185 = OpLabel
OpLoopMerge %186 %187 None
OpBranch %188
%188 = OpLabel
%190 = OpLoad %uint %innerRow
%191 = OpULessThan %bool %190 %uint_4
%189 = OpLogicalNot %bool %191
OpSelectionMerge %192 None
OpBranchConditional %189 %193 %192
%193 = OpLabel
OpBranch %186
%192 = OpLabel
OpStore %innerCol %105
OpBranch %195
%195 = OpLabel
OpLoopMerge %196 %197 None
OpBranch %198
%198 = OpLabel
%200 = OpLoad %uint %innerCol
%201 = OpULessThan %bool %200 %168
%199 = OpLogicalNot %bool %201
OpSelectionMerge %202 None
OpBranchConditional %199 %203 %202
%203 = OpLabel
OpBranch %196
%202 = OpLabel
%204 = OpLoad %uint %innerRow
%205 = OpIAdd %uint %130 %204
%206 = OpLoad %uint %innerCol
%207 = OpIAdd %uint %170 %206
%209 = OpLoad %uint %innerRow
%210 = OpIAdd %uint %134 %209
%211 = OpLoad %uint %t
%212 = OpIMul %uint %211 %uint_64
%213 = OpIAdd %uint %212 %207
%208 = OpFunctionCall %float %mm_readA %210 %213
%214 = OpAccessChain %_ptr_Workgroup_float %mm_Asub %205 %207
OpStore %214 %208
OpBranch %197
%197 = OpLabel
%215 = OpLoad %uint %innerCol
%216 = OpIAdd %uint %215 %uint_1
OpStore %innerCol %216
OpBranch %195
%196 = OpLabel
OpBranch %187
%187 = OpLabel
%217 = OpLoad %uint %innerRow
%218 = OpIAdd %uint %217 %uint_1
OpStore %innerRow %218
OpBranch %185
%186 = OpLabel
OpStore %innerRow_0 %105
OpBranch %220
%220 = OpLabel
OpLoopMerge %221 %222 None
OpBranch %223
%223 = OpLabel
%225 = OpLoad %uint %innerRow_0
%226 = OpULessThan %bool %225 %171
%224 = OpLogicalNot %bool %226
OpSelectionMerge %227 None
OpBranchConditional %224 %228 %227
%228 = OpLabel
OpBranch %221
%227 = OpLabel
OpStore %innerCol_0 %105
OpBranch %230
%230 = OpLabel
OpLoopMerge %231 %232 None
OpBranch %233
%233 = OpLabel
%235 = OpLoad %uint %innerCol_0
%236 = OpULessThan %bool %235 %uint_4
%234 = OpLogicalNot %bool %236
OpSelectionMerge %237 None
OpBranchConditional %234 %238 %237
%238 = OpLabel
OpBranch %231
%237 = OpLabel
%239 = OpLoad %uint %innerRow_0
%240 = OpIAdd %uint %173 %239
%241 = OpLoad %uint %innerCol_0
%242 = OpIAdd %uint %132 %241
%244 = OpLoad %uint %t
%245 = OpIMul %uint %244 %uint_64
%246 = OpIAdd %uint %245 %240
%247 = OpLoad %uint %innerCol_0
%248 = OpIAdd %uint %136 %247
%243 = OpFunctionCall %float %mm_readB %246 %248
%249 = OpLoad %uint %innerCol_0
%250 = OpAccessChain %_ptr_Workgroup_float %mm_Bsub %249 %242
OpStore %250 %243
OpBranch %232
%232 = OpLabel
%251 = OpLoad %uint %innerCol_0
%252 = OpIAdd %uint %251 %uint_1
OpStore %innerCol_0 %252
OpBranch %230
%231 = OpLabel
OpBranch %222
%222 = OpLabel
%253 = OpLoad %uint %innerRow_0
%254 = OpIAdd %uint %253 %uint_1
OpStore %innerRow_0 %254
OpBranch %220
%221 = OpLabel
OpControlBarrier %uint_2 %uint_2 %uint_264
OpStore %k %105
OpBranch %257
%257 = OpLabel
OpLoopMerge %258 %259 None
OpBranch %260
%260 = OpLabel
%262 = OpLoad %uint %k
%263 = OpULessThan %bool %262 %uint_64
%261 = OpLogicalNot %bool %263
OpSelectionMerge %264 None
OpBranchConditional %261 %265 %264
%265 = OpLabel
OpBranch %258
%264 = OpLabel
OpStore %inner %105
OpBranch %267
%267 = OpLabel
OpLoopMerge %268 %269 None
OpBranch %270
%270 = OpLabel
%272 = OpLoad %uint %inner
%273 = OpULessThan %bool %272 %uint_4
%271 = OpLogicalNot %bool %273
OpSelectionMerge %274 None
OpBranchConditional %271 %275 %274
%275 = OpLabel
OpBranch %268
%274 = OpLabel
%276 = OpLoad %uint %inner
%277 = OpAccessChain %_ptr_Function_float %BCached %276
%278 = OpLoad %uint %k
%279 = OpLoad %uint %inner
%280 = OpIAdd %uint %132 %279
%281 = OpAccessChain %_ptr_Workgroup_float %mm_Bsub %278 %280
%282 = OpLoad %float %281
OpStore %277 %282
OpBranch %269
%269 = OpLabel
%283 = OpLoad %uint %inner
%284 = OpIAdd %uint %283 %uint_1
OpStore %inner %284
OpBranch %267
%268 = OpLabel
OpStore %innerRow_1 %105
OpBranch %286
%286 = OpLabel
OpLoopMerge %287 %288 None
OpBranch %289
%289 = OpLabel
%291 = OpLoad %uint %innerRow_1
%292 = OpULessThan %bool %291 %uint_4
%290 = OpLogicalNot %bool %292
OpSelectionMerge %293 None
OpBranchConditional %290 %294 %293
%294 = OpLabel
OpBranch %287
%293 = OpLabel
%295 = OpLoad %uint %innerRow_1
%296 = OpIAdd %uint %130 %295
%297 = OpLoad %uint %k
%298 = OpAccessChain %_ptr_Workgroup_float %mm_Asub %296 %297
%299 = OpLoad %float %298
OpStore %ACached %299
OpStore %innerCol_1 %105
OpBranch %301
%301 = OpLabel
OpLoopMerge %302 %303 None
OpBranch %304
%304 = OpLabel
%306 = OpLoad %uint %innerCol_1
%307 = OpULessThan %bool %306 %uint_4
%305 = OpLogicalNot %bool %307
OpSelectionMerge %308 None
OpBranchConditional %305 %309 %308
%309 = OpLabel
OpBranch %302
%308 = OpLabel
%310 = OpLoad %uint %innerRow_1
%311 = OpIMul %uint %310 %uint_4
%312 = OpLoad %uint %innerCol_1
%313 = OpIAdd %uint %311 %312
%314 = OpAccessChain %_ptr_Function_float %acc %313
%315 = OpAccessChain %_ptr_Function_float %acc %313
%316 = OpLoad %float %315
%317 = OpLoad %float %ACached
%318 = OpLoad %uint %innerCol_1
%319 = OpAccessChain %_ptr_Function_float %BCached %318
%320 = OpLoad %float %319
%321 = OpFMul %float %317 %320
%322 = OpFAdd %float %316 %321
OpStore %314 %322
OpBranch %303
%303 = OpLabel
%323 = OpLoad %uint %innerCol_1
%324 = OpIAdd %uint %323 %uint_1
OpStore %innerCol_1 %324
OpBranch %301
%302 = OpLabel
OpBranch %288
%288 = OpLabel
%325 = OpLoad %uint %innerRow_1
%326 = OpIAdd %uint %325 %uint_1
OpStore %innerRow_1 %326
OpBranch %286
%287 = OpLabel
OpBranch %259
%259 = OpLabel
%327 = OpLoad %uint %k
%328 = OpIAdd %uint %327 %uint_1
OpStore %k %328
OpBranch %257
%258 = OpLabel
OpControlBarrier %uint_2 %uint_2 %uint_264
OpBranch %174
%174 = OpLabel
OpLoopMerge %175 %176 None
OpBranch %177
%177 = OpLabel
%330 = OpLoad %uint %t
%331 = OpIAdd %uint %330 %uint_1
OpStore %t %331
%179 = OpLoad %uint %t
%180 = OpULessThan %bool %179 %141
%178 = OpLogicalNot %bool %180
OpSelectionMerge %181 None
OpBranchConditional %178 %182 %181
%182 = OpLabel
OpBranch %175
%181 = OpLabel
OpStore %innerRow %105
OpBranch %184
%184 = OpLabel
OpLoopMerge %185 %186 None
OpBranch %187
%187 = OpLabel
%189 = OpLoad %uint %innerRow
%190 = OpULessThan %bool %189 %uint_4
%188 = OpLogicalNot %bool %190
OpSelectionMerge %191 None
OpBranchConditional %188 %192 %191
%192 = OpLabel
OpBranch %185
%191 = OpLabel
OpStore %innerCol %105
OpBranch %194
%194 = OpLabel
OpLoopMerge %195 %196 None
OpBranch %197
%197 = OpLabel
%199 = OpLoad %uint %innerCol
%200 = OpULessThan %bool %199 %167
%198 = OpLogicalNot %bool %200
OpSelectionMerge %201 None
OpBranchConditional %198 %202 %201
%202 = OpLabel
OpBranch %195
%201 = OpLabel
%203 = OpLoad %uint %innerRow
%204 = OpIAdd %uint %130 %203
%205 = OpLoad %uint %innerCol
%206 = OpIAdd %uint %169 %205
%208 = OpLoad %uint %innerRow
%209 = OpIAdd %uint %134 %208
%210 = OpLoad %uint %t
%211 = OpIMul %uint %210 %uint_64
%212 = OpIAdd %uint %211 %206
%207 = OpFunctionCall %float %mm_readA %209 %212
%213 = OpAccessChain %_ptr_Workgroup_float %mm_Asub %204 %206
OpStore %213 %207
OpBranch %196
%196 = OpLabel
%214 = OpLoad %uint %innerCol
%215 = OpIAdd %uint %214 %uint_1
OpStore %innerCol %215
OpBranch %194
%195 = OpLabel
OpBranch %186
%186 = OpLabel
%216 = OpLoad %uint %innerRow
%217 = OpIAdd %uint %216 %uint_1
OpStore %innerRow %217
OpBranch %184
%185 = OpLabel
OpStore %innerRow_0 %105
OpBranch %219
%219 = OpLabel
OpLoopMerge %220 %221 None
OpBranch %222
%222 = OpLabel
%224 = OpLoad %uint %innerRow_0
%225 = OpULessThan %bool %224 %170
%223 = OpLogicalNot %bool %225
OpSelectionMerge %226 None
OpBranchConditional %223 %227 %226
%227 = OpLabel
OpBranch %220
%226 = OpLabel
OpStore %innerCol_0 %105
OpBranch %229
%229 = OpLabel
OpLoopMerge %230 %231 None
OpBranch %232
%232 = OpLabel
%234 = OpLoad %uint %innerCol_0
%235 = OpULessThan %bool %234 %uint_4
%233 = OpLogicalNot %bool %235
OpSelectionMerge %236 None
OpBranchConditional %233 %237 %236
%237 = OpLabel
OpBranch %230
%236 = OpLabel
%238 = OpLoad %uint %innerRow_0
%239 = OpIAdd %uint %172 %238
%240 = OpLoad %uint %innerCol_0
%241 = OpIAdd %uint %132 %240
%243 = OpLoad %uint %t
%244 = OpIMul %uint %243 %uint_64
%245 = OpIAdd %uint %244 %239
%246 = OpLoad %uint %innerCol_0
%247 = OpIAdd %uint %136 %246
%242 = OpFunctionCall %float %mm_readB %245 %247
%248 = OpLoad %uint %innerCol_0
%249 = OpAccessChain %_ptr_Workgroup_float %mm_Bsub %248 %241
OpStore %249 %242
OpBranch %231
%231 = OpLabel
%250 = OpLoad %uint %innerCol_0
%251 = OpIAdd %uint %250 %uint_1
OpStore %innerCol_0 %251
OpBranch %229
%230 = OpLabel
OpBranch %221
%221 = OpLabel
%252 = OpLoad %uint %innerRow_0
%253 = OpIAdd %uint %252 %uint_1
OpStore %innerRow_0 %253
OpBranch %219
%220 = OpLabel
OpControlBarrier %uint_2 %uint_2 %uint_264
OpStore %k %105
OpBranch %256
%256 = OpLabel
OpLoopMerge %257 %258 None
OpBranch %259
%259 = OpLabel
%261 = OpLoad %uint %k
%262 = OpULessThan %bool %261 %uint_64
%260 = OpLogicalNot %bool %262
OpSelectionMerge %263 None
OpBranchConditional %260 %264 %263
%264 = OpLabel
OpBranch %257
%263 = OpLabel
OpStore %inner %105
OpBranch %266
%266 = OpLabel
OpLoopMerge %267 %268 None
OpBranch %269
%269 = OpLabel
%271 = OpLoad %uint %inner
%272 = OpULessThan %bool %271 %uint_4
%270 = OpLogicalNot %bool %272
OpSelectionMerge %273 None
OpBranchConditional %270 %274 %273
%274 = OpLabel
OpBranch %267
%273 = OpLabel
%275 = OpLoad %uint %inner
%276 = OpAccessChain %_ptr_Function_float %BCached %275
%277 = OpLoad %uint %k
%278 = OpLoad %uint %inner
%279 = OpIAdd %uint %132 %278
%280 = OpAccessChain %_ptr_Workgroup_float %mm_Bsub %277 %279
%281 = OpLoad %float %280
OpStore %276 %281
OpBranch %268
%268 = OpLabel
%282 = OpLoad %uint %inner
%283 = OpIAdd %uint %282 %uint_1
OpStore %inner %283
OpBranch %266
%267 = OpLabel
OpStore %innerRow_1 %105
OpBranch %285
%285 = OpLabel
OpLoopMerge %286 %287 None
OpBranch %288
%288 = OpLabel
%290 = OpLoad %uint %innerRow_1
%291 = OpULessThan %bool %290 %uint_4
%289 = OpLogicalNot %bool %291
OpSelectionMerge %292 None
OpBranchConditional %289 %293 %292
%293 = OpLabel
OpBranch %286
%292 = OpLabel
%294 = OpLoad %uint %innerRow_1
%295 = OpIAdd %uint %130 %294
%296 = OpLoad %uint %k
%297 = OpAccessChain %_ptr_Workgroup_float %mm_Asub %295 %296
%298 = OpLoad %float %297
OpStore %ACached %298
OpStore %innerCol_1 %105
OpBranch %300
%300 = OpLabel
OpLoopMerge %301 %302 None
OpBranch %303
%303 = OpLabel
%305 = OpLoad %uint %innerCol_1
%306 = OpULessThan %bool %305 %uint_4
%304 = OpLogicalNot %bool %306
OpSelectionMerge %307 None
OpBranchConditional %304 %308 %307
%308 = OpLabel
OpBranch %301
%307 = OpLabel
%309 = OpLoad %uint %innerRow_1
%310 = OpIMul %uint %309 %uint_4
%311 = OpLoad %uint %innerCol_1
%312 = OpIAdd %uint %310 %311
%313 = OpAccessChain %_ptr_Function_float %acc %312
%314 = OpAccessChain %_ptr_Function_float %acc %312
%315 = OpLoad %float %314
%316 = OpLoad %float %ACached
%317 = OpLoad %uint %innerCol_1
%318 = OpAccessChain %_ptr_Function_float %BCached %317
%319 = OpLoad %float %318
%320 = OpFMul %float %316 %319
%321 = OpFAdd %float %315 %320
OpStore %313 %321
OpBranch %302
%302 = OpLabel
%322 = OpLoad %uint %innerCol_1
%323 = OpIAdd %uint %322 %uint_1
OpStore %innerCol_1 %323
OpBranch %300
%301 = OpLabel
OpBranch %287
%287 = OpLabel
%324 = OpLoad %uint %innerRow_1
%325 = OpIAdd %uint %324 %uint_1
OpStore %innerRow_1 %325
OpBranch %285
%286 = OpLabel
OpBranch %258
%258 = OpLabel
%326 = OpLoad %uint %k
%327 = OpIAdd %uint %326 %uint_1
OpStore %k %327
OpBranch %256
%257 = OpLabel
OpControlBarrier %uint_2 %uint_2 %uint_264
OpBranch %176
%176 = OpLabel
%329 = OpLoad %uint %t
%330 = OpIAdd %uint %329 %uint_1
OpStore %t %330
OpBranch %174
%175 = OpLabel
OpStore %innerRow_2 %105
OpBranch %333
%333 = OpLabel
OpLoopMerge %334 %335 None
OpBranch %336
%336 = OpLabel
%338 = OpLoad %uint %innerRow_2
%339 = OpULessThan %bool %338 %uint_4
%337 = OpLogicalNot %bool %339
OpSelectionMerge %340 None
OpBranchConditional %337 %341 %340
%341 = OpLabel
OpBranch %334
%340 = OpLabel
OpStore %innerCol_2 %105
OpBranch %343
%343 = OpLabel
OpLoopMerge %344 %345 None
OpBranch %346
%346 = OpLabel
%348 = OpLoad %uint %innerCol_2
%349 = OpULessThan %bool %348 %uint_4
%347 = OpLogicalNot %bool %349
OpSelectionMerge %350 None
OpBranchConditional %347 %351 %350
%351 = OpLabel
OpBranch %344
%350 = OpLabel
%352 = OpLoad %uint %innerRow_2
%353 = OpIMul %uint %352 %uint_4
%354 = OpLoad %uint %innerCol_2
%355 = OpIAdd %uint %353 %354
%357 = OpLoad %uint %innerRow_2
%358 = OpIAdd %uint %134 %357
%359 = OpLoad %uint %innerCol_2
%360 = OpIAdd %uint %136 %359
%361 = OpAccessChain %_ptr_Function_float %acc %355
%362 = OpLoad %float %361
%356 = OpFunctionCall %void %mm_write %358 %360 %362
OpBranch %345
%345 = OpLabel
%363 = OpLoad %uint %innerCol_2
%364 = OpIAdd %uint %363 %uint_1
OpStore %innerCol_2 %364
OpBranch %343
%344 = OpLabel
OpBranch %332
%332 = OpLabel
OpLoopMerge %333 %334 None
OpBranch %335
%335 = OpLabel
%365 = OpLoad %uint %innerRow_2
%366 = OpIAdd %uint %365 %uint_1
OpStore %innerRow_2 %366
%337 = OpLoad %uint %innerRow_2
%338 = OpULessThan %bool %337 %uint_4
%336 = OpLogicalNot %bool %338
OpSelectionMerge %339 None
OpBranchConditional %336 %340 %339
%340 = OpLabel
OpBranch %333
%339 = OpLabel
OpStore %innerCol_2 %105
OpBranch %342
%342 = OpLabel
OpLoopMerge %343 %344 None
OpBranch %345
%345 = OpLabel
%347 = OpLoad %uint %innerCol_2
%348 = OpULessThan %bool %347 %uint_4
%346 = OpLogicalNot %bool %348
OpSelectionMerge %349 None
OpBranchConditional %346 %350 %349
%350 = OpLabel
OpBranch %343
%349 = OpLabel
%351 = OpLoad %uint %innerRow_2
%352 = OpIMul %uint %351 %uint_4
%353 = OpLoad %uint %innerCol_2
%354 = OpIAdd %uint %352 %353
%356 = OpLoad %uint %innerRow_2
%357 = OpIAdd %uint %134 %356
%358 = OpLoad %uint %innerCol_2
%359 = OpIAdd %uint %136 %358
%360 = OpAccessChain %_ptr_Function_float %acc %354
%361 = OpLoad %float %360
%355 = OpFunctionCall %void %mm_write %357 %359 %361
OpBranch %344
%344 = OpLabel
%362 = OpLoad %uint %innerCol_2
%363 = OpIAdd %uint %362 %uint_1
OpStore %innerCol_2 %363
OpBranch %342
%343 = OpLabel
OpBranch %334
%334 = OpLabel
%364 = OpLoad %uint %innerRow_2
%365 = OpIAdd %uint %364 %uint_1
OpStore %innerRow_2 %365
OpBranch %332
%333 = OpLabel
OpReturn
OpFunctionEnd
%main = OpFunction %void None %367
%369 = OpLabel
%371 = OpLoad %v3uint %local_id_1
%372 = OpLoad %v3uint %global_id_1
%373 = OpLoad %uint %local_invocation_index_1
%370 = OpFunctionCall %void %main_inner %371 %372 %373
%main = OpFunction %void None %366
%368 = OpLabel
%370 = OpLoad %v3uint %local_id_1
%371 = OpLoad %v3uint %global_id_1
%372 = OpLoad %uint %local_invocation_index_1
%369 = OpFunctionCall %void %main_inner %370 %371 %372
OpReturn
OpFunctionEnd

View File

@ -1,5 +1,5 @@
[numthreads(1, 1, 1)]
void f() {
const float16_t r = (float16_t(1.0h) * float16_t(2.0h));
const float16_t r = float16_t(2.0h);
return;
}

View File

@ -2,7 +2,7 @@
#extension GL_AMD_gpu_shader_half_float : require
void f() {
float16_t r = (1.0hf * 2.0hf);
float16_t r = 2.0hf;
}
layout(local_size_x = 1, local_size_y = 1, local_size_z = 1) in;

View File

@ -1,5 +1,5 @@
[numthreads(1, 1, 1)]
void f() {
const float r = (1.0f * 2.0f);
const float r = 2.0f;
return;
}

View File

@ -1,5 +1,5 @@
[numthreads(1, 1, 1)]
void f() {
const float r = (1.0f * 2.0f);
const float r = 2.0f;
return;
}

View File

@ -1,7 +1,7 @@
#version 310 es
void f() {
float r = (1.0f * 2.0f);
float r = 2.0f;
}
layout(local_size_x = 1, local_size_y = 1, local_size_z = 1) in;

View File

@ -1,5 +1,5 @@
[numthreads(1, 1, 1)]
void f() {
const int r = (1 * 2);
const int r = 2;
return;
}

View File

@ -1,5 +1,5 @@
[numthreads(1, 1, 1)]
void f() {
const int r = (1 * 2);
const int r = 2;
return;
}

View File

@ -1,7 +1,7 @@
#version 310 es
void f() {
int r = (1 * 2);
int r = 2;
}
layout(local_size_x = 1, local_size_y = 1, local_size_z = 1) in;

View File

@ -1,5 +1,5 @@
[numthreads(1, 1, 1)]
void f() {
const uint r = (1u * 2u);
const uint r = 2u;
return;
}

View File

@ -1,5 +1,5 @@
[numthreads(1, 1, 1)]
void f() {
const uint r = (1u * 2u);
const uint r = 2u;
return;
}

View File

@ -1,7 +1,7 @@
#version 310 es
void f() {
uint r = (1u * 2u);
uint r = 2u;
}
layout(local_size_x = 1, local_size_y = 1, local_size_z = 1) in;

View File

@ -1,5 +1,5 @@
float main() {
return (((2.0f * 3.0f) - 4.0f) / 5.0f);
return (2.0f / 5.0f);
}
[numthreads(2, 1, 1)]

View File

@ -1,5 +1,5 @@
float main() {
return (((2.0f * 3.0f) - 4.0f) / 5.0f);
return (2.0f / 5.0f);
}
[numthreads(2, 1, 1)]

View File

@ -2,7 +2,7 @@
using namespace metal;
float tint_symbol() {
return (((2.0f * 3.0f) - 4.0f) / 5.0f);
return (2.0f / 5.0f);
}
kernel void ep() {

View File

@ -1,7 +1,7 @@
; SPIR-V
; Version: 1.3
; Generator: Google Tint Compiler; 0
; Bound: 16
; Bound: 12
; Schema: 0
OpCapability Shader
OpMemoryModel Logical GLSL450
@ -12,19 +12,15 @@
%float = OpTypeFloat 32
%1 = OpTypeFunction %float
%float_2 = OpConstant %float 2
%float_3 = OpConstant %float 3
%float_4 = OpConstant %float 4
%float_5 = OpConstant %float 5
%void = OpTypeVoid
%12 = OpTypeFunction %void
%8 = OpTypeFunction %void
%main = OpFunction %float None %1
%4 = OpLabel
%7 = OpFMul %float %float_2 %float_3
%9 = OpFSub %float %7 %float_4
%11 = OpFDiv %float %9 %float_5
OpReturnValue %11
%7 = OpFDiv %float %float_2 %float_5
OpReturnValue %7
OpFunctionEnd
%ep = OpFunction %void None %12
%15 = OpLabel
%ep = OpFunction %void None %8
%11 = OpLabel
OpReturn
OpFunctionEnd