From c20c5dfb4a3cc606f51742f24e2566df999cc1d0 Mon Sep 17 00:00:00 2001 From: Antonio Maiorano Date: Thu, 1 Sep 2022 14:57:39 +0000 Subject: [PATCH] tint: Implement const eval of binary multiply Bug: tint:1581 Change-Id: I70ff40ed4d8faf0a665824fef936ffbafb3f0948 Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/99362 Kokoro: Kokoro Commit-Queue: Antonio Maiorano Reviewed-by: Ben Clayton --- src/tint/intrinsics.def | 18 +- src/tint/number.h | 42 ++ src/tint/resolver/const_eval.cc | 421 +++++++++++- src/tint/resolver/const_eval.h | 126 ++++ src/tint/resolver/const_eval_test.cc | 269 +++++++- src/tint/resolver/intrinsic_table.inl | 54 +- src/tint/resolver/intrinsic_table_test.cc | 36 +- src/tint/utils/result.h | 36 +- .../writer/glsl/generator_impl_binary_test.cc | 20 +- .../writer/hlsl/generator_impl_binary_test.cc | 8 +- test/tint/bug/tint/757.wgsl.expected.dxc.hlsl | 2 +- test/tint/bug/tint/757.wgsl.expected.fxc.hlsl | 2 +- test/tint/bug/tint/757.wgsl.expected.glsl | 2 +- test/tint/bug/tint/757.wgsl.expected.msl | 2 +- test/tint/bug/tint/757.wgsl.expected.spvasm | 8 +- test/tint/bug/tint/914.wgsl.expected.dxc.hlsl | 2 +- test/tint/bug/tint/914.wgsl.expected.fxc.hlsl | 2 +- test/tint/bug/tint/914.wgsl.expected.glsl | 2 +- test/tint/bug/tint/914.wgsl.expected.msl | 2 +- test/tint/bug/tint/914.wgsl.expected.spvasm | 627 +++++++++--------- .../scalar-scalar/f16.wgsl.expected.dxc.hlsl | 2 +- .../mul/scalar-scalar/f16.wgsl.expected.glsl | 2 +- .../scalar-scalar/f32.wgsl.expected.dxc.hlsl | 2 +- .../scalar-scalar/f32.wgsl.expected.fxc.hlsl | 2 +- .../mul/scalar-scalar/f32.wgsl.expected.glsl | 2 +- .../scalar-scalar/i32.wgsl.expected.dxc.hlsl | 2 +- .../scalar-scalar/i32.wgsl.expected.fxc.hlsl | 2 +- .../mul/scalar-scalar/i32.wgsl.expected.glsl | 2 +- .../scalar-scalar/u32.wgsl.expected.dxc.hlsl | 2 +- .../scalar-scalar/u32.wgsl.expected.fxc.hlsl | 2 +- .../mul/scalar-scalar/u32.wgsl.expected.glsl | 2 +- .../samples/function.wgsl.expected.dxc.hlsl | 2 +- .../samples/function.wgsl.expected.fxc.hlsl | 2 +- test/tint/samples/function.wgsl.expected.msl | 2 +- .../samples/function.wgsl.expected.spvasm | 16 +- 35 files changed, 1243 insertions(+), 482 deletions(-) diff --git a/src/tint/intrinsics.def b/src/tint/intrinsics.def index ae89f5c37d..a5651604e1 100644 --- a/src/tint/intrinsics.def +++ b/src/tint/intrinsics.def @@ -898,15 +898,15 @@ op ! (vec) -> vec @const op - (T, vec) -> vec @const op - (mat, mat) -> mat -op * (T, T) -> T -op * (vec, vec) -> vec -op * (vec, T) -> vec -op * (T, vec) -> vec -op * (T, mat) -> mat -op * (mat, T) -> mat -op * (mat, vec) -> vec -op * (vec, mat) -> vec -op * (mat, mat) -> mat +@const("Multiply") op * (T, T) -> T +@const("Multiply") op * (vec, vec) -> vec +@const("Multiply") op * (vec, T) -> vec +@const("Multiply") op * (T, vec) -> vec +@const("Multiply") op * (T, mat) -> mat +@const("Multiply") op * (mat, T) -> mat +@const("MultiplyMatVec") op * (mat, vec) -> vec +@const("MultiplyVecMat") op * (vec, mat) -> vec +@const("MultiplyMatMat") op * (mat, mat) -> mat op / (T, T) -> T op / (vec, vec) -> vec diff --git a/src/tint/number.h b/src/tint/number.h index 4635051177..032844c8ef 100644 --- a/src/tint/number.h +++ b/src/tint/number.h @@ -22,6 +22,7 @@ #include #include +#include "src/tint/traits.h" #include "src/tint/utils/compiler_macros.h" #include "src/tint/utils/result.h" @@ -33,6 +34,14 @@ struct Number; } // namespace tint namespace tint::detail { +/// Base template for IsNumber +template +struct IsNumber : std::false_type {}; + +/// Specialization for IsNumber +template +struct IsNumber> : std::true_type {}; + /// An empty structure used as a unique template type for Number when /// specializing for the f16 type. struct NumberKindF16 {}; @@ -68,6 +77,10 @@ constexpr bool IsInteger = std::is_integral_v; template constexpr bool IsNumeric = IsInteger || IsFloatingPoint; +/// Evaluates to true iff T is a Number +template +constexpr bool IsNumber = detail::IsNumber::value; + /// Resolves to the underlying type for a Number. template using UnwrapNumber = typename detail::NumberUnwrapper::type; @@ -236,6 +249,26 @@ using f32 = Number; /// However since C++ don't have native binary16 type, the value is stored as float. using f16 = Number; +/// @returns the friendly name of Number type T +template >> +const char* FriendlyName() { + if constexpr (std::is_same_v) { + return "abstract-int"; + } else if constexpr (std::is_same_v) { + return "abstract-float"; + } else if constexpr (std::is_same_v) { + return "i32"; + } else if constexpr (std::is_same_v) { + return "u32"; + } else if constexpr (std::is_same_v) { + return "f32"; + } else if constexpr (std::is_same_v) { + return "f16"; + } else { + static_assert(!sizeof(T), "Unhandled type"); + } +} + /// Enumerator of failure reasons when converting from one number to another. enum class ConversionFailure { kExceedsPositiveLimit, // The value was too big (+'ve) to fit in the target type @@ -437,6 +470,15 @@ inline std::optional CheckedMul(AInt a, AInt b) { return AInt(result); } +/// @returns a * b, or an empty optional if the resulting value overflowed the AFloat +inline std::optional CheckedMul(AFloat a, AFloat b) { + auto result = a.value * b.value; + if (!std::isfinite(result)) { + return {}; + } + return AFloat{result}; +} + /// @returns a * b + c, or an empty optional if the value overflowed the AInt inline std::optional CheckedMadd(AInt a, AInt b, AInt c) { // https://gcc.gnu.org/bugzilla/show_bug.cgi?id=80635 diff --git a/src/tint/resolver/const_eval.cc b/src/tint/resolver/const_eval.cc index 9bc1d476df..d363b67d50 100644 --- a/src/tint/resolver/const_eval.cc +++ b/src/tint/resolver/const_eval.cc @@ -38,6 +38,7 @@ #include "src/tint/sem/vector.h" #include "src/tint/utils/compiler_macros.h" #include "src/tint/utils/map.h" +#include "src/tint/utils/scoped_assignment.h" #include "src/tint/utils/transform.h" using namespace tint::number_suffixes; // NOLINT @@ -508,11 +509,202 @@ const Constant* TransformBinaryElements(ProgramBuilder& builder, auto* ty = n0 > n1 ? c0->Type() : c1->Type(); return CreateComposite(builder, ty, std::move(els)); } - } // namespace ConstEval::ConstEval(ProgramBuilder& b) : builder(b) {} +template +utils::Result ConstEval::Add(NumberT a, NumberT b) { + using T = UnwrapNumber; + auto add_values = [](T lhs, T rhs) { + if constexpr (std::is_integral_v && std::is_signed_v) { + // Ensure no UB for signed overflow + using UT = std::make_unsigned_t; + return static_cast(static_cast(lhs) + static_cast(rhs)); + } else { + return lhs + rhs; + } + }; + NumberT result; + if constexpr (std::is_same_v || std::is_same_v) { + // Check for over/underflow for abstract values + if (auto r = CheckedAdd(a, b)) { + result = r->value; + } else { + AddError("'" + std::to_string(add_values(a.value, b.value)) + + "' cannot be represented as '" + FriendlyName() + "'", + *current_source); + return utils::Failure; + } + } else { + result = add_values(a.value, b.value); + } + return result; +} + +template +utils::Result ConstEval::Mul(NumberT a, NumberT b) { + using T = UnwrapNumber; + auto mul_values = [](T lhs, T rhs) { // + if constexpr (std::is_integral_v && std::is_signed_v) { + // For signed integrals, avoid C++ UB by multiplying as unsigned + using UT = std::make_unsigned_t; + return static_cast(static_cast(lhs) * static_cast(rhs)); + } else { + return lhs * rhs; + } + }; + NumberT result; + if constexpr (std::is_same_v || std::is_same_v) { + // Check for over/underflow for abstract values + if (auto r = CheckedMul(a, b)) { + result = r->value; + } else { + AddError("'" + std::to_string(mul_values(a.value, b.value)) + + "' cannot be represented as '" + FriendlyName() + "'", + *current_source); + return utils::Failure; + } + } else { + result = mul_values(a.value, b.value); + } + return result; +} + +template +utils::Result ConstEval::Dot2(NumberT a1, NumberT a2, NumberT b1, NumberT b2) { + auto r1 = Mul(a1, b1); + if (!r1) { + return utils::Failure; + } + auto r2 = Mul(a2, b2); + if (!r2) { + return utils::Failure; + } + auto r = Add(r1.Get(), r2.Get()); + if (!r) { + return utils::Failure; + } + return r; +} + +template +utils::Result ConstEval::Dot3(NumberT a1, + NumberT a2, + NumberT a3, + NumberT b1, + NumberT b2, + NumberT b3) { + auto r1 = Mul(a1, b1); + if (!r1) { + return utils::Failure; + } + auto r2 = Mul(a2, b2); + if (!r2) { + return utils::Failure; + } + auto r3 = Mul(a3, b3); + if (!r3) { + return utils::Failure; + } + auto r = Add(r1.Get(), r2.Get()); + if (!r) { + return utils::Failure; + } + r = Add(r.Get(), r3.Get()); + if (!r) { + return utils::Failure; + } + return r; +} + +template +utils::Result ConstEval::Dot4(NumberT a1, + NumberT a2, + NumberT a3, + NumberT a4, + NumberT b1, + NumberT b2, + NumberT b3, + NumberT b4) { + auto r1 = Mul(a1, b1); + if (!r1) { + return utils::Failure; + } + auto r2 = Mul(a2, b2); + if (!r2) { + return utils::Failure; + } + auto r3 = Mul(a3, b3); + if (!r3) { + return utils::Failure; + } + auto r4 = Mul(a4, b4); + if (!r4) { + return utils::Failure; + } + auto r = Add(r1.Get(), r2.Get()); + if (!r) { + return utils::Failure; + } + r = Add(r.Get(), r3.Get()); + if (!r) { + return utils::Failure; + } + r = Add(r.Get(), r4.Get()); + if (!r) { + return utils::Failure; + } + return r; +} + +auto ConstEval::AddFunc(const sem::Type* elem_ty) { + return [=](auto a1, auto a2) -> utils::Result { + if (auto r = Add(a1, a2)) { + return CreateElement(builder, elem_ty, r.Get()); + } + return utils::Failure; + }; +} + +auto ConstEval::MulFunc(const sem::Type* elem_ty) { + return [=](auto a1, auto a2) -> utils::Result { + if (auto r = Mul(a1, a2)) { + return CreateElement(builder, elem_ty, r.Get()); + } + return utils::Failure; + }; +} + +auto ConstEval::Dot2Func(const sem::Type* elem_ty) { + return [=](auto a1, auto a2, auto b1, auto b2) -> utils::Result { + if (auto r = Dot2(a1, a2, b1, b2)) { + return CreateElement(builder, elem_ty, r.Get()); + } + return utils::Failure; + }; +} + +auto ConstEval::Dot3Func(const sem::Type* elem_ty) { + return [=](auto a1, auto a2, auto a3, auto b1, auto b2, + auto b3) -> utils::Result { + if (auto r = Dot3(a1, a2, a3, b1, b2, b3)) { + return CreateElement(builder, elem_ty, r.Get()); + } + return utils::Failure; + }; +} + +auto ConstEval::Dot4Func(const sem::Type* elem_ty) { + return [=](auto a1, auto a2, auto a3, auto a4, auto b1, auto b2, auto b3, + auto b4) -> utils::Result { + if (auto r = Dot4(a1, a2, a3, a4, b1, b2, b3, b4)) { + return CreateElement(builder, elem_ty, r.Get()); + } + return utils::Failure; + }; +} + ConstEval::ConstantResult ConstEval::Literal(const sem::Type* ty, const ast::LiteralExpression* literal) { return Switch( @@ -756,42 +948,15 @@ ConstEval::ConstantResult ConstEval::OpUnaryMinus(const sem::Type*, return TransformElements(builder, transform, args[0]); } -ConstEval::ConstantResult ConstEval::OpPlus(const sem::Type* ty, +ConstEval::ConstantResult ConstEval::OpPlus(const sem::Type*, utils::VectorRef args, const Source& source) { - auto transform = [&](const sem::Constant* c0, const sem::Constant* c1) { - auto create = [&](auto i, auto j) -> const Constant* { - using NumberT = decltype(i); - using T = UnwrapNumber; - - auto add_values = [](T lhs, T rhs) { - if constexpr (std::is_integral_v && std::is_signed_v) { - // Ensure no UB for signed overflow - using UT = std::make_unsigned_t; - return static_cast(static_cast(lhs) + static_cast(rhs)); - } else { - return lhs + rhs; - } - }; - - NumberT result; - if constexpr (std::is_same_v || std::is_same_v) { - // Check for over/underflow for abstract values - if (auto r = CheckedAdd(i, j)) { - result = r->value; - } else { - AddError("'" + std::to_string(add_values(i.value, j.value)) + - "' cannot be represented as '" + - ty->FriendlyName(builder.Symbols()) + "'", - source); - return nullptr; - } - } else { - result = add_values(i.value, j.value); - } - return CreateElement(builder, c0->Type(), result); - }; - return Dispatch_fia_fiu32_f16(create, c0, c1); + TINT_SCOPED_ASSIGNMENT(current_source, &source); + auto transform = [&](const sem::Constant* c0, const sem::Constant* c1) -> const Constant* { + if (auto r = Dispatch_fia_fiu32_f16(AddFunc(c0->Type()), c0, c1)) { + return r.Get(); + } + return nullptr; }; auto r = TransformBinaryElements(builder, transform, args[0], args[1]); @@ -846,6 +1011,192 @@ ConstEval::ConstantResult ConstEval::OpMinus(const sem::Type* ty, return r; } +ConstEval::ConstantResult ConstEval::OpMultiply(const sem::Type* /*ty*/, + utils::VectorRef args, + const Source& source) { + TINT_SCOPED_ASSIGNMENT(current_source, &source); + auto transform = [&](const sem::Constant* c0, const sem::Constant* c1) -> const Constant* { + if (auto r = Dispatch_fia_fiu32_f16(MulFunc(c0->Type()), c0, c1)) { + return r.Get(); + } + return nullptr; + }; + + auto r = TransformBinaryElements(builder, transform, args[0], args[1]); + if (builder.Diagnostics().contains_errors()) { + return utils::Failure; + } + return r; +} + +ConstEval::ConstantResult ConstEval::OpMultiplyMatVec(const sem::Type* ty, + utils::VectorRef args, + const Source& source) { + TINT_SCOPED_ASSIGNMENT(current_source, &source); + auto* mat_ty = args[0]->Type()->As(); + auto* vec_ty = args[1]->Type()->As(); + auto* elem_ty = vec_ty->type(); + + auto dot = [&](const sem::Constant* m, size_t row, const sem::Constant* v) { + utils::Result result; + switch (mat_ty->columns()) { + case 2: + result = Dispatch_fa_f32_f16(Dot2Func(elem_ty), // + m->Index(0)->Index(row), // + m->Index(1)->Index(row), // + v->Index(0), // + v->Index(1)); + break; + case 3: + result = Dispatch_fa_f32_f16(Dot3Func(elem_ty), // + m->Index(0)->Index(row), // + m->Index(1)->Index(row), // + m->Index(2)->Index(row), // + v->Index(0), // + v->Index(1), v->Index(2)); + break; + case 4: + result = Dispatch_fa_f32_f16(Dot4Func(elem_ty), // + m->Index(0)->Index(row), // + m->Index(1)->Index(row), // + m->Index(2)->Index(row), // + m->Index(3)->Index(row), // + v->Index(0), // + v->Index(1), // + v->Index(2), // + v->Index(3)); + break; + } + return result; + }; + + utils::Vector result; + for (size_t i = 0; i < mat_ty->rows(); ++i) { + auto r = dot(args[0], i, args[1]); // matrix row i * vector + if (!r) { + return utils::Failure; + } + result.Push(r.Get()); + } + return CreateComposite(builder, ty, result); +} +ConstEval::ConstantResult ConstEval::OpMultiplyVecMat(const sem::Type* ty, + utils::VectorRef args, + const Source& source) { + TINT_SCOPED_ASSIGNMENT(current_source, &source); + auto* vec_ty = args[0]->Type()->As(); + auto* mat_ty = args[1]->Type()->As(); + auto* elem_ty = vec_ty->type(); + + auto dot = [&](const sem::Constant* v, const sem::Constant* m, size_t col) { + utils::Result result; + switch (mat_ty->rows()) { + case 2: + result = Dispatch_fa_f32_f16(Dot2Func(elem_ty), // + m->Index(col)->Index(0), // + m->Index(col)->Index(1), // + v->Index(0), // + v->Index(1)); + break; + case 3: + result = Dispatch_fa_f32_f16(Dot3Func(elem_ty), // + m->Index(col)->Index(0), // + m->Index(col)->Index(1), // + m->Index(col)->Index(2), + v->Index(0), // + v->Index(1), // + v->Index(2)); + break; + case 4: + result = Dispatch_fa_f32_f16(Dot4Func(elem_ty), // + m->Index(col)->Index(0), // + m->Index(col)->Index(1), // + m->Index(col)->Index(2), // + m->Index(col)->Index(3), // + v->Index(0), // + v->Index(1), // + v->Index(2), // + v->Index(3)); + } + return result; + }; + + utils::Vector result; + for (size_t i = 0; i < mat_ty->columns(); ++i) { + auto r = dot(args[0], args[1], i); // vector * matrix col i + if (!r) { + return utils::Failure; + } + result.Push(r.Get()); + } + return CreateComposite(builder, ty, result); +} + +ConstEval::ConstantResult ConstEval::OpMultiplyMatMat(const sem::Type* ty, + utils::VectorRef args, + const Source& source) { + TINT_SCOPED_ASSIGNMENT(current_source, &source); + auto* mat1 = args[0]; + auto* mat2 = args[1]; + auto* mat1_ty = mat1->Type()->As(); + auto* mat2_ty = mat2->Type()->As(); + auto* elem_ty = mat1_ty->type(); + + auto dot = [&](const sem::Constant* m1, size_t row, const sem::Constant* m2, size_t col) { + auto m1e = [&](size_t r, size_t c) { return m1->Index(c)->Index(r); }; + auto m2e = [&](size_t r, size_t c) { return m2->Index(c)->Index(r); }; + + utils::Result result; + switch (mat1_ty->columns()) { + case 2: + result = Dispatch_fa_f32_f16(Dot2Func(elem_ty), // + m1e(row, 0), // + m1e(row, 1), // + m2e(0, col), // + m2e(1, col)); + break; + case 3: + result = Dispatch_fa_f32_f16(Dot3Func(elem_ty), // + m1e(row, 0), // + m1e(row, 1), // + m1e(row, 2), // + m2e(0, col), // + m2e(1, col), // + m2e(2, col)); + break; + case 4: + result = Dispatch_fa_f32_f16(Dot4Func(elem_ty), // + m1e(row, 0), // + m1e(row, 1), // + m1e(row, 2), // + m1e(row, 3), // + m2e(0, col), // + m2e(1, col), // + m2e(2, col), // + m2e(3, col)); + break; + } + return result; + }; + + utils::Vector result_mat; + for (size_t c = 0; c < mat2_ty->columns(); ++c) { + utils::Vector col_vec; + for (size_t r = 0; r < mat1_ty->rows(); ++r) { + auto v = dot(mat1, r, mat2, c); // mat1 row r * mat2 col c + if (!v) { + return utils::Failure; + } + col_vec.Push(v.Get()); // mat1 row r * mat2 col c + } + + // Add column vector to matrix + auto* col_vec_ty = ty->As()->ColumnType(); + result_mat.Push(CreateComposite(builder, col_vec_ty, col_vec)); + } + return CreateComposite(builder, ty, result_mat); +} + ConstEval::ConstantResult ConstEval::atan2(const sem::Type*, utils::VectorRef args, const Source&) { diff --git a/src/tint/resolver/const_eval.h b/src/tint/resolver/const_eval.h index 9cd38d7054..400716292e 100644 --- a/src/tint/resolver/const_eval.h +++ b/src/tint/resolver/const_eval.h @@ -230,6 +230,42 @@ class ConstEval { utils::VectorRef args, const Source& source); + /// Multiply operator '*' for the same type on the LHS and RHS + /// @param ty the expression type + /// @param args the input arguments + /// @param source the source location of the conversion + /// @return the result value, or null if the value cannot be calculated + ConstantResult OpMultiply(const sem::Type* ty, + utils::VectorRef args, + const Source& source); + + /// Multiply operator '*' for matCxR * vecC + /// @param ty the expression type + /// @param args the input arguments + /// @param source the source location of the conversion + /// @return the result value, or null if the value cannot be calculated + ConstantResult OpMultiplyMatVec(const sem::Type* ty, + utils::VectorRef args, + const Source& source); + + /// Multiply operator '*' for vecR * matCxR + /// @param ty the expression type + /// @param args the input arguments + /// @param source the source location of the conversion + /// @return the result value, or null if the value cannot be calculated + ConstantResult OpMultiplyVecMat(const sem::Type* ty, + utils::VectorRef args, + const Source& source); + + /// Multiply operator '*' for matKxR * matCxK + /// @param ty the expression type + /// @param args the input arguments + /// @param source the source location of the conversion + /// @return the result value, or null if the value cannot be calculated + ConstantResult OpMultiplyMatMat(const sem::Type* ty, + utils::VectorRef args, + const Source& source); + //////////////////////////////////////////////////////////////////////////// // Builtins //////////////////////////////////////////////////////////////////////////// @@ -259,7 +295,97 @@ class ConstEval { /// Adds the given warning message to the diagnostics void AddWarning(const std::string& msg, const Source& source) const; + /// Adds two Numbers + /// @param a the lhs number + /// @param b the rhs number + /// @returns the result number on success, or logs an error and returns Failure + template + utils::Result Add(NumberT a, NumberT b); + + /// Multiplies two Numbers + /// @param a the lhs number + /// @param b the rhs number + /// @returns the result number on success, or logs an error and returns Failure + template + utils::Result Mul(NumberT a, NumberT b); + + /// Returns the dot product of (a1,a2) with (b1,b2) + /// @param a1 component 1 of lhs vector + /// @param a2 component 2 of lhs vector + /// @param b1 component 1 of rhs vector + /// @param b2 component 2 of rhs vector + /// @returns the result number on success, or logs an error and returns Failure + template + utils::Result Dot2(NumberT a1, NumberT a2, NumberT b1, NumberT b2); + + /// Returns the dot product of (a1,a2,a3) with (b1,b2,b3) + /// @param a1 component 1 of lhs vector + /// @param a2 component 2 of lhs vector + /// @param a3 component 3 of lhs vector + /// @param b1 component 1 of rhs vector + /// @param b2 component 2 of rhs vector + /// @param b3 component 3 of rhs vector + /// @returns the result number on success, or logs an error and returns Failure + template + utils::Result Dot3(NumberT a1, + NumberT a2, + NumberT a3, + NumberT b1, + NumberT b2, + NumberT b3); + + /// Returns the dot product of (a1,b1,c1,d1) with (a2,b2,c2,d2) + /// @param a1 component 1 of lhs vector + /// @param a2 component 2 of lhs vector + /// @param a3 component 3 of lhs vector + /// @param a4 component 4 of lhs vector + /// @param b1 component 1 of rhs vector + /// @param b2 component 2 of rhs vector + /// @param b3 component 3 of rhs vector + /// @param b4 component 4 of rhs vector + /// @returns the result number on success, or logs an error and returns Failure + template + utils::Result Dot4(NumberT a1, + NumberT a2, + NumberT a3, + NumberT a4, + NumberT b1, + NumberT b2, + NumberT b3, + NumberT b4); + + /// Returns a callable that calls Add, and creates a Constant with its result of type `elem_ty` + /// if successful, or returns Failure otherwise. + /// @param elem_ty the element type of the Constant to create on success + /// @returns the callable function + auto AddFunc(const sem::Type* elem_ty); + + /// Returns a callable that calls Mul, and creates a Constant with its result of type `elem_ty` + /// if successful, or returns Failure otherwise. + /// @param elem_ty the element type of the Constant to create on success + /// @returns the callable function + auto MulFunc(const sem::Type* elem_ty); + + /// Returns a callable that calls Dot2, and creates a Constant with its result of type `elem_ty` + /// if successful, or returns Failure otherwise. + /// @param elem_ty the element type of the Constant to create on success + /// @returns the callable function + auto Dot2Func(const sem::Type* elem_ty); + + /// Returns a callable that calls Dot3, and creates a Constant with its result of type `elem_ty` + /// if successful, or returns Failure otherwise. + /// @param elem_ty the element type of the Constant to create on success + /// @returns the callable function + auto Dot3Func(const sem::Type* elem_ty); + + /// Returns a callable that calls Dot4, and creates a Constant with its result of type `elem_ty` + /// if successful, or returns Failure otherwise. + /// @param elem_ty the element type of the Constant to create on success + /// @returns the callable function + auto Dot4Func(const sem::Type* elem_ty); + ProgramBuilder& builder; + const Source* current_source = nullptr; }; } // namespace tint::resolver diff --git a/src/tint/resolver/const_eval_test.cc b/src/tint/resolver/const_eval_test.cc index 23e621f01c..a83cfa1bff 100644 --- a/src/tint/resolver/const_eval_test.cc +++ b/src/tint/resolver/const_eval_test.cc @@ -15,6 +15,7 @@ #include #include +#include "gmock/gmock.h" #include "gtest/gtest.h" #include "src/tint/resolver/resolver_test_helper.h" #include "src/tint/sem/builtin_type.h" @@ -24,6 +25,8 @@ #include "src/tint/sem/test_helper.h" #include "src/tint/utils/transform.h" +using ::testing::HasSubstr; + using namespace tint::number_suffixes; // NOLINT namespace tint::resolver { @@ -74,6 +77,19 @@ auto Abs(const Number& v) { } } +TINT_BEGIN_DISABLE_WARNING(CONSTANT_OVERFLOW); +template +constexpr Number Mul(Number v1, Number v2) { + if constexpr (std::is_integral_v && std::is_signed_v) { + // For signed integrals, avoid C++ UB by multiplying as unsigned + using UT = std::make_unsigned_t; + return static_cast>(static_cast(v1) * static_cast(v2)); + } else { + return static_cast>(v1 * v2); + } +} +TINT_END_DISABLE_WARNING(CONSTANT_OVERFLOW); + // Concats any number of std::vectors template [[nodiscard]] auto Concat(Vec&& v1, Vecs&&... vs) { @@ -3228,29 +3244,26 @@ Case C(T lhs, U rhs, V expected, bool overflow = false) { return Case{Val(lhs), Val(rhs), Val(expected), overflow}; } -static std::ostream& operator<<(std::ostream& o, const Case& c) { - auto print_value = [&](auto&& value) { - std::visit( - [&](auto&& v) { - using ValueType = std::decay_t; - o << ValueType::DataType::Name() << "("; - for (auto& a : v.args.values) { - o << std::get(a); - if (&a != &v.args.values.Back()) { - o << ", "; - } +static std::ostream& operator<<(std::ostream& o, const Types& types) { + std::visit( + [&](auto&& v) { + using ValueType = std::decay_t; + o << ValueType::DataType::Name() << "("; + for (auto& a : v.args.values) { + o << std::get(a); + if (&a != &v.args.values.Back()) { + o << ", "; } - o << ")"; - }, - value); - }; - o << "lhs: "; - print_value(c.lhs); - o << ", rhs: "; - print_value(c.rhs); - o << ", expected: "; - print_value(c.expected); - o << ", overflow: " << c.overflow; + } + o << ")"; + }, + types); + return o; +} + +static std::ostream& operator<<(std::ostream& o, const Case& c) { + o << "lhs: " << c.lhs << ", rhs: " << c.rhs << ", expected: " << c.expected + << ", overflow: " << c.overflow; return o; } @@ -3281,7 +3294,6 @@ bool ForEachElemPair(const sem::Constant* a, const sem::Constant* b, Func&& f) { using ResolverConstEvalBinaryOpTest = ResolverTestWithParam>; TEST_P(ResolverConstEvalBinaryOpTest, Test) { Enable(ast::Extension::kF16); - auto op = std::get<0>(GetParam()); auto& c = std::get<1>(GetParam()); @@ -3300,10 +3312,8 @@ TEST_P(ResolverConstEvalBinaryOpTest, Test) { auto* expr = create(op, lhs_expr, rhs_expr); GlobalConst("C", expr); - auto* expected_expr = expected.Expr(*this); GlobalConst("E", expected_expr); - EXPECT_TRUE(r()->Resolve()) << r()->error(); auto* sem = Sem().Get(expr); @@ -3413,6 +3423,215 @@ INSTANTIATE_TEST_SUITE_P(Sub, OpSubFloatCases(), OpSubFloatCases())))); +template +std::vector OpMulScalarCases() { + return { + C(T{0}, T{0}, T{0}), + C(T{1}, T{2}, T{2}), + C(T{2}, T{3}, T{6}), + C(Negate(T{2}), T{3}, Negate(T{6})), + C(T::Highest(), T{1}, T::Highest()), + C(T::Lowest(), T{1}, T::Lowest()), + C(T::Highest(), T::Highest(), Mul(T::Highest(), T::Highest()), true), + C(T::Lowest(), T::Lowest(), Mul(T::Lowest(), T::Lowest()), true), + }; +} + +template +std::vector OpMulVecCases() { + return { + // s * vec3 = vec3 + C(Val(T{2.0}), Vec(T{1.25}, T{2.25}, T{3.25}), Vec(T{2.5}, T{4.5}, T{6.5})), + // vec3 * s = vec3 + C(Vec(T{1.25}, T{2.25}, T{3.25}), Val(T{2.0}), Vec(T{2.5}, T{4.5}, T{6.5})), + // vec3 * vec3 = vec3 + C(Vec(T{1.25}, T{2.25}, T{3.25}), Vec(T{2.0}, T{2.0}, T{2.0}), Vec(T{2.5}, T{4.5}, T{6.5})), + }; +} + +template +std::vector OpMulMatCases() { + return { + // s * mat3x2 = mat3x2 + C(Val(T{2.25}), + Mat({T{1.0}, T{4.0}}, // + {T{2.0}, T{5.0}}, // + {T{3.0}, T{6.0}}), + Mat({T{2.25}, T{9.0}}, // + {T{4.5}, T{11.25}}, // + {T{6.75}, T{13.5}})), + // mat3x2 * s = mat3x2 + C(Mat({T{1.0}, T{4.0}}, // + {T{2.0}, T{5.0}}, // + {T{3.0}, T{6.0}}), + Val(T{2.25}), + Mat({T{2.25}, T{9.0}}, // + {T{4.5}, T{11.25}}, // + {T{6.75}, T{13.5}})), + // vec3 * mat2x3 = vec2 + C(Vec(T{1.25}, T{2.25}, T{3.25}), // + Mat({T{1.0}, T{2.0}, T{3.0}}, // + {T{4.0}, T{5.0}, T{6.0}}), // + Vec(T{15.5}, T{35.75})), + // mat2x3 * vec2 = vec3 + C(Mat({T{1.0}, T{2.0}, T{3.0}}, // + {T{4.0}, T{5.0}, T{6.0}}), // + Vec(T{1.25}, T{2.25}), // + Vec(T{10.25}, T{13.75}, T{17.25})), + // mat3x2 * mat2x3 = mat2x2 + C(Mat({T{1.0}, T{2.0}}, // + {T{3.0}, T{4.0}}, // + {T{5.0}, T{6.0}}), // + Mat({T{1.25}, T{2.25}, T{3.25}}, // + {T{4.25}, T{5.25}, T{6.25}}), // + Mat({T{24.25}, T{31.0}}, // + {T{51.25}, T{67.0}})), // + }; +} + +INSTANTIATE_TEST_SUITE_P(Mul, + ResolverConstEvalBinaryOpTest, + testing::Combine( // + testing::Values(ast::BinaryOp::kMultiply), + testing::ValuesIn(Concat( // + OpMulScalarCases(), + OpMulScalarCases(), + OpMulScalarCases(), + OpMulScalarCases(), + OpMulScalarCases(), + OpMulScalarCases(), + OpMulVecCases(), + OpMulVecCases(), + OpMulVecCases(), + OpMulVecCases(), + OpMulVecCases(), + OpMulVecCases(), + OpMulMatCases(), + OpMulMatCases(), + OpMulMatCases())))); + +// Tests for errors on overflow/underflow of binary operations with abstract numbers +struct OverflowCase { + ast::BinaryOp op; + Types lhs; + Types rhs; + std::string overflowed_result; +}; + +static std::ostream& operator<<(std::ostream& o, const OverflowCase& c) { + o << ast::FriendlyName(c.op) << ", lhs: " << c.lhs << ", rhs: " << c.rhs; + return o; +} +using ResolverConstEvalBinaryOpTest_Overflow = ResolverTestWithParam; +TEST_P(ResolverConstEvalBinaryOpTest_Overflow, Test) { + Enable(ast::Extension::kF16); + auto& c = GetParam(); + auto* lhs_expr = std::visit([&](auto&& value) { return value.Expr(*this); }, c.lhs); + auto* rhs_expr = std::visit([&](auto&& value) { return value.Expr(*this); }, c.rhs); + auto* expr = create(Source{{1, 1}}, c.op, lhs_expr, rhs_expr); + GlobalConst("C", expr); + ASSERT_FALSE(r()->Resolve()); + + std::string type_name = std::visit( + [&](auto&& value) { + using ValueType = std::decay_t; + return tint::FriendlyName(); + }, + c.lhs); + + EXPECT_THAT(r()->error(), HasSubstr("1:1 error: '" + c.overflowed_result + + "' cannot be represented as '" + type_name + "'")); +} +INSTANTIATE_TEST_SUITE_P( + Test, + ResolverConstEvalBinaryOpTest_Overflow, + testing::Values( // + // scalar-scalar add + OverflowCase{ast::BinaryOp::kAdd, Val(AInt::Highest()), Val(1_a), "-9223372036854775808"}, + OverflowCase{ast::BinaryOp::kAdd, Val(AInt::Lowest()), Val(-1_a), "9223372036854775807"}, + OverflowCase{ast::BinaryOp::kAdd, Val(AFloat::Highest()), Val(AFloat::Highest()), "inf"}, + OverflowCase{ast::BinaryOp::kAdd, Val(AFloat::Lowest()), Val(AFloat::Lowest()), "-inf"}, + // scalar-scalar subtract + OverflowCase{ast::BinaryOp::kSubtract, Val(AInt::Lowest()), Val(1_a), + "9223372036854775807"}, + OverflowCase{ast::BinaryOp::kSubtract, Val(AInt::Highest()), Val(-1_a), + "-9223372036854775808"}, + OverflowCase{ast::BinaryOp::kSubtract, Val(AFloat::Highest()), Val(AFloat::Lowest()), + "inf"}, + OverflowCase{ast::BinaryOp::kSubtract, Val(AFloat::Lowest()), Val(AFloat::Highest()), + "-inf"}, + + // scalar-scalar multiply + OverflowCase{ast::BinaryOp::kMultiply, Val(AInt::Highest()), Val(2_a), "-2"}, + OverflowCase{ast::BinaryOp::kMultiply, Val(AInt::Lowest()), Val(-2_a), "0"}, + + // scalar-vector multiply + OverflowCase{ast::BinaryOp::kMultiply, Val(AInt::Highest()), Vec(2_a, 1_a), "-2"}, + OverflowCase{ast::BinaryOp::kMultiply, Val(AInt::Lowest()), Vec(-2_a, 1_a), "0"}, + + // vector-matrix multiply + + // Overflow from first multiplication of dot product of vector and matrix column 0 + // i.e. (v[0] * m[0][0] + v[1] * m[0][1]) + // ^ + OverflowCase{ast::BinaryOp::kMultiply, // + Vec(AFloat::Highest(), 1.0_a), // + Mat({2.0_a, 1.0_a}, // + {1.0_a, 1.0_a}), // + "inf"}, + + // Overflow from second multiplication of dot product of vector and matrix column 0 + // i.e. (v[0] * m[0][0] + v[1] * m[0][1]) + // ^ + OverflowCase{ast::BinaryOp::kMultiply, // + Vec(1.0_a, AFloat::Highest()), // + Mat({1.0_a, 2.0_a}, // + {1.0_a, 1.0_a}), // + "inf"}, + + // Overflow from addition of dot product of vector and matrix column 0 + // i.e. (v[0] * m[0][0] + v[1] * m[0][1]) + // ^ + OverflowCase{ast::BinaryOp::kMultiply, // + Vec(AFloat::Highest(), AFloat::Highest()), // + Mat({1.0_a, 1.0_a}, // + {1.0_a, 1.0_a}), // + "inf"}, + + // matrix-matrix multiply + + // Overflow from first multiplication of dot product of lhs row 0 and rhs column 0 + // i.e. m1[0][0] * m2[0][0] + m1[0][1] * m[1][0] + // ^ + OverflowCase{ast::BinaryOp::kMultiply, // + Mat({AFloat::Highest(), 1.0_a}, // + {1.0_a, 1.0_a}), // + Mat({2.0_a, 1.0_a}, // + {1.0_a, 1.0_a}), // + "inf"}, + + // Overflow from second multiplication of dot product of lhs row 0 and rhs column 0 + // i.e. m1[0][0] * m2[0][0] + m1[0][1] * m[1][0] + // ^ + OverflowCase{ast::BinaryOp::kMultiply, // + Mat({1.0_a, AFloat::Highest()}, // + {1.0_a, 1.0_a}), // + Mat({1.0_a, 1.0_a}, // + {2.0_a, 1.0_a}), // + "inf"}, + + // Overflow from addition of dot product of lhs row 0 and rhs column 0 + // i.e. m1[0][0] * m2[0][0] + m1[0][1] * m[1][0] + // ^ + OverflowCase{ast::BinaryOp::kMultiply, // + Mat({AFloat::Highest(), 1.0_a}, // + {AFloat::Highest(), 1.0_a}), // + Mat({1.0_a, 1.0_a}, // + {1.0_a, 1.0_a}), // + "inf"} + + )); + TEST_F(ResolverConstEvalTest, BinaryAbstractAddOverflow_AInt) { GlobalConst("c", Add(Source{{1, 1}}, Expr(AInt::Highest()), 1_a)); EXPECT_FALSE(r()->Resolve()); diff --git a/src/tint/resolver/intrinsic_table.inl b/src/tint/resolver/intrinsic_table.inl index ff5878a2d7..e0a8ddcab3 100644 --- a/src/tint/resolver/intrinsic_table.inl +++ b/src/tint/resolver/intrinsic_table.inl @@ -9572,108 +9572,108 @@ constexpr OverloadInfo kOverloads[] = { /* num parameters */ 2, /* num template types */ 1, /* num template numbers */ 0, - /* template types */ &kTemplateTypes[15], + /* template types */ &kTemplateTypes[13], /* template numbers */ &kTemplateNumbers[10], /* parameters */ &kParameters[727], /* return matcher indices */ &kMatcherIndices[1], /* flags */ OverloadFlags(OverloadFlag::kIsOperator, OverloadFlag::kSupportsVertexPipeline, OverloadFlag::kSupportsFragmentPipeline, OverloadFlag::kSupportsComputePipeline), - /* const eval */ nullptr, + /* const eval */ &ConstEval::OpMultiply, }, { /* [118] */ /* num parameters */ 2, /* num template types */ 1, /* num template numbers */ 1, - /* template types */ &kTemplateTypes[15], + /* template types */ &kTemplateTypes[13], /* template numbers */ &kTemplateNumbers[6], /* parameters */ &kParameters[725], /* return matcher indices */ &kMatcherIndices[30], /* flags */ OverloadFlags(OverloadFlag::kIsOperator, OverloadFlag::kSupportsVertexPipeline, OverloadFlag::kSupportsFragmentPipeline, OverloadFlag::kSupportsComputePipeline), - /* const eval */ nullptr, + /* const eval */ &ConstEval::OpMultiply, }, { /* [119] */ /* num parameters */ 2, /* num template types */ 1, /* num template numbers */ 1, - /* template types */ &kTemplateTypes[15], + /* template types */ &kTemplateTypes[13], /* template numbers */ &kTemplateNumbers[6], /* parameters */ &kParameters[723], /* return matcher indices */ &kMatcherIndices[30], /* flags */ OverloadFlags(OverloadFlag::kIsOperator, OverloadFlag::kSupportsVertexPipeline, OverloadFlag::kSupportsFragmentPipeline, OverloadFlag::kSupportsComputePipeline), - /* const eval */ nullptr, + /* const eval */ &ConstEval::OpMultiply, }, { /* [120] */ /* num parameters */ 2, /* num template types */ 1, /* num template numbers */ 1, - /* template types */ &kTemplateTypes[15], + /* template types */ &kTemplateTypes[13], /* template numbers */ &kTemplateNumbers[6], /* parameters */ &kParameters[721], /* return matcher indices */ &kMatcherIndices[30], /* flags */ OverloadFlags(OverloadFlag::kIsOperator, OverloadFlag::kSupportsVertexPipeline, OverloadFlag::kSupportsFragmentPipeline, OverloadFlag::kSupportsComputePipeline), - /* const eval */ nullptr, + /* const eval */ &ConstEval::OpMultiply, }, { /* [121] */ /* num parameters */ 2, /* num template types */ 1, /* num template numbers */ 2, - /* template types */ &kTemplateTypes[11], + /* template types */ &kTemplateTypes[12], /* template numbers */ &kTemplateNumbers[6], /* parameters */ &kParameters[719], /* return matcher indices */ &kMatcherIndices[10], /* flags */ OverloadFlags(OverloadFlag::kIsOperator, OverloadFlag::kSupportsVertexPipeline, OverloadFlag::kSupportsFragmentPipeline, OverloadFlag::kSupportsComputePipeline), - /* const eval */ nullptr, + /* const eval */ &ConstEval::OpMultiply, }, { /* [122] */ /* num parameters */ 2, /* num template types */ 1, /* num template numbers */ 2, - /* template types */ &kTemplateTypes[11], + /* template types */ &kTemplateTypes[12], /* template numbers */ &kTemplateNumbers[6], /* parameters */ &kParameters[717], /* return matcher indices */ &kMatcherIndices[10], /* flags */ OverloadFlags(OverloadFlag::kIsOperator, OverloadFlag::kSupportsVertexPipeline, OverloadFlag::kSupportsFragmentPipeline, OverloadFlag::kSupportsComputePipeline), - /* const eval */ nullptr, + /* const eval */ &ConstEval::OpMultiply, }, { /* [123] */ /* num parameters */ 2, /* num template types */ 1, /* num template numbers */ 2, - /* template types */ &kTemplateTypes[11], + /* template types */ &kTemplateTypes[12], /* template numbers */ &kTemplateNumbers[1], /* parameters */ &kParameters[715], /* return matcher indices */ &kMatcherIndices[69], /* flags */ OverloadFlags(OverloadFlag::kIsOperator, OverloadFlag::kSupportsVertexPipeline, OverloadFlag::kSupportsFragmentPipeline, OverloadFlag::kSupportsComputePipeline), - /* const eval */ nullptr, + /* const eval */ &ConstEval::OpMultiplyMatVec, }, { /* [124] */ /* num parameters */ 2, /* num template types */ 1, /* num template numbers */ 2, - /* template types */ &kTemplateTypes[11], + /* template types */ &kTemplateTypes[12], /* template numbers */ &kTemplateNumbers[1], /* parameters */ &kParameters[713], /* return matcher indices */ &kMatcherIndices[30], /* flags */ OverloadFlags(OverloadFlag::kIsOperator, OverloadFlag::kSupportsVertexPipeline, OverloadFlag::kSupportsFragmentPipeline, OverloadFlag::kSupportsComputePipeline), - /* const eval */ nullptr, + /* const eval */ &ConstEval::OpMultiplyVecMat, }, { /* [125] */ /* num parameters */ 2, /* num template types */ 1, /* num template numbers */ 3, - /* template types */ &kTemplateTypes[11], + /* template types */ &kTemplateTypes[12], /* template numbers */ &kTemplateNumbers[0], /* parameters */ &kParameters[711], /* return matcher indices */ &kMatcherIndices[22], /* flags */ OverloadFlags(OverloadFlag::kIsOperator, OverloadFlag::kSupportsVertexPipeline, OverloadFlag::kSupportsFragmentPipeline, OverloadFlag::kSupportsComputePipeline), - /* const eval */ nullptr, + /* const eval */ &ConstEval::OpMultiplyMatMat, }, { /* [126] */ @@ -14630,15 +14630,15 @@ constexpr IntrinsicInfo kBinaryOperators[] = { }, { /* [2] */ - /* op *(T, T) -> T */ - /* op *(vec, vec) -> vec */ - /* op *(vec, T) -> vec */ - /* op *(T, vec) -> vec */ - /* op *(T, mat) -> mat */ - /* op *(mat, T) -> mat */ - /* op *(mat, vec) -> vec */ - /* op *(vec, mat) -> vec */ - /* op *(mat, mat) -> mat */ + /* op *(T, T) -> T */ + /* op *(vec, vec) -> vec */ + /* op *(vec, T) -> vec */ + /* op *(T, vec) -> vec */ + /* op *(T, mat) -> mat */ + /* op *(mat, T) -> mat */ + /* op *(mat, vec) -> vec */ + /* op *(vec, mat) -> vec */ + /* op *(mat, mat) -> mat */ /* num overloads */ 9, /* overloads */ &kOverloads[117], }, diff --git a/src/tint/resolver/intrinsic_table_test.cc b/src/tint/resolver/intrinsic_table_test.cc index b9f9d53c5f..0ffc011464 100644 --- a/src/tint/resolver/intrinsic_table_test.cc +++ b/src/tint/resolver/intrinsic_table_test.cc @@ -641,15 +641,15 @@ TEST_F(IntrinsicTableTest, MismatchBinaryOp) { EXPECT_EQ(Diagnostics().str(), R"(12:34 error: no matching overload for operator * (f32, bool) 9 candidate operators: - operator * (T, T) -> T where: T is f32, i32, u32 or f16 - operator * (vecN, T) -> vecN where: T is f32, i32, u32 or f16 - operator * (T, vecN) -> vecN where: T is f32, i32, u32 or f16 - operator * (T, matNxM) -> matNxM where: T is f32 or f16 - operator * (matNxM, T) -> matNxM where: T is f32 or f16 - operator * (vecN, vecN) -> vecN where: T is f32, i32, u32 or f16 - operator * (matCxR, vecC) -> vecR where: T is f32 or f16 - operator * (vecR, matCxR) -> vecC where: T is f32 or f16 - operator * (matKxR, matCxK) -> matCxR where: T is f32 or f16 + operator * (T, T) -> T where: T is abstract-float, abstract-int, f32, i32, u32 or f16 + operator * (vecN, T) -> vecN where: T is abstract-float, abstract-int, f32, i32, u32 or f16 + operator * (T, vecN) -> vecN where: T is abstract-float, abstract-int, f32, i32, u32 or f16 + operator * (T, matNxM) -> matNxM where: T is abstract-float, f32 or f16 + operator * (matNxM, T) -> matNxM where: T is abstract-float, f32 or f16 + operator * (vecN, vecN) -> vecN where: T is abstract-float, abstract-int, f32, i32, u32 or f16 + operator * (matCxR, vecC) -> vecR where: T is abstract-float, f32 or f16 + operator * (vecR, matCxR) -> vecC where: T is abstract-float, f32 or f16 + operator * (matKxR, matCxK) -> matCxR where: T is abstract-float, f32 or f16 )"); } @@ -673,15 +673,15 @@ TEST_F(IntrinsicTableTest, MismatchCompoundOp) { EXPECT_EQ(Diagnostics().str(), R"(12:34 error: no matching overload for operator *= (f32, bool) 9 candidate operators: - operator *= (T, T) -> T where: T is f32, i32, u32 or f16 - operator *= (vecN, T) -> vecN where: T is f32, i32, u32 or f16 - operator *= (T, vecN) -> vecN where: T is f32, i32, u32 or f16 - operator *= (T, matNxM) -> matNxM where: T is f32 or f16 - operator *= (matNxM, T) -> matNxM where: T is f32 or f16 - operator *= (vecN, vecN) -> vecN where: T is f32, i32, u32 or f16 - operator *= (matCxR, vecC) -> vecR where: T is f32 or f16 - operator *= (vecR, matCxR) -> vecC where: T is f32 or f16 - operator *= (matKxR, matCxK) -> matCxR where: T is f32 or f16 + operator *= (T, T) -> T where: T is abstract-float, abstract-int, f32, i32, u32 or f16 + operator *= (vecN, T) -> vecN where: T is abstract-float, abstract-int, f32, i32, u32 or f16 + operator *= (T, vecN) -> vecN where: T is abstract-float, abstract-int, f32, i32, u32 or f16 + operator *= (T, matNxM) -> matNxM where: T is abstract-float, f32 or f16 + operator *= (matNxM, T) -> matNxM where: T is abstract-float, f32 or f16 + operator *= (vecN, vecN) -> vecN where: T is abstract-float, abstract-int, f32, i32, u32 or f16 + operator *= (matCxR, vecC) -> vecR where: T is abstract-float, f32 or f16 + operator *= (vecR, matCxR) -> vecC where: T is abstract-float, f32 or f16 + operator *= (matKxR, matCxK) -> matCxR where: T is abstract-float, f32 or f16 )"); } diff --git a/src/tint/utils/result.h b/src/tint/utils/result.h index b535f4fcdb..6a14352670 100644 --- a/src/tint/utils/result.h +++ b/src/tint/utils/result.h @@ -17,6 +17,7 @@ #include #include +#include "src/tint/debug.h" namespace tint::utils { @@ -36,6 +37,9 @@ struct [[nodiscard]] Result { static_assert(!std::is_same_v, "Result must not have the same type for SUCCESS_TYPE and FAILURE_TYPE"); + /// Default constructor initializes to invalid state + Result() : value(std::monostate{}) {} + /// Constructor /// @param success the success result Result(const SUCCESS_TYPE& success) // NOLINT(runtime/explicit): @@ -47,27 +51,43 @@ struct [[nodiscard]] Result { : value{failure} {} /// @returns true if the result was a success - operator bool() const { return std::holds_alternative(value); } + operator bool() const { + Validate(); + return std::holds_alternative(value); + } /// @returns true if the result was a failure - bool operator!() const { return std::holds_alternative(value); } + bool operator!() const { + Validate(); + return std::holds_alternative(value); + } /// @returns the success value /// @warning attempting to call this when the Result holds an failure will result in UB. - const SUCCESS_TYPE* operator->() const { return &std::get(value); } + const SUCCESS_TYPE* operator->() const { + Validate(); + return &(Get()); + } /// @returns the success value /// @warning attempting to call this when the Result holds an failure value will result in UB. - const SUCCESS_TYPE& Get() const { return std::get(value); } + const SUCCESS_TYPE& Get() const { + Validate(); + return std::get(value); + } /// @returns the failure value /// @warning attempting to call this when the Result holds a success value will result in UB. - const FAILURE_TYPE& Failure() const { return std::get(value); } + const FAILURE_TYPE& Failure() const { + Validate(); + return std::get(value); + } /// Equality operator /// @param val the value to compare this Result to /// @returns true if this result holds a success value equal to `value` bool operator==(SUCCESS_TYPE val) const { + Validate(); if (auto* v = std::get_if(&value)) { return *v == val; } @@ -78,14 +98,18 @@ struct [[nodiscard]] Result { /// @param val the value to compare this Result to /// @returns true if this result holds a failure value equal to `value` bool operator==(FAILURE_TYPE val) const { + Validate(); if (auto* v = std::get_if(&value)) { return *v == val; } return false; } + private: + void Validate() const { TINT_ASSERT(Utils, !std::holds_alternative(value)); } + /// The result. Either a success of failure value. - std::variant value; + std::variant value; }; /// Writes the result to the ostream. diff --git a/src/tint/writer/glsl/generator_impl_binary_test.cc b/src/tint/writer/glsl/generator_impl_binary_test.cc index 9113ae6f8f..23e2ece759 100644 --- a/src/tint/writer/glsl/generator_impl_binary_test.cc +++ b/src/tint/writer/glsl/generator_impl_binary_test.cc @@ -151,7 +151,8 @@ INSTANTIATE_TEST_SUITE_P( BinaryData{"(left % right)", ast::BinaryOp::kModulo})); TEST_F(GlslGeneratorImplTest_Binary, Multiply_VectorScalar_f32) { - auto* lhs = vec3(1_f, 1_f, 1_f); + GlobalVar("a", vec3(1_f, 1_f, 1_f), ast::StorageClass::kPrivate); + auto* lhs = Expr("a"); auto* rhs = Expr(1_f); auto* expr = create(ast::BinaryOp::kMultiply, lhs, rhs); @@ -162,13 +163,14 @@ TEST_F(GlslGeneratorImplTest_Binary, Multiply_VectorScalar_f32) { std::stringstream out; EXPECT_TRUE(gen.EmitExpression(out, expr)) << gen.error(); - EXPECT_EQ(out.str(), "(vec3(1.0f) * 1.0f)"); + EXPECT_EQ(out.str(), "(a * 1.0f)"); } TEST_F(GlslGeneratorImplTest_Binary, Multiply_VectorScalar_f16) { Enable(ast::Extension::kF16); - auto* lhs = vec3(1_h, 1_h, 1_h); + GlobalVar("a", vec3(1_h, 1_h, 1_h), ast::StorageClass::kPrivate); + auto* lhs = Expr("a"); auto* rhs = Expr(1_h); auto* expr = create(ast::BinaryOp::kMultiply, lhs, rhs); @@ -179,12 +181,13 @@ TEST_F(GlslGeneratorImplTest_Binary, Multiply_VectorScalar_f16) { std::stringstream out; EXPECT_TRUE(gen.EmitExpression(out, expr)) << gen.error(); - EXPECT_EQ(out.str(), "(f16vec3(1.0hf) * 1.0hf)"); + EXPECT_EQ(out.str(), "(a * 1.0hf)"); } TEST_F(GlslGeneratorImplTest_Binary, Multiply_ScalarVector_f32) { + GlobalVar("a", vec3(1_f, 1_f, 1_f), ast::StorageClass::kPrivate); auto* lhs = Expr(1_f); - auto* rhs = vec3(1_f, 1_f, 1_f); + auto* rhs = Expr("a"); auto* expr = create(ast::BinaryOp::kMultiply, lhs, rhs); @@ -194,14 +197,15 @@ TEST_F(GlslGeneratorImplTest_Binary, Multiply_ScalarVector_f32) { std::stringstream out; EXPECT_TRUE(gen.EmitExpression(out, expr)) << gen.error(); - EXPECT_EQ(out.str(), "(1.0f * vec3(1.0f))"); + EXPECT_EQ(out.str(), "(1.0f * a)"); } TEST_F(GlslGeneratorImplTest_Binary, Multiply_ScalarVector_f16) { Enable(ast::Extension::kF16); + GlobalVar("a", vec3(1_h, 1_h, 1_h), ast::StorageClass::kPrivate); auto* lhs = Expr(1_h); - auto* rhs = vec3(1_h, 1_h, 1_h); + auto* rhs = Expr("a"); auto* expr = create(ast::BinaryOp::kMultiply, lhs, rhs); @@ -211,7 +215,7 @@ TEST_F(GlslGeneratorImplTest_Binary, Multiply_ScalarVector_f16) { std::stringstream out; EXPECT_TRUE(gen.EmitExpression(out, expr)) << gen.error(); - EXPECT_EQ(out.str(), "(1.0hf * f16vec3(1.0hf))"); + EXPECT_EQ(out.str(), "(1.0hf * a)"); } TEST_F(GlslGeneratorImplTest_Binary, Multiply_MatrixScalar_f32) { diff --git a/src/tint/writer/hlsl/generator_impl_binary_test.cc b/src/tint/writer/hlsl/generator_impl_binary_test.cc index aa541441e8..8a87697055 100644 --- a/src/tint/writer/hlsl/generator_impl_binary_test.cc +++ b/src/tint/writer/hlsl/generator_impl_binary_test.cc @@ -184,7 +184,7 @@ TEST_F(HlslGeneratorImplTest_Binary, Multiply_VectorScalar_f32) { std::stringstream out; EXPECT_TRUE(gen.EmitExpression(out, expr)) << gen.error(); - EXPECT_EQ(out.str(), "((1.0f).xxx * 1.0f)"); + EXPECT_EQ(out.str(), "(1.0f).xxx"); } TEST_F(HlslGeneratorImplTest_Binary, Multiply_VectorScalar_f16) { @@ -201,7 +201,7 @@ TEST_F(HlslGeneratorImplTest_Binary, Multiply_VectorScalar_f16) { std::stringstream out; EXPECT_TRUE(gen.EmitExpression(out, expr)) << gen.error(); - EXPECT_EQ(out.str(), "((float16_t(1.0h)).xxx * float16_t(1.0h))"); + EXPECT_EQ(out.str(), "(float16_t(1.0h)).xxx"); } TEST_F(HlslGeneratorImplTest_Binary, Multiply_ScalarVector_f32) { @@ -216,7 +216,7 @@ TEST_F(HlslGeneratorImplTest_Binary, Multiply_ScalarVector_f32) { std::stringstream out; EXPECT_TRUE(gen.EmitExpression(out, expr)) << gen.error(); - EXPECT_EQ(out.str(), "(1.0f * (1.0f).xxx)"); + EXPECT_EQ(out.str(), "(1.0f).xxx"); } TEST_F(HlslGeneratorImplTest_Binary, Multiply_ScalarVector_f16) { @@ -233,7 +233,7 @@ TEST_F(HlslGeneratorImplTest_Binary, Multiply_ScalarVector_f16) { std::stringstream out; EXPECT_TRUE(gen.EmitExpression(out, expr)) << gen.error(); - EXPECT_EQ(out.str(), "(float16_t(1.0h) * (float16_t(1.0h)).xxx)"); + EXPECT_EQ(out.str(), "(float16_t(1.0h)).xxx"); } TEST_F(HlslGeneratorImplTest_Binary, Multiply_MatrixScalar_f32) { diff --git a/test/tint/bug/tint/757.wgsl.expected.dxc.hlsl b/test/tint/bug/tint/757.wgsl.expected.dxc.hlsl index d144aff8e9..a0788b89f2 100644 --- a/test/tint/bug/tint/757.wgsl.expected.dxc.hlsl +++ b/test/tint/bug/tint/757.wgsl.expected.dxc.hlsl @@ -10,7 +10,7 @@ struct tint_symbol_1 { }; void main_inner(uint3 GlobalInvocationID) { - uint flatIndex = ((((2u * 2u) * GlobalInvocationID.z) + (2u * GlobalInvocationID.y)) + GlobalInvocationID.x); + uint flatIndex = (((4u * GlobalInvocationID.z) + (2u * GlobalInvocationID.y)) + GlobalInvocationID.x); flatIndex = (flatIndex * 1u); float4 texel = myTexture.Load(int4(int3(int2(GlobalInvocationID.xy), 0), 0)); { diff --git a/test/tint/bug/tint/757.wgsl.expected.fxc.hlsl b/test/tint/bug/tint/757.wgsl.expected.fxc.hlsl index d144aff8e9..a0788b89f2 100644 --- a/test/tint/bug/tint/757.wgsl.expected.fxc.hlsl +++ b/test/tint/bug/tint/757.wgsl.expected.fxc.hlsl @@ -10,7 +10,7 @@ struct tint_symbol_1 { }; void main_inner(uint3 GlobalInvocationID) { - uint flatIndex = ((((2u * 2u) * GlobalInvocationID.z) + (2u * GlobalInvocationID.y)) + GlobalInvocationID.x); + uint flatIndex = (((4u * GlobalInvocationID.z) + (2u * GlobalInvocationID.y)) + GlobalInvocationID.x); flatIndex = (flatIndex * 1u); float4 texel = myTexture.Load(int4(int3(int2(GlobalInvocationID.xy), 0), 0)); { diff --git a/test/tint/bug/tint/757.wgsl.expected.glsl b/test/tint/bug/tint/757.wgsl.expected.glsl index ef8ff63fad..cd7ba225f5 100644 --- a/test/tint/bug/tint/757.wgsl.expected.glsl +++ b/test/tint/bug/tint/757.wgsl.expected.glsl @@ -9,7 +9,7 @@ layout(binding = 3, std430) buffer Result_1 { } result; uniform highp sampler2DArray myTexture_1; void tint_symbol(uvec3 GlobalInvocationID) { - uint flatIndex = ((((2u * 2u) * GlobalInvocationID.z) + (2u * GlobalInvocationID.y)) + GlobalInvocationID.x); + uint flatIndex = (((4u * GlobalInvocationID.z) + (2u * GlobalInvocationID.y)) + GlobalInvocationID.x); flatIndex = (flatIndex * 1u); vec4 texel = texelFetch(myTexture_1, ivec3(ivec2(GlobalInvocationID.xy), 0), 0); { diff --git a/test/tint/bug/tint/757.wgsl.expected.msl b/test/tint/bug/tint/757.wgsl.expected.msl index b0644d05d0..0f16050db4 100644 --- a/test/tint/bug/tint/757.wgsl.expected.msl +++ b/test/tint/bug/tint/757.wgsl.expected.msl @@ -23,7 +23,7 @@ struct Result { }; void tint_symbol_inner(uint3 GlobalInvocationID, texture2d_array tint_symbol_1, device Result* const tint_symbol_2) { - uint flatIndex = ((((2u * 2u) * GlobalInvocationID[2]) + (2u * GlobalInvocationID[1])) + GlobalInvocationID[0]); + uint flatIndex = (((4u * GlobalInvocationID[2]) + (2u * GlobalInvocationID[1])) + GlobalInvocationID[0]); flatIndex = (flatIndex * 1u); float4 texel = tint_symbol_1.read(uint2(int2(uint3(GlobalInvocationID).xy)), 0, 0); for(uint i = 0u; (i < 1u); i = (i + 1u)) { diff --git a/test/tint/bug/tint/757.wgsl.expected.spvasm b/test/tint/bug/tint/757.wgsl.expected.spvasm index 4ad4cef2b4..c660ecbe2d 100644 --- a/test/tint/bug/tint/757.wgsl.expected.spvasm +++ b/test/tint/bug/tint/757.wgsl.expected.spvasm @@ -52,6 +52,7 @@ %result = OpVariable %_ptr_StorageBuffer_Result StorageBuffer %void = OpTypeVoid %17 = OpTypeFunction %void %v3uint + %uint_4 = OpConstant %uint 4 %uint_2 = OpConstant %uint 2 %_ptr_Function_uint = OpTypePointer Function %uint %33 = OpConstantNull %uint @@ -74,12 +75,11 @@ %flatIndex = OpVariable %_ptr_Function_uint Function %33 %texel = OpVariable %_ptr_Function_v4float Function %51 %i = OpVariable %_ptr_Function_uint Function %33 - %23 = OpIMul %uint %uint_2 %uint_2 - %24 = OpCompositeExtract %uint %GlobalInvocationID 2 - %25 = OpIMul %uint %23 %24 + %23 = OpCompositeExtract %uint %GlobalInvocationID 2 + %24 = OpIMul %uint %uint_4 %23 %26 = OpCompositeExtract %uint %GlobalInvocationID 1 %27 = OpIMul %uint %uint_2 %26 - %28 = OpIAdd %uint %25 %27 + %28 = OpIAdd %uint %24 %27 %29 = OpCompositeExtract %uint %GlobalInvocationID 0 %30 = OpIAdd %uint %28 %29 OpStore %flatIndex %30 diff --git a/test/tint/bug/tint/914.wgsl.expected.dxc.hlsl b/test/tint/bug/tint/914.wgsl.expected.dxc.hlsl index a9ca8f839a..ee10d6f437 100644 --- a/test/tint/bug/tint/914.wgsl.expected.dxc.hlsl +++ b/test/tint/bug/tint/914.wgsl.expected.dxc.hlsl @@ -68,7 +68,7 @@ void main_inner(uint3 local_id, uint3 global_id, uint local_invocation_index) { float ACached = 0.0f; float BCached[4] = (float[4])0; { - [loop] for(uint index = 0u; (index < (4u * 4u)); index = (index + 1u)) { + [loop] for(uint index = 0u; (index < 16u); index = (index + 1u)) { acc[index] = 0.0f; } } diff --git a/test/tint/bug/tint/914.wgsl.expected.fxc.hlsl b/test/tint/bug/tint/914.wgsl.expected.fxc.hlsl index a9ca8f839a..ee10d6f437 100644 --- a/test/tint/bug/tint/914.wgsl.expected.fxc.hlsl +++ b/test/tint/bug/tint/914.wgsl.expected.fxc.hlsl @@ -68,7 +68,7 @@ void main_inner(uint3 local_id, uint3 global_id, uint local_invocation_index) { float ACached = 0.0f; float BCached[4] = (float[4])0; { - [loop] for(uint index = 0u; (index < (4u * 4u)); index = (index + 1u)) { + [loop] for(uint index = 0u; (index < 16u); index = (index + 1u)) { acc[index] = 0.0f; } } diff --git a/test/tint/bug/tint/914.wgsl.expected.glsl b/test/tint/bug/tint/914.wgsl.expected.glsl index a40e5ce01b..25b3ff30a2 100644 --- a/test/tint/bug/tint/914.wgsl.expected.glsl +++ b/test/tint/bug/tint/914.wgsl.expected.glsl @@ -77,7 +77,7 @@ void tint_symbol(uvec3 local_id, uvec3 global_id, uint local_invocation_index) { float ACached = 0.0f; float BCached[4] = float[4](0.0f, 0.0f, 0.0f, 0.0f); { - for(uint index = 0u; (index < (4u * 4u)); index = (index + 1u)) { + for(uint index = 0u; (index < 16u); index = (index + 1u)) { acc[index] = 0.0f; } } diff --git a/test/tint/bug/tint/914.wgsl.expected.msl b/test/tint/bug/tint/914.wgsl.expected.msl index d70dc36902..714fdf9276 100644 --- a/test/tint/bug/tint/914.wgsl.expected.msl +++ b/test/tint/bug/tint/914.wgsl.expected.msl @@ -63,7 +63,7 @@ void tint_symbol_inner(uint3 local_id, uint3 global_id, uint local_invocation_in tint_array acc = {}; float ACached = 0.0f; tint_array BCached = {}; - for(uint index = 0u; (index < (4u * 4u)); index = (index + 1u)) { + for(uint index = 0u; (index < 16u); index = (index + 1u)) { acc[index] = 0.0f; } uint const ColPerThreadA = (64u / 16u); diff --git a/test/tint/bug/tint/914.wgsl.expected.spvasm b/test/tint/bug/tint/914.wgsl.expected.spvasm index 6506377514..86e951e87f 100644 --- a/test/tint/bug/tint/914.wgsl.expected.spvasm +++ b/test/tint/bug/tint/914.wgsl.expected.spvasm @@ -1,7 +1,7 @@ ; SPIR-V ; Version: 1.3 ; Generator: Google Tint Compiler; 0 -; Bound: 374 +; Bound: 373 ; Schema: 0 OpCapability Shader OpMemoryModel Logical GLSL450 @@ -127,7 +127,7 @@ %_arr_float_uint_4 = OpTypeArray %float %uint_4 %_ptr_Function__arr_float_uint_4 = OpTypePointer Function %_arr_float_uint_4 %152 = OpConstantNull %_arr_float_uint_4 - %367 = OpTypeFunction %void + %366 = OpTypeFunction %void %mm_readA = OpFunction %float None %24 %row = OpFunctionParameter %uint %col = OpFunctionParameter %uint @@ -287,334 +287,333 @@ OpBranch %157 %157 = OpLabel %159 = OpLoad %uint %index - %160 = OpIMul %uint %uint_4 %uint_4 - %161 = OpULessThan %bool %159 %160 - %158 = OpLogicalNot %bool %161 - OpSelectionMerge %162 None - OpBranchConditional %158 %163 %162 - %163 = OpLabel - OpBranch %155 + %160 = OpULessThan %bool %159 %uint_16 + %158 = OpLogicalNot %bool %160 + OpSelectionMerge %161 None + OpBranchConditional %158 %162 %161 %162 = OpLabel - %164 = OpLoad %uint %index - %165 = OpAccessChain %_ptr_Function_float %acc %164 - OpStore %165 %51 + OpBranch %155 + %161 = OpLabel + %163 = OpLoad %uint %index + %164 = OpAccessChain %_ptr_Function_float %acc %163 + OpStore %164 %51 OpBranch %156 %156 = OpLabel - %166 = OpLoad %uint %index - %167 = OpIAdd %uint %166 %uint_1 - OpStore %index %167 + %165 = OpLoad %uint %index + %166 = OpIAdd %uint %165 %uint_1 + OpStore %index %166 OpBranch %154 %155 = OpLabel - %168 = OpUDiv %uint %uint_64 %uint_16 - %169 = OpCompositeExtract %uint %local_id 0 - %170 = OpIMul %uint %169 %168 - %171 = OpUDiv %uint %uint_64 %uint_16 - %172 = OpCompositeExtract %uint %local_id 1 - %173 = OpIMul %uint %172 %171 + %167 = OpUDiv %uint %uint_64 %uint_16 + %168 = OpCompositeExtract %uint %local_id 0 + %169 = OpIMul %uint %168 %167 + %170 = OpUDiv %uint %uint_64 %uint_16 + %171 = OpCompositeExtract %uint %local_id 1 + %172 = OpIMul %uint %171 %170 OpStore %t %105 - OpBranch %175 - %175 = OpLabel - OpLoopMerge %176 %177 None - OpBranch %178 - %178 = OpLabel - %180 = OpLoad %uint %t - %181 = OpULessThan %bool %180 %141 - %179 = OpLogicalNot %bool %181 - OpSelectionMerge %182 None - OpBranchConditional %179 %183 %182 - %183 = OpLabel - OpBranch %176 - %182 = OpLabel - OpStore %innerRow %105 - OpBranch %185 - %185 = OpLabel - OpLoopMerge %186 %187 None - OpBranch %188 - %188 = OpLabel - %190 = OpLoad %uint %innerRow - %191 = OpULessThan %bool %190 %uint_4 - %189 = OpLogicalNot %bool %191 - OpSelectionMerge %192 None - OpBranchConditional %189 %193 %192 - %193 = OpLabel - OpBranch %186 - %192 = OpLabel - OpStore %innerCol %105 - OpBranch %195 - %195 = OpLabel - OpLoopMerge %196 %197 None - OpBranch %198 - %198 = OpLabel - %200 = OpLoad %uint %innerCol - %201 = OpULessThan %bool %200 %168 - %199 = OpLogicalNot %bool %201 - OpSelectionMerge %202 None - OpBranchConditional %199 %203 %202 - %203 = OpLabel - OpBranch %196 - %202 = OpLabel - %204 = OpLoad %uint %innerRow - %205 = OpIAdd %uint %130 %204 - %206 = OpLoad %uint %innerCol - %207 = OpIAdd %uint %170 %206 - %209 = OpLoad %uint %innerRow - %210 = OpIAdd %uint %134 %209 - %211 = OpLoad %uint %t - %212 = OpIMul %uint %211 %uint_64 - %213 = OpIAdd %uint %212 %207 - %208 = OpFunctionCall %float %mm_readA %210 %213 - %214 = OpAccessChain %_ptr_Workgroup_float %mm_Asub %205 %207 - OpStore %214 %208 - OpBranch %197 - %197 = OpLabel - %215 = OpLoad %uint %innerCol - %216 = OpIAdd %uint %215 %uint_1 - OpStore %innerCol %216 - OpBranch %195 - %196 = OpLabel - OpBranch %187 - %187 = OpLabel - %217 = OpLoad %uint %innerRow - %218 = OpIAdd %uint %217 %uint_1 - OpStore %innerRow %218 - OpBranch %185 - %186 = OpLabel - OpStore %innerRow_0 %105 - OpBranch %220 - %220 = OpLabel - OpLoopMerge %221 %222 None - OpBranch %223 - %223 = OpLabel - %225 = OpLoad %uint %innerRow_0 - %226 = OpULessThan %bool %225 %171 - %224 = OpLogicalNot %bool %226 - OpSelectionMerge %227 None - OpBranchConditional %224 %228 %227 - %228 = OpLabel - OpBranch %221 - %227 = OpLabel - OpStore %innerCol_0 %105 - OpBranch %230 - %230 = OpLabel - OpLoopMerge %231 %232 None - OpBranch %233 - %233 = OpLabel - %235 = OpLoad %uint %innerCol_0 - %236 = OpULessThan %bool %235 %uint_4 - %234 = OpLogicalNot %bool %236 - OpSelectionMerge %237 None - OpBranchConditional %234 %238 %237 - %238 = OpLabel - OpBranch %231 - %237 = OpLabel - %239 = OpLoad %uint %innerRow_0 - %240 = OpIAdd %uint %173 %239 - %241 = OpLoad %uint %innerCol_0 - %242 = OpIAdd %uint %132 %241 - %244 = OpLoad %uint %t - %245 = OpIMul %uint %244 %uint_64 - %246 = OpIAdd %uint %245 %240 - %247 = OpLoad %uint %innerCol_0 - %248 = OpIAdd %uint %136 %247 - %243 = OpFunctionCall %float %mm_readB %246 %248 - %249 = OpLoad %uint %innerCol_0 - %250 = OpAccessChain %_ptr_Workgroup_float %mm_Bsub %249 %242 - OpStore %250 %243 - OpBranch %232 - %232 = OpLabel - %251 = OpLoad %uint %innerCol_0 - %252 = OpIAdd %uint %251 %uint_1 - OpStore %innerCol_0 %252 - OpBranch %230 - %231 = OpLabel - OpBranch %222 - %222 = OpLabel - %253 = OpLoad %uint %innerRow_0 - %254 = OpIAdd %uint %253 %uint_1 - OpStore %innerRow_0 %254 - OpBranch %220 - %221 = OpLabel - OpControlBarrier %uint_2 %uint_2 %uint_264 - OpStore %k %105 - OpBranch %257 - %257 = OpLabel - OpLoopMerge %258 %259 None - OpBranch %260 - %260 = OpLabel - %262 = OpLoad %uint %k - %263 = OpULessThan %bool %262 %uint_64 - %261 = OpLogicalNot %bool %263 - OpSelectionMerge %264 None - OpBranchConditional %261 %265 %264 - %265 = OpLabel - OpBranch %258 - %264 = OpLabel - OpStore %inner %105 - OpBranch %267 - %267 = OpLabel - OpLoopMerge %268 %269 None - OpBranch %270 - %270 = OpLabel - %272 = OpLoad %uint %inner - %273 = OpULessThan %bool %272 %uint_4 - %271 = OpLogicalNot %bool %273 - OpSelectionMerge %274 None - OpBranchConditional %271 %275 %274 - %275 = OpLabel - OpBranch %268 - %274 = OpLabel - %276 = OpLoad %uint %inner - %277 = OpAccessChain %_ptr_Function_float %BCached %276 - %278 = OpLoad %uint %k - %279 = OpLoad %uint %inner - %280 = OpIAdd %uint %132 %279 - %281 = OpAccessChain %_ptr_Workgroup_float %mm_Bsub %278 %280 - %282 = OpLoad %float %281 - OpStore %277 %282 - OpBranch %269 - %269 = OpLabel - %283 = OpLoad %uint %inner - %284 = OpIAdd %uint %283 %uint_1 - OpStore %inner %284 - OpBranch %267 - %268 = OpLabel - OpStore %innerRow_1 %105 - OpBranch %286 - %286 = OpLabel - OpLoopMerge %287 %288 None - OpBranch %289 - %289 = OpLabel - %291 = OpLoad %uint %innerRow_1 - %292 = OpULessThan %bool %291 %uint_4 - %290 = OpLogicalNot %bool %292 - OpSelectionMerge %293 None - OpBranchConditional %290 %294 %293 - %294 = OpLabel - OpBranch %287 - %293 = OpLabel - %295 = OpLoad %uint %innerRow_1 - %296 = OpIAdd %uint %130 %295 - %297 = OpLoad %uint %k - %298 = OpAccessChain %_ptr_Workgroup_float %mm_Asub %296 %297 - %299 = OpLoad %float %298 - OpStore %ACached %299 - OpStore %innerCol_1 %105 - OpBranch %301 - %301 = OpLabel - OpLoopMerge %302 %303 None - OpBranch %304 - %304 = OpLabel - %306 = OpLoad %uint %innerCol_1 - %307 = OpULessThan %bool %306 %uint_4 - %305 = OpLogicalNot %bool %307 - OpSelectionMerge %308 None - OpBranchConditional %305 %309 %308 - %309 = OpLabel - OpBranch %302 - %308 = OpLabel - %310 = OpLoad %uint %innerRow_1 - %311 = OpIMul %uint %310 %uint_4 - %312 = OpLoad %uint %innerCol_1 - %313 = OpIAdd %uint %311 %312 - %314 = OpAccessChain %_ptr_Function_float %acc %313 - %315 = OpAccessChain %_ptr_Function_float %acc %313 - %316 = OpLoad %float %315 - %317 = OpLoad %float %ACached - %318 = OpLoad %uint %innerCol_1 - %319 = OpAccessChain %_ptr_Function_float %BCached %318 - %320 = OpLoad %float %319 - %321 = OpFMul %float %317 %320 - %322 = OpFAdd %float %316 %321 - OpStore %314 %322 - OpBranch %303 - %303 = OpLabel - %323 = OpLoad %uint %innerCol_1 - %324 = OpIAdd %uint %323 %uint_1 - OpStore %innerCol_1 %324 - OpBranch %301 - %302 = OpLabel - OpBranch %288 - %288 = OpLabel - %325 = OpLoad %uint %innerRow_1 - %326 = OpIAdd %uint %325 %uint_1 - OpStore %innerRow_1 %326 - OpBranch %286 - %287 = OpLabel - OpBranch %259 - %259 = OpLabel - %327 = OpLoad %uint %k - %328 = OpIAdd %uint %327 %uint_1 - OpStore %k %328 - OpBranch %257 - %258 = OpLabel - OpControlBarrier %uint_2 %uint_2 %uint_264 + OpBranch %174 + %174 = OpLabel + OpLoopMerge %175 %176 None OpBranch %177 %177 = OpLabel - %330 = OpLoad %uint %t - %331 = OpIAdd %uint %330 %uint_1 - OpStore %t %331 + %179 = OpLoad %uint %t + %180 = OpULessThan %bool %179 %141 + %178 = OpLogicalNot %bool %180 + OpSelectionMerge %181 None + OpBranchConditional %178 %182 %181 + %182 = OpLabel OpBranch %175 + %181 = OpLabel + OpStore %innerRow %105 + OpBranch %184 + %184 = OpLabel + OpLoopMerge %185 %186 None + OpBranch %187 + %187 = OpLabel + %189 = OpLoad %uint %innerRow + %190 = OpULessThan %bool %189 %uint_4 + %188 = OpLogicalNot %bool %190 + OpSelectionMerge %191 None + OpBranchConditional %188 %192 %191 + %192 = OpLabel + OpBranch %185 + %191 = OpLabel + OpStore %innerCol %105 + OpBranch %194 + %194 = OpLabel + OpLoopMerge %195 %196 None + OpBranch %197 + %197 = OpLabel + %199 = OpLoad %uint %innerCol + %200 = OpULessThan %bool %199 %167 + %198 = OpLogicalNot %bool %200 + OpSelectionMerge %201 None + OpBranchConditional %198 %202 %201 + %202 = OpLabel + OpBranch %195 + %201 = OpLabel + %203 = OpLoad %uint %innerRow + %204 = OpIAdd %uint %130 %203 + %205 = OpLoad %uint %innerCol + %206 = OpIAdd %uint %169 %205 + %208 = OpLoad %uint %innerRow + %209 = OpIAdd %uint %134 %208 + %210 = OpLoad %uint %t + %211 = OpIMul %uint %210 %uint_64 + %212 = OpIAdd %uint %211 %206 + %207 = OpFunctionCall %float %mm_readA %209 %212 + %213 = OpAccessChain %_ptr_Workgroup_float %mm_Asub %204 %206 + OpStore %213 %207 + OpBranch %196 + %196 = OpLabel + %214 = OpLoad %uint %innerCol + %215 = OpIAdd %uint %214 %uint_1 + OpStore %innerCol %215 + OpBranch %194 + %195 = OpLabel + OpBranch %186 + %186 = OpLabel + %216 = OpLoad %uint %innerRow + %217 = OpIAdd %uint %216 %uint_1 + OpStore %innerRow %217 + OpBranch %184 + %185 = OpLabel + OpStore %innerRow_0 %105 + OpBranch %219 + %219 = OpLabel + OpLoopMerge %220 %221 None + OpBranch %222 + %222 = OpLabel + %224 = OpLoad %uint %innerRow_0 + %225 = OpULessThan %bool %224 %170 + %223 = OpLogicalNot %bool %225 + OpSelectionMerge %226 None + OpBranchConditional %223 %227 %226 + %227 = OpLabel + OpBranch %220 + %226 = OpLabel + OpStore %innerCol_0 %105 + OpBranch %229 + %229 = OpLabel + OpLoopMerge %230 %231 None + OpBranch %232 + %232 = OpLabel + %234 = OpLoad %uint %innerCol_0 + %235 = OpULessThan %bool %234 %uint_4 + %233 = OpLogicalNot %bool %235 + OpSelectionMerge %236 None + OpBranchConditional %233 %237 %236 + %237 = OpLabel + OpBranch %230 + %236 = OpLabel + %238 = OpLoad %uint %innerRow_0 + %239 = OpIAdd %uint %172 %238 + %240 = OpLoad %uint %innerCol_0 + %241 = OpIAdd %uint %132 %240 + %243 = OpLoad %uint %t + %244 = OpIMul %uint %243 %uint_64 + %245 = OpIAdd %uint %244 %239 + %246 = OpLoad %uint %innerCol_0 + %247 = OpIAdd %uint %136 %246 + %242 = OpFunctionCall %float %mm_readB %245 %247 + %248 = OpLoad %uint %innerCol_0 + %249 = OpAccessChain %_ptr_Workgroup_float %mm_Bsub %248 %241 + OpStore %249 %242 + OpBranch %231 + %231 = OpLabel + %250 = OpLoad %uint %innerCol_0 + %251 = OpIAdd %uint %250 %uint_1 + OpStore %innerCol_0 %251 + OpBranch %229 + %230 = OpLabel + OpBranch %221 + %221 = OpLabel + %252 = OpLoad %uint %innerRow_0 + %253 = OpIAdd %uint %252 %uint_1 + OpStore %innerRow_0 %253 + OpBranch %219 + %220 = OpLabel + OpControlBarrier %uint_2 %uint_2 %uint_264 + OpStore %k %105 + OpBranch %256 + %256 = OpLabel + OpLoopMerge %257 %258 None + OpBranch %259 + %259 = OpLabel + %261 = OpLoad %uint %k + %262 = OpULessThan %bool %261 %uint_64 + %260 = OpLogicalNot %bool %262 + OpSelectionMerge %263 None + OpBranchConditional %260 %264 %263 + %264 = OpLabel + OpBranch %257 + %263 = OpLabel + OpStore %inner %105 + OpBranch %266 + %266 = OpLabel + OpLoopMerge %267 %268 None + OpBranch %269 + %269 = OpLabel + %271 = OpLoad %uint %inner + %272 = OpULessThan %bool %271 %uint_4 + %270 = OpLogicalNot %bool %272 + OpSelectionMerge %273 None + OpBranchConditional %270 %274 %273 + %274 = OpLabel + OpBranch %267 + %273 = OpLabel + %275 = OpLoad %uint %inner + %276 = OpAccessChain %_ptr_Function_float %BCached %275 + %277 = OpLoad %uint %k + %278 = OpLoad %uint %inner + %279 = OpIAdd %uint %132 %278 + %280 = OpAccessChain %_ptr_Workgroup_float %mm_Bsub %277 %279 + %281 = OpLoad %float %280 + OpStore %276 %281 + OpBranch %268 + %268 = OpLabel + %282 = OpLoad %uint %inner + %283 = OpIAdd %uint %282 %uint_1 + OpStore %inner %283 + OpBranch %266 + %267 = OpLabel + OpStore %innerRow_1 %105 + OpBranch %285 + %285 = OpLabel + OpLoopMerge %286 %287 None + OpBranch %288 + %288 = OpLabel + %290 = OpLoad %uint %innerRow_1 + %291 = OpULessThan %bool %290 %uint_4 + %289 = OpLogicalNot %bool %291 + OpSelectionMerge %292 None + OpBranchConditional %289 %293 %292 + %293 = OpLabel + OpBranch %286 + %292 = OpLabel + %294 = OpLoad %uint %innerRow_1 + %295 = OpIAdd %uint %130 %294 + %296 = OpLoad %uint %k + %297 = OpAccessChain %_ptr_Workgroup_float %mm_Asub %295 %296 + %298 = OpLoad %float %297 + OpStore %ACached %298 + OpStore %innerCol_1 %105 + OpBranch %300 + %300 = OpLabel + OpLoopMerge %301 %302 None + OpBranch %303 + %303 = OpLabel + %305 = OpLoad %uint %innerCol_1 + %306 = OpULessThan %bool %305 %uint_4 + %304 = OpLogicalNot %bool %306 + OpSelectionMerge %307 None + OpBranchConditional %304 %308 %307 + %308 = OpLabel + OpBranch %301 + %307 = OpLabel + %309 = OpLoad %uint %innerRow_1 + %310 = OpIMul %uint %309 %uint_4 + %311 = OpLoad %uint %innerCol_1 + %312 = OpIAdd %uint %310 %311 + %313 = OpAccessChain %_ptr_Function_float %acc %312 + %314 = OpAccessChain %_ptr_Function_float %acc %312 + %315 = OpLoad %float %314 + %316 = OpLoad %float %ACached + %317 = OpLoad %uint %innerCol_1 + %318 = OpAccessChain %_ptr_Function_float %BCached %317 + %319 = OpLoad %float %318 + %320 = OpFMul %float %316 %319 + %321 = OpFAdd %float %315 %320 + OpStore %313 %321 + OpBranch %302 + %302 = OpLabel + %322 = OpLoad %uint %innerCol_1 + %323 = OpIAdd %uint %322 %uint_1 + OpStore %innerCol_1 %323 + OpBranch %300 + %301 = OpLabel + OpBranch %287 + %287 = OpLabel + %324 = OpLoad %uint %innerRow_1 + %325 = OpIAdd %uint %324 %uint_1 + OpStore %innerRow_1 %325 + OpBranch %285 + %286 = OpLabel + OpBranch %258 + %258 = OpLabel + %326 = OpLoad %uint %k + %327 = OpIAdd %uint %326 %uint_1 + OpStore %k %327 + OpBranch %256 + %257 = OpLabel + OpControlBarrier %uint_2 %uint_2 %uint_264 + OpBranch %176 %176 = OpLabel + %329 = OpLoad %uint %t + %330 = OpIAdd %uint %329 %uint_1 + OpStore %t %330 + OpBranch %174 + %175 = OpLabel OpStore %innerRow_2 %105 - OpBranch %333 - %333 = OpLabel - OpLoopMerge %334 %335 None - OpBranch %336 - %336 = OpLabel - %338 = OpLoad %uint %innerRow_2 - %339 = OpULessThan %bool %338 %uint_4 - %337 = OpLogicalNot %bool %339 - OpSelectionMerge %340 None - OpBranchConditional %337 %341 %340 - %341 = OpLabel - OpBranch %334 - %340 = OpLabel - OpStore %innerCol_2 %105 - OpBranch %343 - %343 = OpLabel - OpLoopMerge %344 %345 None - OpBranch %346 - %346 = OpLabel - %348 = OpLoad %uint %innerCol_2 - %349 = OpULessThan %bool %348 %uint_4 - %347 = OpLogicalNot %bool %349 - OpSelectionMerge %350 None - OpBranchConditional %347 %351 %350 - %351 = OpLabel - OpBranch %344 - %350 = OpLabel - %352 = OpLoad %uint %innerRow_2 - %353 = OpIMul %uint %352 %uint_4 - %354 = OpLoad %uint %innerCol_2 - %355 = OpIAdd %uint %353 %354 - %357 = OpLoad %uint %innerRow_2 - %358 = OpIAdd %uint %134 %357 - %359 = OpLoad %uint %innerCol_2 - %360 = OpIAdd %uint %136 %359 - %361 = OpAccessChain %_ptr_Function_float %acc %355 - %362 = OpLoad %float %361 - %356 = OpFunctionCall %void %mm_write %358 %360 %362 - OpBranch %345 - %345 = OpLabel - %363 = OpLoad %uint %innerCol_2 - %364 = OpIAdd %uint %363 %uint_1 - OpStore %innerCol_2 %364 - OpBranch %343 - %344 = OpLabel + OpBranch %332 + %332 = OpLabel + OpLoopMerge %333 %334 None OpBranch %335 %335 = OpLabel - %365 = OpLoad %uint %innerRow_2 - %366 = OpIAdd %uint %365 %uint_1 - OpStore %innerRow_2 %366 + %337 = OpLoad %uint %innerRow_2 + %338 = OpULessThan %bool %337 %uint_4 + %336 = OpLogicalNot %bool %338 + OpSelectionMerge %339 None + OpBranchConditional %336 %340 %339 + %340 = OpLabel OpBranch %333 + %339 = OpLabel + OpStore %innerCol_2 %105 + OpBranch %342 + %342 = OpLabel + OpLoopMerge %343 %344 None + OpBranch %345 + %345 = OpLabel + %347 = OpLoad %uint %innerCol_2 + %348 = OpULessThan %bool %347 %uint_4 + %346 = OpLogicalNot %bool %348 + OpSelectionMerge %349 None + OpBranchConditional %346 %350 %349 + %350 = OpLabel + OpBranch %343 + %349 = OpLabel + %351 = OpLoad %uint %innerRow_2 + %352 = OpIMul %uint %351 %uint_4 + %353 = OpLoad %uint %innerCol_2 + %354 = OpIAdd %uint %352 %353 + %356 = OpLoad %uint %innerRow_2 + %357 = OpIAdd %uint %134 %356 + %358 = OpLoad %uint %innerCol_2 + %359 = OpIAdd %uint %136 %358 + %360 = OpAccessChain %_ptr_Function_float %acc %354 + %361 = OpLoad %float %360 + %355 = OpFunctionCall %void %mm_write %357 %359 %361 + OpBranch %344 + %344 = OpLabel + %362 = OpLoad %uint %innerCol_2 + %363 = OpIAdd %uint %362 %uint_1 + OpStore %innerCol_2 %363 + OpBranch %342 + %343 = OpLabel + OpBranch %334 %334 = OpLabel + %364 = OpLoad %uint %innerRow_2 + %365 = OpIAdd %uint %364 %uint_1 + OpStore %innerRow_2 %365 + OpBranch %332 + %333 = OpLabel OpReturn OpFunctionEnd - %main = OpFunction %void None %367 - %369 = OpLabel - %371 = OpLoad %v3uint %local_id_1 - %372 = OpLoad %v3uint %global_id_1 - %373 = OpLoad %uint %local_invocation_index_1 - %370 = OpFunctionCall %void %main_inner %371 %372 %373 + %main = OpFunction %void None %366 + %368 = OpLabel + %370 = OpLoad %v3uint %local_id_1 + %371 = OpLoad %v3uint %global_id_1 + %372 = OpLoad %uint %local_invocation_index_1 + %369 = OpFunctionCall %void %main_inner %370 %371 %372 OpReturn OpFunctionEnd diff --git a/test/tint/expressions/binary/mul/scalar-scalar/f16.wgsl.expected.dxc.hlsl b/test/tint/expressions/binary/mul/scalar-scalar/f16.wgsl.expected.dxc.hlsl index c3dc99b5f4..f832a2f7ec 100644 --- a/test/tint/expressions/binary/mul/scalar-scalar/f16.wgsl.expected.dxc.hlsl +++ b/test/tint/expressions/binary/mul/scalar-scalar/f16.wgsl.expected.dxc.hlsl @@ -1,5 +1,5 @@ [numthreads(1, 1, 1)] void f() { - const float16_t r = (float16_t(1.0h) * float16_t(2.0h)); + const float16_t r = float16_t(2.0h); return; } diff --git a/test/tint/expressions/binary/mul/scalar-scalar/f16.wgsl.expected.glsl b/test/tint/expressions/binary/mul/scalar-scalar/f16.wgsl.expected.glsl index ad5aa1a24b..5d151fae32 100644 --- a/test/tint/expressions/binary/mul/scalar-scalar/f16.wgsl.expected.glsl +++ b/test/tint/expressions/binary/mul/scalar-scalar/f16.wgsl.expected.glsl @@ -2,7 +2,7 @@ #extension GL_AMD_gpu_shader_half_float : require void f() { - float16_t r = (1.0hf * 2.0hf); + float16_t r = 2.0hf; } layout(local_size_x = 1, local_size_y = 1, local_size_z = 1) in; diff --git a/test/tint/expressions/binary/mul/scalar-scalar/f32.wgsl.expected.dxc.hlsl b/test/tint/expressions/binary/mul/scalar-scalar/f32.wgsl.expected.dxc.hlsl index a2d27f9b82..0f4b84fd27 100644 --- a/test/tint/expressions/binary/mul/scalar-scalar/f32.wgsl.expected.dxc.hlsl +++ b/test/tint/expressions/binary/mul/scalar-scalar/f32.wgsl.expected.dxc.hlsl @@ -1,5 +1,5 @@ [numthreads(1, 1, 1)] void f() { - const float r = (1.0f * 2.0f); + const float r = 2.0f; return; } diff --git a/test/tint/expressions/binary/mul/scalar-scalar/f32.wgsl.expected.fxc.hlsl b/test/tint/expressions/binary/mul/scalar-scalar/f32.wgsl.expected.fxc.hlsl index a2d27f9b82..0f4b84fd27 100644 --- a/test/tint/expressions/binary/mul/scalar-scalar/f32.wgsl.expected.fxc.hlsl +++ b/test/tint/expressions/binary/mul/scalar-scalar/f32.wgsl.expected.fxc.hlsl @@ -1,5 +1,5 @@ [numthreads(1, 1, 1)] void f() { - const float r = (1.0f * 2.0f); + const float r = 2.0f; return; } diff --git a/test/tint/expressions/binary/mul/scalar-scalar/f32.wgsl.expected.glsl b/test/tint/expressions/binary/mul/scalar-scalar/f32.wgsl.expected.glsl index c0865da86b..d9a5099b0e 100644 --- a/test/tint/expressions/binary/mul/scalar-scalar/f32.wgsl.expected.glsl +++ b/test/tint/expressions/binary/mul/scalar-scalar/f32.wgsl.expected.glsl @@ -1,7 +1,7 @@ #version 310 es void f() { - float r = (1.0f * 2.0f); + float r = 2.0f; } layout(local_size_x = 1, local_size_y = 1, local_size_z = 1) in; diff --git a/test/tint/expressions/binary/mul/scalar-scalar/i32.wgsl.expected.dxc.hlsl b/test/tint/expressions/binary/mul/scalar-scalar/i32.wgsl.expected.dxc.hlsl index 4b93ecb2bf..1207f53182 100644 --- a/test/tint/expressions/binary/mul/scalar-scalar/i32.wgsl.expected.dxc.hlsl +++ b/test/tint/expressions/binary/mul/scalar-scalar/i32.wgsl.expected.dxc.hlsl @@ -1,5 +1,5 @@ [numthreads(1, 1, 1)] void f() { - const int r = (1 * 2); + const int r = 2; return; } diff --git a/test/tint/expressions/binary/mul/scalar-scalar/i32.wgsl.expected.fxc.hlsl b/test/tint/expressions/binary/mul/scalar-scalar/i32.wgsl.expected.fxc.hlsl index 4b93ecb2bf..1207f53182 100644 --- a/test/tint/expressions/binary/mul/scalar-scalar/i32.wgsl.expected.fxc.hlsl +++ b/test/tint/expressions/binary/mul/scalar-scalar/i32.wgsl.expected.fxc.hlsl @@ -1,5 +1,5 @@ [numthreads(1, 1, 1)] void f() { - const int r = (1 * 2); + const int r = 2; return; } diff --git a/test/tint/expressions/binary/mul/scalar-scalar/i32.wgsl.expected.glsl b/test/tint/expressions/binary/mul/scalar-scalar/i32.wgsl.expected.glsl index deb72c9d56..bde01222d3 100644 --- a/test/tint/expressions/binary/mul/scalar-scalar/i32.wgsl.expected.glsl +++ b/test/tint/expressions/binary/mul/scalar-scalar/i32.wgsl.expected.glsl @@ -1,7 +1,7 @@ #version 310 es void f() { - int r = (1 * 2); + int r = 2; } layout(local_size_x = 1, local_size_y = 1, local_size_z = 1) in; diff --git a/test/tint/expressions/binary/mul/scalar-scalar/u32.wgsl.expected.dxc.hlsl b/test/tint/expressions/binary/mul/scalar-scalar/u32.wgsl.expected.dxc.hlsl index 4623783d4f..7816c13d18 100644 --- a/test/tint/expressions/binary/mul/scalar-scalar/u32.wgsl.expected.dxc.hlsl +++ b/test/tint/expressions/binary/mul/scalar-scalar/u32.wgsl.expected.dxc.hlsl @@ -1,5 +1,5 @@ [numthreads(1, 1, 1)] void f() { - const uint r = (1u * 2u); + const uint r = 2u; return; } diff --git a/test/tint/expressions/binary/mul/scalar-scalar/u32.wgsl.expected.fxc.hlsl b/test/tint/expressions/binary/mul/scalar-scalar/u32.wgsl.expected.fxc.hlsl index 4623783d4f..7816c13d18 100644 --- a/test/tint/expressions/binary/mul/scalar-scalar/u32.wgsl.expected.fxc.hlsl +++ b/test/tint/expressions/binary/mul/scalar-scalar/u32.wgsl.expected.fxc.hlsl @@ -1,5 +1,5 @@ [numthreads(1, 1, 1)] void f() { - const uint r = (1u * 2u); + const uint r = 2u; return; } diff --git a/test/tint/expressions/binary/mul/scalar-scalar/u32.wgsl.expected.glsl b/test/tint/expressions/binary/mul/scalar-scalar/u32.wgsl.expected.glsl index 077b8603b0..57feb4db57 100644 --- a/test/tint/expressions/binary/mul/scalar-scalar/u32.wgsl.expected.glsl +++ b/test/tint/expressions/binary/mul/scalar-scalar/u32.wgsl.expected.glsl @@ -1,7 +1,7 @@ #version 310 es void f() { - uint r = (1u * 2u); + uint r = 2u; } layout(local_size_x = 1, local_size_y = 1, local_size_z = 1) in; diff --git a/test/tint/samples/function.wgsl.expected.dxc.hlsl b/test/tint/samples/function.wgsl.expected.dxc.hlsl index 9de43a2241..7245500ee9 100644 --- a/test/tint/samples/function.wgsl.expected.dxc.hlsl +++ b/test/tint/samples/function.wgsl.expected.dxc.hlsl @@ -1,5 +1,5 @@ float main() { - return (((2.0f * 3.0f) - 4.0f) / 5.0f); + return (2.0f / 5.0f); } [numthreads(2, 1, 1)] diff --git a/test/tint/samples/function.wgsl.expected.fxc.hlsl b/test/tint/samples/function.wgsl.expected.fxc.hlsl index 9de43a2241..7245500ee9 100644 --- a/test/tint/samples/function.wgsl.expected.fxc.hlsl +++ b/test/tint/samples/function.wgsl.expected.fxc.hlsl @@ -1,5 +1,5 @@ float main() { - return (((2.0f * 3.0f) - 4.0f) / 5.0f); + return (2.0f / 5.0f); } [numthreads(2, 1, 1)] diff --git a/test/tint/samples/function.wgsl.expected.msl b/test/tint/samples/function.wgsl.expected.msl index 4ab43dd4cd..b0f4f9bb25 100644 --- a/test/tint/samples/function.wgsl.expected.msl +++ b/test/tint/samples/function.wgsl.expected.msl @@ -2,7 +2,7 @@ using namespace metal; float tint_symbol() { - return (((2.0f * 3.0f) - 4.0f) / 5.0f); + return (2.0f / 5.0f); } kernel void ep() { diff --git a/test/tint/samples/function.wgsl.expected.spvasm b/test/tint/samples/function.wgsl.expected.spvasm index affe05169a..d6ef6f791c 100644 --- a/test/tint/samples/function.wgsl.expected.spvasm +++ b/test/tint/samples/function.wgsl.expected.spvasm @@ -1,7 +1,7 @@ ; SPIR-V ; Version: 1.3 ; Generator: Google Tint Compiler; 0 -; Bound: 16 +; Bound: 12 ; Schema: 0 OpCapability Shader OpMemoryModel Logical GLSL450 @@ -12,19 +12,15 @@ %float = OpTypeFloat 32 %1 = OpTypeFunction %float %float_2 = OpConstant %float 2 - %float_3 = OpConstant %float 3 - %float_4 = OpConstant %float 4 %float_5 = OpConstant %float 5 %void = OpTypeVoid - %12 = OpTypeFunction %void + %8 = OpTypeFunction %void %main = OpFunction %float None %1 %4 = OpLabel - %7 = OpFMul %float %float_2 %float_3 - %9 = OpFSub %float %7 %float_4 - %11 = OpFDiv %float %9 %float_5 - OpReturnValue %11 + %7 = OpFDiv %float %float_2 %float_5 + OpReturnValue %7 OpFunctionEnd - %ep = OpFunction %void None %12 - %15 = OpLabel + %ep = OpFunction %void None %8 + %11 = OpLabel OpReturn OpFunctionEnd