From 9707e6bb38c249e49e372b46bd8492d32e20315d Mon Sep 17 00:00:00 2001 From: Ben Clayton Date: Wed, 25 May 2022 19:28:55 +0000 Subject: [PATCH] tint: Rework sem::Constant to be variant-of-vector MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Instead of vector-of-variant. This: • Makes it impossible to produce a mix of scalar variant types, which would make no sense. • Reduces the size of a Constant, by removing the union-tag from each element. Also clean up terminology. Rename 'Constant::Scalar' to 'Constant::Element'. Scalars are well-defined in WGSL, and with the introduction of abstract-numerics, this no longer makes sense. Bug: tint:1504 Change-Id: I599aa97ad1ea798b7db8e512a5990ba75827faad Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/91304 Reviewed-by: Antonio Maiorano Kokoro: Kokoro Commit-Queue: Ben Clayton Commit-Queue: Ben Clayton --- src/tint/BUILD.gn | 2 + src/tint/CMakeLists.txt | 3 + src/tint/resolver/materialize_test.cc | 12 +- src/tint/resolver/resolver.cc | 6 +- src/tint/resolver/resolver.h | 9 +- src/tint/resolver/resolver_constants.cc | 156 +++++++++------ src/tint/resolver/resolver_constants_test.cc | 48 +++-- src/tint/resolver/validator.cc | 2 +- src/tint/sem/constant.cc | 57 +++--- src/tint/sem/constant.h | 132 +++++++++--- src/tint/sem/constant_test.cc | 199 +++++++++++++++++++ src/tint/transform/fold_constants.cc | 45 +++-- src/tint/utils/compiler_macros.h | 39 ++++ src/tint/writer/hlsl/generator_impl.cc | 3 +- src/tint/writer/spirv/builder.cc | 2 +- test/tint/BUILD.gn | 1 + 16 files changed, 545 insertions(+), 171 deletions(-) create mode 100644 src/tint/sem/constant_test.cc create mode 100644 src/tint/utils/compiler_macros.h diff --git a/src/tint/BUILD.gn b/src/tint/BUILD.gn index 59ed51c749..0805d5d48e 100644 --- a/src/tint/BUILD.gn +++ b/src/tint/BUILD.gn @@ -529,6 +529,8 @@ libtint_source_set("libtint_core_all_src") { "transform/zero_init_workgroup_memory.h", "utils/bitcast.h", "utils/block_allocator.h", + "utils/compiler_macros.h", + "utils/concat.h", "utils/crc32.h", "utils/debugger.cc", "utils/debugger.h", diff --git a/src/tint/CMakeLists.txt b/src/tint/CMakeLists.txt index 64b25a6177..2fd7d0d108 100644 --- a/src/tint/CMakeLists.txt +++ b/src/tint/CMakeLists.txt @@ -456,6 +456,8 @@ set(TINT_LIB_SRCS transform/zero_init_workgroup_memory.h utils/bitcast.h utils/block_allocator.h + utils/compiler_macros.h + utils/concat.h utils/crc32.h utils/enum_set.h utils/hash.h @@ -801,6 +803,7 @@ if(TINT_BUILD_TESTS) sem/atomic.cc sem/bool_test.cc sem/builtin_test.cc + sem/constant_test.cc sem/depth_multisampled_texture_test.cc sem/depth_texture_test.cc sem/expression_test.cc diff --git a/src/tint/resolver/materialize_test.cc b/src/tint/resolver/materialize_test.cc index e2e1a02ff1..d2cb47205b 100644 --- a/src/tint/resolver/materialize_test.cc +++ b/src/tint/resolver/materialize_test.cc @@ -252,11 +252,13 @@ TEST_P(MaterializeAbstractNumeric, Test) { uint32_t num_elems = 0; const sem::Type* target_sem_el_ty = sem::Type::ElementOf(target_sem_ty, &num_elems); EXPECT_TYPE(expr->ConstantValue().ElementType(), target_sem_el_ty); - std::visit( - [&](auto&& v) { - EXPECT_EQ(expr->ConstantValue().Elements(), sem::Constant::Scalars(num_elems, {v})); - }, - data.materialized_value); + expr->ConstantValue().WithElements([&](auto&& vec) { + using VEC_TY = std::decay_t; + using EL_TY = typename VEC_TY::value_type; + ASSERT_TRUE(std::holds_alternative(data.materialized_value)); + VEC_TY expected(num_elems, std::get(data.materialized_value)); + EXPECT_EQ(vec, expected); + }); }; switch (expectation) { diff --git a/src/tint/resolver/resolver.cc b/src/tint/resolver/resolver.cc index 9d3d7a4360..afa41c4f68 100644 --- a/src/tint/resolver/resolver.cc +++ b/src/tint/resolver/resolver.cc @@ -806,7 +806,7 @@ bool Resolver::WorkgroupSize(const ast::Function* func) { return false; } - ws[i].value = static_cast(value.Element(0).value); + ws[i].value = value.Element(0); } current_function_->SetWorkgroupSize(std::move(ws)); @@ -1119,7 +1119,7 @@ const sem::Expression* Resolver::Materialize(const sem::Expression* expr, << (expr->Type() ? expr->Type()->FriendlyName(builder_->Symbols()) : ""); return nullptr; } - auto materialized_val = ConstantCast(expr_val, target_ty); + auto materialized_val = ConvertValue(expr_val, target_ty); auto* m = builder_->create(expr, current_statement_, materialized_val); m->Behaviors() = expr->Behaviors(); builder_->Sem().Replace(expr->Declaration(), m); @@ -2022,7 +2022,7 @@ sem::Array* Resolver::Array(const ast::Array* arr) { return nullptr; } - count = static_cast(count_val.Element(0).value); + count = count_val.Element(0); } auto size = std::max(count, 1) * stride; diff --git a/src/tint/resolver/resolver.h b/src/tint/resolver/resolver.h index b03bb32c52..07e963640f 100644 --- a/src/tint/resolver/resolver.h +++ b/src/tint/resolver/resolver.h @@ -354,11 +354,10 @@ class Resolver { ////////////////////////////////////////////////////////////////////////////// /// Constant value evaluation methods ////////////////////////////////////////////////////////////////////////////// - /// Cast `Value` to `target_type` - /// @return the casted value - sem::Constant ConstantCast(const sem::Constant& value, - const sem::Type* target_type, - const sem::Type* target_element_type = nullptr); + + /// Convert the `value` to `target_type` + /// @return the converted value + sem::Constant ConvertValue(const sem::Constant& value, const sem::Type* target_type); sem::Constant EvaluateConstantValue(const ast::Expression* expr, const sem::Type* type); sem::Constant EvaluateConstantValue(const ast::LiteralExpression* literal, diff --git a/src/tint/resolver/resolver_constants.cc b/src/tint/resolver/resolver_constants.cc index 265b06d68f..5d1f389d35 100644 --- a/src/tint/resolver/resolver_constants.cc +++ b/src/tint/resolver/resolver_constants.cc @@ -14,39 +14,70 @@ #include "src/tint/resolver/resolver.h" +#include + #include "src/tint/sem/abstract_float.h" #include "src/tint/sem/abstract_int.h" #include "src/tint/sem/constant.h" #include "src/tint/sem/type_constructor.h" +#include "src/tint/utils/compiler_macros.h" #include "src/tint/utils/map.h" +#include "src/tint/utils/transform.h" using namespace tint::number_suffixes; // NOLINT namespace tint::resolver { + namespace { -sem::Constant::Scalars CastScalars(sem::Constant::Scalars in, const sem::Type* target_type) { - sem::Constant::Scalars out; - out.reserve(in.size()); - for (auto v : in) { - // TODO(crbug.com/tint/1504): Check that value fits in new type - out.emplace_back(Switch( - target_type, // - [&](const sem::AbstractInt*) { return sem::Constant::Cast(v); }, - [&](const sem::AbstractFloat*) { return sem::Constant::Cast(v); }, - [&](const sem::I32*) { return sem::Constant::Cast(v); }, - [&](const sem::U32*) { return sem::Constant::Cast(v); }, - [&](const sem::F32*) { return sem::Constant::Cast(v); }, - [&](const sem::F16*) { return sem::Constant::Cast(v); }, - [&](const sem::Bool*) { return sem::Constant::Cast(v); }, - [&](Default) { - diag::List diags; - TINT_UNREACHABLE(Semantic, diags) - << "invalid element type " << target_type->TypeInfo().name; - return sem::Constant::Scalar(false); - })); - } - return out; +/// Converts all the element values of `in` to the type `T`. +/// @param elements_in the vector of elements to be converted +/// @returns the elements converted to type T. +template +sem::Constant::Elements Convert(const ELEMENTS_IN& elements_in) { + TINT_BEGIN_DISABLE_WARNING_UNREACHABLE_CODE(); + + using E = UnwrapNumber; + return utils::Transform(elements_in, [&](auto value_in) { + if constexpr (std::is_same_v) { + return AInt(value_in != 0); + } + + E converted = static_cast(value_in); + if constexpr (IsFloatingPoint) { + return AFloat(converted); + } else { + return AInt(converted); + } + }); + + TINT_END_DISABLE_WARNING_UNREACHABLE_CODE(); +} + +/// Converts and returns all the element values of `in` to the semantic type `el_ty`. +/// @param in the constant to convert +/// @param el_ty the target element type +/// @returns the elements converted to `type` +sem::Constant::Elements Convert(const sem::Constant::Elements& in, const sem::Type* el_ty) { + return std::visit( + [&](auto&& v) { + return Switch( + el_ty, // + [&](const sem::AbstractInt*) { return Convert(v); }, + [&](const sem::AbstractFloat*) { return Convert(v); }, + [&](const sem::I32*) { return Convert(v); }, + [&](const sem::U32*) { return Convert(v); }, + [&](const sem::F32*) { return Convert(v); }, + [&](const sem::F16*) { return Convert(v); }, + [&](const sem::Bool*) { return Convert(v); }, + [&](Default) -> sem::Constant::Elements { + diag::List diags; + TINT_UNREACHABLE(Semantic, diags) + << "invalid element type " << el_ty->TypeInfo().name; + return {}; + }); + }, + in); } } // namespace @@ -72,43 +103,42 @@ sem::Constant Resolver::EvaluateConstantValue(const ast::LiteralExpression* lite return sem::Constant{type, {AFloat(lit->value)}}; }, [&](const ast::BoolLiteralExpression* lit) { - return sem::Constant{type, {lit->value}}; + return sem::Constant{type, {AInt(lit->value ? 1 : 0)}}; }); } sem::Constant Resolver::EvaluateConstantValue(const ast::CallExpression* call, - const sem::Type* type) { + const sem::Type* ty) { uint32_t result_size = 0; - auto* el_ty = sem::Type::ElementOf(type, &result_size); + auto* el_ty = sem::Type::ElementOf(ty, &result_size); if (!el_ty) { return {}; } // ElementOf() will also return the element type of array, which we do not support. - if (type->Is()) { + if (ty->Is()) { return {}; } // For zero value init, return 0s if (call->args.empty()) { - using Scalars = sem::Constant::Scalars; return Switch( el_ty, [&](const sem::AbstractInt*) { - return sem::Constant(type, Scalars(result_size, AInt(0))); + return sem::Constant(ty, std::vector(result_size, AInt(0))); }, [&](const sem::AbstractFloat*) { - return sem::Constant(type, Scalars(result_size, AFloat(0))); + return sem::Constant(ty, std::vector(result_size, AFloat(0))); }, - [&](const sem::I32*) { return sem::Constant(type, Scalars(result_size, AInt(0))); }, - [&](const sem::U32*) { return sem::Constant(type, Scalars(result_size, AInt(0))); }, - [&](const sem::F32*) { return sem::Constant(type, Scalars(result_size, AFloat(0))); }, - [&](const sem::F16*) { return sem::Constant(type, Scalars(result_size, AFloat(0))); }, - [&](const sem::Bool*) { return sem::Constant(type, Scalars(result_size, false)); }); + [&](const sem::I32*) { return sem::Constant(ty, std::vector(result_size, AInt(0))); }, + [&](const sem::U32*) { return sem::Constant(ty, std::vector(result_size, AInt(0))); }, + [&](const sem::F32*) { return sem::Constant(ty, std::vector(result_size, AFloat(0))); }, + [&](const sem::F16*) { return sem::Constant(ty, std::vector(result_size, AFloat(0))); }, + [&](const sem::Bool*) { return sem::Constant(ty, std::vector(result_size, AInt(0))); }); } - // Build value for type_ctor from each child value by casting to type_ctor's type. - sem::Constant::Scalars elems; + // Build value for type_ctor from each child value by converting to type_ctor's type. + std::optional elements; for (auto* expr : call->args) { auto* arg = builder_->Sem().Get(expr); if (!arg) { @@ -118,42 +148,52 @@ sem::Constant Resolver::EvaluateConstantValue(const ast::CallExpression* call, if (!value) { return {}; } - elems.insert(elems.end(), value.Elements().begin(), value.Elements().end()); - } - // Splat single-value initializers - if (elems.size() == 1) { - for (uint32_t i = 0; i < result_size - 1; ++i) { - elems.emplace_back(elems[0]); + // Convert the elements to the desired type. + auto converted = Convert(value.GetElements(), el_ty); + + if (elements.has_value()) { + // Append the converted vector to elements + std::visit( + [&](auto&& dst) { + using VEC_TY = std::decay_t; + const auto& src = std::get(converted); + dst.insert(dst.end(), src.begin(), src.end()); + }, + elements.value()); + } else { + elements = std::move(converted); } } - // Finally cast the elements to the desired type. - auto cast = CastScalars(elems, el_ty); + // Splat single-value initializers + std::visit( + [&](auto&& v) { + if (v.size() == 1) { + for (uint32_t i = 0; i < result_size - 1; ++i) { + v.emplace_back(v[0]); + } + } + }, + elements.value()); - return sem::Constant(type, std::move(cast)); + return sem::Constant(ty, std::move(elements.value())); } -sem::Constant Resolver::ConstantCast(const sem::Constant& value, - const sem::Type* target_type, - const sem::Type* target_element_type /* = nullptr */) { - if (value.Type() == target_type) { +sem::Constant Resolver::ConvertValue(const sem::Constant& value, const sem::Type* ty) { + if (value.Type() == ty) { return value; } - if (target_element_type == nullptr) { - target_element_type = sem::Type::ElementOf(target_type); - } - if (target_element_type == nullptr) { + auto* el_ty = sem::Type::ElementOf(ty); + if (el_ty == nullptr) { return {}; } - if (value.ElementType() == target_element_type) { - return sem::Constant(target_type, value.Elements()); + if (value.ElementType() == el_ty) { + return sem::Constant(ty, value.GetElements()); } - auto elems = CastScalars(value.Elements(), target_element_type); - - return sem::Constant(target_type, elems); + return sem::Constant(ty, Convert(value.GetElements(), el_ty)); } } // namespace tint::resolver diff --git a/src/tint/resolver/resolver_constants_test.cc b/src/tint/resolver/resolver_constants_test.cc index 05e6e4c4dc..7be067a918 100644 --- a/src/tint/resolver/resolver_constants_test.cc +++ b/src/tint/resolver/resolver_constants_test.cc @@ -23,8 +23,6 @@ using namespace tint::number_suffixes; // NOLINT namespace tint::resolver { namespace { -using Scalar = sem::Constant::Scalar; - using ResolverConstantsTest = ResolverTest; TEST_F(ResolverConstantsTest, Scalar_i32) { @@ -38,7 +36,7 @@ TEST_F(ResolverConstantsTest, Scalar_i32) { EXPECT_TRUE(sem->Type()->Is()); EXPECT_EQ(sem->ConstantValue().Type(), sem->Type()); EXPECT_EQ(sem->ConstantValue().ElementType(), sem->Type()); - ASSERT_EQ(sem->ConstantValue().Elements().size(), 1u); + ASSERT_EQ(sem->ConstantValue().ElementCount(), 1u); EXPECT_EQ(sem->ConstantValue().Element(0).value, 99); } @@ -53,7 +51,7 @@ TEST_F(ResolverConstantsTest, Scalar_u32) { EXPECT_TRUE(sem->Type()->Is()); EXPECT_EQ(sem->ConstantValue().Type(), sem->Type()); EXPECT_EQ(sem->ConstantValue().ElementType(), sem->Type()); - ASSERT_EQ(sem->ConstantValue().Elements().size(), 1u); + ASSERT_EQ(sem->ConstantValue().ElementCount(), 1u); EXPECT_EQ(sem->ConstantValue().Element(0).value, 99u); } @@ -68,7 +66,7 @@ TEST_F(ResolverConstantsTest, Scalar_f32) { EXPECT_TRUE(sem->Type()->Is()); EXPECT_EQ(sem->ConstantValue().Type(), sem->Type()); EXPECT_EQ(sem->ConstantValue().ElementType(), sem->Type()); - ASSERT_EQ(sem->ConstantValue().Elements().size(), 1u); + ASSERT_EQ(sem->ConstantValue().ElementCount(), 1u); EXPECT_EQ(sem->ConstantValue().Element(0).value, 9.9f); } @@ -83,7 +81,7 @@ TEST_F(ResolverConstantsTest, Scalar_bool) { EXPECT_TRUE(sem->Type()->Is()); EXPECT_EQ(sem->ConstantValue().Type(), sem->Type()); EXPECT_EQ(sem->ConstantValue().ElementType(), sem->Type()); - ASSERT_EQ(sem->ConstantValue().Elements().size(), 1u); + ASSERT_EQ(sem->ConstantValue().ElementCount(), 1u); EXPECT_EQ(sem->ConstantValue().Element(0), true); } @@ -100,7 +98,7 @@ TEST_F(ResolverConstantsTest, Vec3_ZeroInit_i32) { EXPECT_EQ(sem->Type()->As()->Width(), 3u); EXPECT_EQ(sem->ConstantValue().Type(), sem->Type()); EXPECT_TRUE(sem->ConstantValue().ElementType()->Is()); - ASSERT_EQ(sem->ConstantValue().Elements().size(), 3u); + ASSERT_EQ(sem->ConstantValue().ElementCount(), 3u); EXPECT_EQ(sem->ConstantValue().Element(0).value, 0); EXPECT_EQ(sem->ConstantValue().Element(1).value, 0); EXPECT_EQ(sem->ConstantValue().Element(2).value, 0); @@ -119,7 +117,7 @@ TEST_F(ResolverConstantsTest, Vec3_ZeroInit_u32) { EXPECT_EQ(sem->Type()->As()->Width(), 3u); EXPECT_EQ(sem->ConstantValue().Type(), sem->Type()); EXPECT_TRUE(sem->ConstantValue().ElementType()->Is()); - ASSERT_EQ(sem->ConstantValue().Elements().size(), 3u); + ASSERT_EQ(sem->ConstantValue().ElementCount(), 3u); EXPECT_EQ(sem->ConstantValue().Element(0).value, 0u); EXPECT_EQ(sem->ConstantValue().Element(1).value, 0u); EXPECT_EQ(sem->ConstantValue().Element(2).value, 0u); @@ -138,7 +136,7 @@ TEST_F(ResolverConstantsTest, Vec3_ZeroInit_f32) { EXPECT_EQ(sem->Type()->As()->Width(), 3u); EXPECT_EQ(sem->ConstantValue().Type(), sem->Type()); EXPECT_TRUE(sem->ConstantValue().ElementType()->Is()); - ASSERT_EQ(sem->ConstantValue().Elements().size(), 3u); + ASSERT_EQ(sem->ConstantValue().ElementCount(), 3u); EXPECT_EQ(sem->ConstantValue().Element(0).value, 0.0); EXPECT_EQ(sem->ConstantValue().Element(1).value, 0.0); EXPECT_EQ(sem->ConstantValue().Element(2).value, 0.0); @@ -157,7 +155,7 @@ TEST_F(ResolverConstantsTest, Vec3_ZeroInit_bool) { EXPECT_EQ(sem->Type()->As()->Width(), 3u); EXPECT_EQ(sem->ConstantValue().Type(), sem->Type()); EXPECT_TRUE(sem->ConstantValue().ElementType()->Is()); - ASSERT_EQ(sem->ConstantValue().Elements().size(), 3u); + ASSERT_EQ(sem->ConstantValue().ElementCount(), 3u); EXPECT_EQ(sem->ConstantValue().Element(0), false); EXPECT_EQ(sem->ConstantValue().Element(1), false); EXPECT_EQ(sem->ConstantValue().Element(2), false); @@ -176,7 +174,7 @@ TEST_F(ResolverConstantsTest, Vec3_Splat_i32) { EXPECT_EQ(sem->Type()->As()->Width(), 3u); EXPECT_EQ(sem->ConstantValue().Type(), sem->Type()); EXPECT_TRUE(sem->ConstantValue().ElementType()->Is()); - ASSERT_EQ(sem->ConstantValue().Elements().size(), 3u); + ASSERT_EQ(sem->ConstantValue().ElementCount(), 3u); EXPECT_EQ(sem->ConstantValue().Element(0).value, 99); EXPECT_EQ(sem->ConstantValue().Element(1).value, 99); EXPECT_EQ(sem->ConstantValue().Element(2).value, 99); @@ -195,7 +193,7 @@ TEST_F(ResolverConstantsTest, Vec3_Splat_u32) { EXPECT_EQ(sem->Type()->As()->Width(), 3u); EXPECT_EQ(sem->ConstantValue().Type(), sem->Type()); EXPECT_TRUE(sem->ConstantValue().ElementType()->Is()); - ASSERT_EQ(sem->ConstantValue().Elements().size(), 3u); + ASSERT_EQ(sem->ConstantValue().ElementCount(), 3u); EXPECT_EQ(sem->ConstantValue().Element(0).value, 99u); EXPECT_EQ(sem->ConstantValue().Element(1).value, 99u); EXPECT_EQ(sem->ConstantValue().Element(2).value, 99u); @@ -214,7 +212,7 @@ TEST_F(ResolverConstantsTest, Vec3_Splat_f32) { EXPECT_EQ(sem->Type()->As()->Width(), 3u); EXPECT_EQ(sem->ConstantValue().Type(), sem->Type()); EXPECT_TRUE(sem->ConstantValue().ElementType()->Is()); - ASSERT_EQ(sem->ConstantValue().Elements().size(), 3u); + ASSERT_EQ(sem->ConstantValue().ElementCount(), 3u); EXPECT_EQ(sem->ConstantValue().Element(0).value, 9.9f); EXPECT_EQ(sem->ConstantValue().Element(1).value, 9.9f); EXPECT_EQ(sem->ConstantValue().Element(2).value, 9.9f); @@ -233,7 +231,7 @@ TEST_F(ResolverConstantsTest, Vec3_Splat_bool) { EXPECT_EQ(sem->Type()->As()->Width(), 3u); EXPECT_EQ(sem->ConstantValue().Type(), sem->Type()); EXPECT_TRUE(sem->ConstantValue().ElementType()->Is()); - ASSERT_EQ(sem->ConstantValue().Elements().size(), 3u); + ASSERT_EQ(sem->ConstantValue().ElementCount(), 3u); EXPECT_EQ(sem->ConstantValue().Element(0), true); EXPECT_EQ(sem->ConstantValue().Element(1), true); EXPECT_EQ(sem->ConstantValue().Element(2), true); @@ -252,7 +250,7 @@ TEST_F(ResolverConstantsTest, Vec3_FullConstruct_i32) { EXPECT_EQ(sem->Type()->As()->Width(), 3u); EXPECT_EQ(sem->ConstantValue().Type(), sem->Type()); EXPECT_TRUE(sem->ConstantValue().ElementType()->Is()); - ASSERT_EQ(sem->ConstantValue().Elements().size(), 3u); + ASSERT_EQ(sem->ConstantValue().ElementCount(), 3u); EXPECT_EQ(sem->ConstantValue().Element(0).value, 1); EXPECT_EQ(sem->ConstantValue().Element(1).value, 2); EXPECT_EQ(sem->ConstantValue().Element(2).value, 3); @@ -271,7 +269,7 @@ TEST_F(ResolverConstantsTest, Vec3_FullConstruct_u32) { EXPECT_EQ(sem->Type()->As()->Width(), 3u); EXPECT_EQ(sem->ConstantValue().Type(), sem->Type()); EXPECT_TRUE(sem->ConstantValue().ElementType()->Is()); - ASSERT_EQ(sem->ConstantValue().Elements().size(), 3u); + ASSERT_EQ(sem->ConstantValue().ElementCount(), 3u); EXPECT_EQ(sem->ConstantValue().Element(0).value, 1); EXPECT_EQ(sem->ConstantValue().Element(1).value, 2); EXPECT_EQ(sem->ConstantValue().Element(2).value, 3); @@ -290,7 +288,7 @@ TEST_F(ResolverConstantsTest, Vec3_FullConstruct_f32) { EXPECT_EQ(sem->Type()->As()->Width(), 3u); EXPECT_EQ(sem->ConstantValue().Type(), sem->Type()); EXPECT_TRUE(sem->ConstantValue().ElementType()->Is()); - ASSERT_EQ(sem->ConstantValue().Elements().size(), 3u); + ASSERT_EQ(sem->ConstantValue().ElementCount(), 3u); EXPECT_EQ(sem->ConstantValue().Element(0).value, 1.f); EXPECT_EQ(sem->ConstantValue().Element(1).value, 2.f); EXPECT_EQ(sem->ConstantValue().Element(2).value, 3.f); @@ -309,7 +307,7 @@ TEST_F(ResolverConstantsTest, Vec3_FullConstruct_bool) { EXPECT_EQ(sem->Type()->As()->Width(), 3u); EXPECT_EQ(sem->ConstantValue().Type(), sem->Type()); EXPECT_TRUE(sem->ConstantValue().ElementType()->Is()); - ASSERT_EQ(sem->ConstantValue().Elements().size(), 3u); + ASSERT_EQ(sem->ConstantValue().ElementCount(), 3u); EXPECT_EQ(sem->ConstantValue().Element(0), true); EXPECT_EQ(sem->ConstantValue().Element(1), false); EXPECT_EQ(sem->ConstantValue().Element(2), true); @@ -328,7 +326,7 @@ TEST_F(ResolverConstantsTest, Vec3_MixConstruct_i32) { EXPECT_EQ(sem->Type()->As()->Width(), 3u); EXPECT_EQ(sem->ConstantValue().Type(), sem->Type()); EXPECT_TRUE(sem->ConstantValue().ElementType()->Is()); - ASSERT_EQ(sem->ConstantValue().Elements().size(), 3u); + ASSERT_EQ(sem->ConstantValue().ElementCount(), 3u); EXPECT_EQ(sem->ConstantValue().Element(0).value, 1); EXPECT_EQ(sem->ConstantValue().Element(1).value, 2); EXPECT_EQ(sem->ConstantValue().Element(2).value, 3); @@ -347,7 +345,7 @@ TEST_F(ResolverConstantsTest, Vec3_MixConstruct_u32) { EXPECT_EQ(sem->Type()->As()->Width(), 3u); EXPECT_EQ(sem->ConstantValue().Type(), sem->Type()); EXPECT_TRUE(sem->ConstantValue().ElementType()->Is()); - ASSERT_EQ(sem->ConstantValue().Elements().size(), 3u); + ASSERT_EQ(sem->ConstantValue().ElementCount(), 3u); EXPECT_EQ(sem->ConstantValue().Element(0).value, 1); EXPECT_EQ(sem->ConstantValue().Element(1).value, 2); EXPECT_EQ(sem->ConstantValue().Element(2).value, 3); @@ -366,7 +364,7 @@ TEST_F(ResolverConstantsTest, Vec3_MixConstruct_f32) { EXPECT_EQ(sem->Type()->As()->Width(), 3u); EXPECT_EQ(sem->ConstantValue().Type(), sem->Type()); EXPECT_TRUE(sem->ConstantValue().ElementType()->Is()); - ASSERT_EQ(sem->ConstantValue().Elements().size(), 3u); + ASSERT_EQ(sem->ConstantValue().ElementCount(), 3u); EXPECT_EQ(sem->ConstantValue().Element(0).value, 1.f); EXPECT_EQ(sem->ConstantValue().Element(1).value, 2.f); EXPECT_EQ(sem->ConstantValue().Element(2).value, 3.f); @@ -385,13 +383,13 @@ TEST_F(ResolverConstantsTest, Vec3_MixConstruct_bool) { EXPECT_EQ(sem->Type()->As()->Width(), 3u); EXPECT_EQ(sem->ConstantValue().Type(), sem->Type()); EXPECT_TRUE(sem->ConstantValue().ElementType()->Is()); - ASSERT_EQ(sem->ConstantValue().Elements().size(), 3u); + ASSERT_EQ(sem->ConstantValue().ElementCount(), 3u); EXPECT_EQ(sem->ConstantValue().Element(0), true); EXPECT_EQ(sem->ConstantValue().Element(1), false); EXPECT_EQ(sem->ConstantValue().Element(2), true); } -TEST_F(ResolverConstantsTest, Vec3_Cast_f32_to_32) { +TEST_F(ResolverConstantsTest, Vec3_Cast_f32_to_i32) { auto* expr = vec3(vec3(1.1_f, 2.2_f, 3.3_f)); WrapInFunction(expr); @@ -404,7 +402,7 @@ TEST_F(ResolverConstantsTest, Vec3_Cast_f32_to_32) { EXPECT_EQ(sem->Type()->As()->Width(), 3u); EXPECT_EQ(sem->ConstantValue().Type(), sem->Type()); EXPECT_TRUE(sem->ConstantValue().ElementType()->Is()); - ASSERT_EQ(sem->ConstantValue().Elements().size(), 3u); + ASSERT_EQ(sem->ConstantValue().ElementCount(), 3u); EXPECT_EQ(sem->ConstantValue().Element(0).value, 1); EXPECT_EQ(sem->ConstantValue().Element(1).value, 2); EXPECT_EQ(sem->ConstantValue().Element(2).value, 3); @@ -423,7 +421,7 @@ TEST_F(ResolverConstantsTest, Vec3_Cast_u32_to_f32) { EXPECT_EQ(sem->Type()->As()->Width(), 3u); EXPECT_EQ(sem->ConstantValue().Type(), sem->Type()); EXPECT_TRUE(sem->ConstantValue().ElementType()->Is()); - ASSERT_EQ(sem->ConstantValue().Elements().size(), 3u); + ASSERT_EQ(sem->ConstantValue().ElementCount(), 3u); EXPECT_EQ(sem->ConstantValue().Element(0).value, 10.f); EXPECT_EQ(sem->ConstantValue().Element(1).value, 20.f); EXPECT_EQ(sem->ConstantValue().Element(2).value, 30.f); diff --git a/src/tint/resolver/validator.cc b/src/tint/resolver/validator.cc index fdd74e9576..fd599887e7 100644 --- a/src/tint/resolver/validator.cc +++ b/src/tint/resolver/validator.cc @@ -1535,7 +1535,7 @@ bool Validator::TextureBuiltinFunction(const sem::Call* call) const { }); if (is_const_expr) { auto vector = builtin->Parameters()[index]->Type()->Is(); - for (size_t i = 0; i < values.Elements().size(); i++) { + for (size_t i = 0, n = values.ElementCount(); i < n; i++) { auto value = values.Element(i).value; if (value < min || value > max) { if (vector) { diff --git a/src/tint/sem/constant.cc b/src/tint/sem/constant.cc index 98c724cf42..1fa83f5046 100644 --- a/src/tint/sem/constant.cc +++ b/src/tint/sem/constant.cc @@ -23,29 +23,19 @@ namespace tint::sem { namespace { - -const Type* CheckElemType(const Type* ty, size_t num_scalars) { - diag::List diag; - if (ty->is_abstract_or_scalar() || ty->IsAnyOf()) { - uint32_t count = 0; - auto* el_ty = Type::ElementOf(ty, &count); - if (num_scalars != count) { - TINT_ICE(Semantic, diag) << "sem::Constant() type <-> scalar mismatch. type: '" - << ty->TypeInfo().name << "' scalar: " << num_scalars; - } - TINT_ASSERT(Semantic, el_ty->is_abstract_or_scalar()); - return el_ty; - } - TINT_UNREACHABLE(Semantic, diag) << "Unsupported sem::Constant type: " << ty->TypeInfo().name; - return nullptr; +size_t CountElements(const Constant::Elements& elements) { + return std::visit([](auto&& vec) { return vec.size(); }, elements); } - } // namespace Constant::Constant() {} -Constant::Constant(const sem::Type* ty, Scalars els) - : type_(ty), elem_type_(CheckElemType(ty, els.size())), elems_(std::move(els)) {} +Constant::Constant(const sem::Type* ty, Elements els) + : type_(ty), elem_type_(CheckElemType(ty, CountElements(els))), elems_(std::move(els)) {} + +Constant::Constant(const sem::Type* ty, AInts vec) : Constant(ty, Elements{std::move(vec)}) {} + +Constant::Constant(const sem::Type* ty, AFloats vec) : Constant(ty, Elements{std::move(vec)}) {} Constant::Constant(const Constant&) = default; @@ -54,16 +44,31 @@ Constant::~Constant() = default; Constant& Constant::operator=(const Constant& rhs) = default; bool Constant::AnyZero() const { - for (auto scalar : elems_) { - auto is_zero = [&](auto&& s) { - using T = std::remove_reference_t; - return s == T(0); - }; - if (std::visit(is_zero, scalar)) { - return true; + return WithElements([&](auto&& vec) { + for (auto scalar : vec) { + using T = std::remove_reference_t; + if (scalar == T(0)) { + return true; + } } + return false; + }); +} + +const Type* Constant::CheckElemType(const sem::Type* ty, size_t num_elements) { + diag::List diag; + if (ty->is_abstract_or_scalar() || ty->IsAnyOf()) { + uint32_t count = 0; + auto* el_ty = Type::ElementOf(ty, &count); + if (num_elements != count) { + TINT_ICE(Semantic, diag) << "sem::Constant() type <-> element mismatch. type: '" + << ty->TypeInfo().name << "' element: " << num_elements; + } + TINT_ASSERT(Semantic, el_ty->is_abstract_or_scalar()); + return el_ty; } - return false; + TINT_UNREACHABLE(Semantic, diag) << "Unsupported sem::Constant type: " << ty->TypeInfo().name; + return nullptr; } } // namespace tint::sem diff --git a/src/tint/sem/constant.h b/src/tint/sem/constant.h index ea143b618b..43109f7ccd 100644 --- a/src/tint/sem/constant.h +++ b/src/tint/sem/constant.h @@ -15,7 +15,10 @@ #ifndef SRC_TINT_SEM_CONSTANT_H_ #define SRC_TINT_SEM_CONSTANT_H_ -#include +#include +// TODO(https://crbug.com/dawn/1379) Update cpplint and remove NOLINT +#include +#include // NOLINT(build/include_order) #include #include "src/tint/program_builder.h" @@ -23,15 +26,31 @@ namespace tint::sem { -/// A Constant is compile-time known expression value, expressed as a flattened -/// list of scalar values. Value may be of a scalar or vector type. +/// A Constant holds a compile-time evaluated expression value, expressed as a flattened list of +/// element values. The expression type may be of an abstract-numeric, scalar, vector or matrix +/// type. Constant holds the element values in either a vector of abstract-integer (AInt) or +/// abstract-float (AFloat), depending on the element type. class Constant { public: - /// Scalar holds a single constant scalar value - one of: AInt, AFloat or bool. - using Scalar = std::variant; + /// AInts is a vector of AInt, used to hold elements of the WGSL types: + /// * abstract-integer + /// * i32 + /// * u32 + /// * bool (0 or 1) + using AInts = std::vector; - /// Scalars is a list of scalar values - using Scalars = std::vector; + /// AFloats is a vector of AFloat, used to hold elements of the WGSL types: + /// * abstract-float + /// * f32 + /// * f16 + using AFloats = std::vector; + + /// Elements is either a vector of AInts or AFloats + using Elements = std::variant; + + /// Helper that resolves to either AInts or AFloats based on the element type T. + template + using ElementVectorFor = std::conditional_t>, AFloats, AInts>; /// Constructs an invalid Constant Constant(); @@ -39,7 +58,23 @@ class Constant { /// Constructs a Constant of the given type and element values /// @param ty the Constant type /// @param els the Constant element values - Constant(const Type* ty, Scalars els); + Constant(const sem::Type* ty, Elements els); + + /// Constructs a Constant of the given type and element values + /// @param ty the Constant type + /// @param vec the Constant element values + Constant(const sem::Type* ty, AInts vec); + + /// Constructs a Constant of the given type and element values + /// @param ty the Constant type + /// @param vec the Constant element values + Constant(const sem::Type* ty, AFloats vec); + + /// Constructs a Constant of the given type and element values + /// @param ty the Constant type + /// @param els the Constant element values + template + Constant(const sem::Type* ty, std::initializer_list els); /// Copy constructor Constant(const Constant&); @@ -61,42 +96,77 @@ class Constant { /// @returns the type of the Constant const sem::Type* Type() const { return type_; } + /// @returns the number of elements + size_t ElementCount() const { + return std::visit([](auto&& v) { return v.size(); }, elems_); + } + /// @returns the element type of the Constant const sem::Type* ElementType() const { return elem_type_; } - /// @returns the constant's scalar elements - const Scalars& Elements() const { return elems_; } + /// @returns the constant's elements + const Elements& GetElements() const { return elems_; } - /// @returns true if any scalar element is zero + /// WithElements calls the function `f` with the vector of elements as either AFloats or AInts + /// @param f a function-like with the signature `R(auto&&)`. + /// @returns the result of calling `f`. + template + auto WithElements(F&& f) const { + return std::visit(std::forward(f), elems_); + } + + /// WithElements calls the function `f` with the element vector as either AFloats or AInts + /// @param f a function-like with the signature `R(auto&&)`. + /// @returns the result of calling `f`. + template + auto WithElements(F&& f) { + return std::visit(std::forward(f), elems_); + } + + /// @returns the elements as a vector of AInt + inline const AInts& IElements() const { return std::get(elems_); } + + /// @returns the elements as a vector of AFloat + inline const AFloats& FElements() const { return std::get(elems_); } + + /// @returns true if any element is zero bool AnyZero() const; - /// @param index the index of the scalar value - /// @return the value of the scalar at `index`, which must be of type `T`. + /// @param index the index of the element + /// @return the element at `index`, which must be of type `T`. template - T Element(size_t index) const { - return std::get(elems_[index]); - } - - /// @param index the index of the scalar value - /// @return the value of the scalar `static_cast` to type T. - template - T ElementAs(size_t index) const { - return Cast(elems_[index]); - } - - /// @param s the input scalar - /// @returns the scalar `s` cast to the type `T`. - template - static T Cast(Scalar s) { - return std::visit([](auto v) { return static_cast(v); }, s); - } + T Element(size_t index) const; private: + /// Checks that the provided type matches the number of expected elements. + /// @returns the element type of `ty`. + const sem::Type* CheckElemType(const sem::Type* ty, size_t num_elements); + const sem::Type* type_ = nullptr; const sem::Type* elem_type_ = nullptr; - Scalars elems_; + Elements elems_; }; +template +Constant::Constant(const sem::Type* ty, std::initializer_list els) + : type_(ty), elem_type_(CheckElemType(type_, els.size())) { + ElementVectorFor elements; + elements.reserve(els.size()); + for (auto el : els) { + elements.emplace_back(AFloat(el)); + } + elems_ = Elements{std::move(elements)}; +} + +template +T Constant::Element(size_t index) const { + if constexpr (std::is_same_v, AFloats>) { + return static_cast(FElements()[index].value); + } else { + return static_cast(IElements()[index].value); + } +} + } // namespace tint::sem #endif // SRC_TINT_SEM_CONSTANT_H_ diff --git a/src/tint/sem/constant_test.cc b/src/tint/sem/constant_test.cc new file mode 100644 index 0000000000..345ebd807b --- /dev/null +++ b/src/tint/sem/constant_test.cc @@ -0,0 +1,199 @@ +// Copyright 2022 The Tint Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "src/tint/sem/constant.h" + +#include + +#include "src/tint/sem/abstract_float.h" +#include "src/tint/sem/abstract_int.h" +#include "src/tint/sem/test_helper.h" + +using namespace tint::number_suffixes; // NOLINT + +namespace tint::sem { +namespace { + +using ConstantTest = TestHelper; + +TEST_F(ConstantTest, ConstructorInitializerList) { + { + Constant c(create(), {1_a}); + c.WithElements([&](auto&& vec) { EXPECT_THAT(vec, testing::ElementsAre(1_a)); }); + } + { + Constant c(create(), {1_i}); + c.WithElements([&](auto&& vec) { EXPECT_THAT(vec, testing::ElementsAre(1_a)); }); + } + { + Constant c(create(), {1_u}); + c.WithElements([&](auto&& vec) { EXPECT_THAT(vec, testing::ElementsAre(1_a)); }); + } + { + Constant c(create(), {false}); + c.WithElements([&](auto&& vec) { EXPECT_THAT(vec, testing::ElementsAre(0_a)); }); + } + { + Constant c(create(), {true}); + c.WithElements([&](auto&& vec) { EXPECT_THAT(vec, testing::ElementsAre(1_a)); }); + } + { + Constant c(create(), {1.0_a}); + c.WithElements([&](auto&& vec) { EXPECT_THAT(vec, testing::ElementsAre(1.0_a)); }); + } + { + Constant c(create(), {1.0_f}); + c.WithElements([&](auto&& vec) { EXPECT_THAT(vec, testing::ElementsAre(1.0_a)); }); + } + { + Constant c(create(), {1.0_h}); + c.WithElements([&](auto&& vec) { EXPECT_THAT(vec, testing::ElementsAre(1.0_a)); }); + } +} + +TEST_F(ConstantTest, Element_ai) { + Constant c(create(), {1_a}); + EXPECT_EQ(c.Element(0), 1_a); + EXPECT_EQ(c.ElementCount(), 1u); +} + +TEST_F(ConstantTest, Element_i32) { + Constant c(create(), {1_a}); + EXPECT_EQ(c.Element(0), 1_i); + EXPECT_EQ(c.ElementCount(), 1u); +} + +TEST_F(ConstantTest, Element_u32) { + Constant c(create(), {1_a}); + EXPECT_EQ(c.Element(0), 1_u); + EXPECT_EQ(c.ElementCount(), 1u); +} + +TEST_F(ConstantTest, Element_bool) { + Constant c(create(), {true}); + EXPECT_EQ(c.Element(0), true); + EXPECT_EQ(c.ElementCount(), 1u); +} + +TEST_F(ConstantTest, Element_af) { + Constant c(create(), {1.0_a}); + EXPECT_EQ(c.Element(0), 1.0_a); + EXPECT_EQ(c.ElementCount(), 1u); +} + +TEST_F(ConstantTest, Element_f32) { + Constant c(create(), {1.0_a}); + EXPECT_EQ(c.Element(0), 1.0_f); + EXPECT_EQ(c.ElementCount(), 1u); +} + +TEST_F(ConstantTest, Element_f16) { + Constant c(create(), {1.0_a}); + EXPECT_EQ(c.Element(0), 1.0_h); + EXPECT_EQ(c.ElementCount(), 1u); +} + +TEST_F(ConstantTest, Element_vec3_ai) { + Constant c(create(create(), 3u), {1_a, 2_a, 3_a}); + EXPECT_EQ(c.Element(0), 1_a); + EXPECT_EQ(c.Element(1), 2_a); + EXPECT_EQ(c.Element(2), 3_a); + EXPECT_EQ(c.ElementCount(), 3u); +} + +TEST_F(ConstantTest, Element_vec3_i32) { + Constant c(create(create(), 3u), {1_a, 2_a, 3_a}); + EXPECT_EQ(c.Element(0), 1_i); + EXPECT_EQ(c.Element(1), 2_i); + EXPECT_EQ(c.Element(2), 3_i); + EXPECT_EQ(c.ElementCount(), 3u); +} + +TEST_F(ConstantTest, Element_vec3_u32) { + Constant c(create(create(), 3u), {1_a, 2_a, 3_a}); + EXPECT_EQ(c.Element(0), 1_u); + EXPECT_EQ(c.Element(1), 2_u); + EXPECT_EQ(c.Element(2), 3_u); + EXPECT_EQ(c.ElementCount(), 3u); +} + +TEST_F(ConstantTest, Element_vec3_bool) { + Constant c(create(create(), 2u), {true, false}); + EXPECT_EQ(c.Element(0), true); + EXPECT_EQ(c.Element(1), false); + EXPECT_EQ(c.ElementCount(), 2u); +} + +TEST_F(ConstantTest, Element_vec3_af) { + Constant c(create(create(), 3u), {1.0_a, 2.0_a, 3.0_a}); + EXPECT_EQ(c.Element(0), 1.0_a); + EXPECT_EQ(c.Element(1), 2.0_a); + EXPECT_EQ(c.Element(2), 3.0_a); + EXPECT_EQ(c.ElementCount(), 3u); +} + +TEST_F(ConstantTest, Element_vec3_f32) { + Constant c(create(create(), 3u), {1.0_a, 2.0_a, 3.0_a}); + EXPECT_EQ(c.Element(0), 1.0_f); + EXPECT_EQ(c.Element(1), 2.0_f); + EXPECT_EQ(c.Element(2), 3.0_f); + EXPECT_EQ(c.ElementCount(), 3u); +} + +TEST_F(ConstantTest, Element_vec3_f16) { + Constant c(create(create(), 3u), {1.0_a, 2.0_a, 3.0_a}); + EXPECT_EQ(c.Element(0), 1.0_h); + EXPECT_EQ(c.Element(1), 2.0_h); + EXPECT_EQ(c.Element(2), 3.0_h); + EXPECT_EQ(c.ElementCount(), 3u); +} + +TEST_F(ConstantTest, Element_mat2x3_af) { + Constant c(create(create(create(), 3u), 2u), + {1.0_a, 2.0_a, 3.0_a, 4.0_a, 5.0_a, 6.0_a}); + EXPECT_EQ(c.Element(0), 1.0_a); + EXPECT_EQ(c.Element(1), 2.0_a); + EXPECT_EQ(c.Element(2), 3.0_a); + EXPECT_EQ(c.Element(3), 4.0_a); + EXPECT_EQ(c.Element(4), 5.0_a); + EXPECT_EQ(c.Element(5), 6.0_a); + EXPECT_EQ(c.ElementCount(), 6u); +} + +TEST_F(ConstantTest, Element_mat2x3_f32) { + Constant c(create(create(create(), 3u), 2u), + {1.0_a, 2.0_a, 3.0_a, 4.0_a, 5.0_a, 6.0_a}); + EXPECT_EQ(c.Element(0), 1.0_f); + EXPECT_EQ(c.Element(1), 2.0_f); + EXPECT_EQ(c.Element(2), 3.0_f); + EXPECT_EQ(c.Element(3), 4.0_f); + EXPECT_EQ(c.Element(4), 5.0_f); + EXPECT_EQ(c.Element(5), 6.0_f); + EXPECT_EQ(c.ElementCount(), 6u); +} + +TEST_F(ConstantTest, Element_mat2x3_f16) { + Constant c(create(create(create(), 3u), 2u), + {1.0_a, 2.0_a, 3.0_a, 4.0_a, 5.0_a, 6.0_a}); + EXPECT_EQ(c.Element(0), 1.0_h); + EXPECT_EQ(c.Element(1), 2.0_h); + EXPECT_EQ(c.Element(2), 3.0_h); + EXPECT_EQ(c.Element(3), 4.0_h); + EXPECT_EQ(c.Element(4), 5.0_h); + EXPECT_EQ(c.Element(5), 6.0_h); + EXPECT_EQ(c.ElementCount(), 6u); +} + +} // namespace +} // namespace tint::sem diff --git a/src/tint/transform/fold_constants.cc b/src/tint/transform/fold_constants.cc index d51a68a725..f268800b35 100644 --- a/src/tint/transform/fold_constants.cc +++ b/src/tint/transform/fold_constants.cc @@ -23,6 +23,7 @@ #include "src/tint/sem/expression.h" #include "src/tint/sem/type_constructor.h" #include "src/tint/sem/type_conversion.h" +#include "src/tint/utils/transform.h" TINT_INSTANTIATE_TYPEINFO(tint::transform::FoldConstants); @@ -50,24 +51,40 @@ void FoldConstants::Run(CloneContext& ctx, const DataMap&, DataMap&) const { return nullptr; } - // If original ctor expression had no init values, don't replace the - // expression + // If original ctor expression had no init values, don't replace the expression if (call->Arguments().empty()) { return nullptr; } - auto build_scalar = [&](sem::Constant::Scalar s) { + auto build_elements = [&](size_t limit) { return Switch( value.ElementType(), // - [&](const sem::I32*) { return ctx.dst->Expr(i32(std::get(s).value)); }, - [&](const sem::U32*) { return ctx.dst->Expr(u32(std::get(s).value)); }, - [&](const sem::F32*) { return ctx.dst->Expr(f32(std::get(s).value)); }, - [&](const sem::Bool*) { return ctx.dst->Expr(std::get(s)); }, + [&](const sem::Bool*) { + return utils::TransformN(value.IElements(), limit, [&](AInt i) { + return static_cast( + ctx.dst->Expr(static_cast(i.value))); + }); + }, + [&](const sem::I32*) { + return utils::TransformN(value.IElements(), limit, [&](AInt i) { + return static_cast(ctx.dst->Expr(i32(i.value))); + }); + }, + [&](const sem::U32*) { + return utils::TransformN(value.IElements(), limit, [&](AInt i) { + return static_cast(ctx.dst->Expr(u32(i.value))); + }); + }, + [&](const sem::F32*) { + return utils::TransformN(value.FElements(), limit, [&](AFloat f) { + return static_cast(ctx.dst->Expr(f32(f.value))); + }); + }, [&](Default) { TINT_ICE(Transform, ctx.dst->Diagnostics()) << "unhandled Constant::Scalar type: " << value.ElementType()->FriendlyName(ctx.src->Symbols()); - return nullptr; + return ast::ExpressionList{}; }); }; @@ -78,17 +95,17 @@ void FoldConstants::Run(CloneContext& ctx, const DataMap&, DataMap&) const { // constructor args that the original node had, but after folding // constants, cases like the following are problematic: // - // vec3 = vec3(vec2, 1.0) // vec_size=3, ctor_size=2 + // vec3 = vec3(vec2(), 1.0) // vec_size=3, ctor_size=2 // // In this case, creating a vec3 with 2 args is invalid, so we should // create it with 3. So what we do is construct with vec_size args, // except if the original vector was single-value initialized, in // which case, we only construct with one arg again. - uint32_t ctor_size = (call->Arguments().size() == 1) ? 1 : vec_size; - ast::ExpressionList ctors; - for (uint32_t i = 0; i < ctor_size; ++i) { - ctors.emplace_back(build_scalar(value.Elements()[i])); + if (call->Arguments().size() == 1) { + ctors = build_elements(1); + } else { + ctors = build_elements(value.ElementCount()); } auto* el_ty = CreateASTTypeFor(ctx, vec->type()); @@ -96,7 +113,7 @@ void FoldConstants::Run(CloneContext& ctx, const DataMap&, DataMap&) const { } if (ty->is_scalar()) { - return build_scalar(value.Elements()[0]); + return build_elements(1)[0]; } return nullptr; diff --git a/src/tint/utils/compiler_macros.h b/src/tint/utils/compiler_macros.h new file mode 100644 index 0000000000..8b360ffe4b --- /dev/null +++ b/src/tint/utils/compiler_macros.h @@ -0,0 +1,39 @@ +// Copyright 2022 The Tint Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef SRC_TINT_UTILS_COMPILER_MACROS_H_ +#define SRC_TINT_UTILS_COMPILER_MACROS_H_ + +#define TINT_REQUIRE_SEMICOLON \ + do { \ + } while (false) + +#if defined(_MSC_VER) +// clang-format off +#define TINT_BEGIN_DISABLE_WARNING_UNREACHABLE_CODE() \ + __pragma(warning(push)) \ + __pragma(warning(disable:4702)) \ + TINT_REQUIRE_SEMICOLON +#define TINT_END_DISABLE_WARNING_UNREACHABLE_CODE() \ + __pragma(warning(pop)) \ + TINT_REQUIRE_SEMICOLON +// clang-format on +#else +// clang-format off +#define TINT_BEGIN_DISABLE_WARNING_UNREACHABLE_CODE() TINT_REQUIRE_SEMICOLON +#define TINT_END_DISABLE_WARNING_UNREACHABLE_CODE() TINT_REQUIRE_SEMICOLON +// clang-format on +#endif // defined(_MSC_VER) + +#endif // SRC_TINT_UTILS_COMPILER_MACROS_H_ diff --git a/src/tint/writer/hlsl/generator_impl.cc b/src/tint/writer/hlsl/generator_impl.cc index affaf5e2ec..344d0d87c8 100644 --- a/src/tint/writer/hlsl/generator_impl.cc +++ b/src/tint/writer/hlsl/generator_impl.cc @@ -639,7 +639,6 @@ bool GeneratorImpl::EmitAssign(const ast::AssignmentStatement* stmt) { bool GeneratorImpl::EmitExpressionOrOneIfZero(std::ostream& out, const ast::Expression* expr) { // For constants, replace literal 0 with 1. - sem::Constant::Scalars elems; if (const auto& val = builder_.Sem().Get(expr)->ConstantValue()) { if (!val.AnyZero()) { return EmitExpression(out, expr); @@ -657,7 +656,7 @@ bool GeneratorImpl::EmitExpressionOrOneIfZero(std::ostream& out, const ast::Expr } out << "("; - for (size_t i = 0; i < val.Elements().size(); ++i) { + for (size_t i = 0; i < val.ElementCount(); ++i) { if (i != 0) { out << ", "; } diff --git a/src/tint/writer/spirv/builder.cc b/src/tint/writer/spirv/builder.cc index 5f8072bb05..df57345b8f 100644 --- a/src/tint/writer/spirv/builder.cc +++ b/src/tint/writer/spirv/builder.cc @@ -924,7 +924,7 @@ bool Builder::GenerateIndexAccessor(const ast::IndexAccessorExpression* expr, Ac Operand(result_type_id), extract, Operand(info->source_id), - Operand(idx_constval.ElementAs(0)), + Operand(idx_constval.Element(0)), })) { return false; } diff --git a/test/tint/BUILD.gn b/test/tint/BUILD.gn index 5ff7829ffe..7acdea9ce8 100644 --- a/test/tint/BUILD.gn +++ b/test/tint/BUILD.gn @@ -289,6 +289,7 @@ tint_unittests_source_set("tint_unittests_sem_src") { "../../src/tint/sem/atomic_test.cc", "../../src/tint/sem/bool_test.cc", "../../src/tint/sem/builtin_test.cc", + "../../src/tint/sem/constant_test.cc", "../../src/tint/sem/depth_multisampled_texture_test.cc", "../../src/tint/sem/depth_texture_test.cc", "../../src/tint/sem/expression_test.cc",