diff --git a/src/tint/resolver/resolver.cc b/src/tint/resolver/resolver.cc index d647e96ece..4993cfef84 100644 --- a/src/tint/resolver/resolver.cc +++ b/src/tint/resolver/resolver.cc @@ -793,13 +793,12 @@ bool Resolver::WorkgroupSize(const ast::Function* func) { continue; } // validator_.Validate and set the default value for this dimension. - if (is_i32 ? value.Elements()[0].i32 < 1 : value.Elements()[0].u32 < 1) { + if (value.Element(0).value < 1) { AddError("workgroup_size argument must be at least 1", values[i]->source); return false; } - ws[i].value = - is_i32 ? static_cast(value.Elements()[0].i32) : value.Elements()[0].u32; + ws[i].value = static_cast(value.Element(0).value); } current_function_->SetWorkgroupSize(std::move(ws)); @@ -1855,13 +1854,12 @@ sem::Array* Resolver::Array(const ast::Array* arr) { return nullptr; } - if (ty->is_signed_integer_scalar() ? count_val.Elements()[0].i32 < 1 - : count_val.Elements()[0].u32 < 1u) { + if (count_val.Element(0).value < 1) { AddError("array size must be at least 1", size_source); return nullptr; } - count = count_val.Elements()[0].u32; + count = static_cast(count_val.Element(0).value); } auto size = std::max(count, 1) * stride; diff --git a/src/tint/resolver/resolver_constants.cc b/src/tint/resolver/resolver_constants.cc index 3ec6f94a65..66094cf65d 100644 --- a/src/tint/resolver/resolver_constants.cc +++ b/src/tint/resolver/resolver_constants.cc @@ -37,13 +37,10 @@ sem::Constant Resolver::EvaluateConstantValue(const ast::LiteralExpression* lite return Switch( literal, [&](const ast::IntLiteralExpression* lit) { - if (lit->suffix == ast::IntLiteralExpression::Suffix::kU) { - return sem::Constant{type, {u32(lit->value)}}; - } - return sem::Constant{type, {i32(lit->value)}}; + return sem::Constant{type, {AInt(lit->value)}}; }, [&](const ast::FloatLiteralExpression* lit) { - return sem::Constant{type, {f32(lit->value)}}; + return sem::Constant{type, {AFloat(lit->value)}}; }, [&](const ast::BoolLiteralExpression* lit) { return sem::Constant{type, {lit->value}}; @@ -64,21 +61,16 @@ sem::Constant Resolver::EvaluateConstantValue(const ast::CallExpression* call, // For zero value init, return 0s if (call->args.empty()) { - if (elem_type->Is()) { - return sem::Constant(type, sem::Constant::Scalars(result_size, 0_i)); - } - if (elem_type->Is()) { - return sem::Constant(type, sem::Constant::Scalars(result_size, 0_u)); - } - // Add f16 zero scalar here - if (elem_type->Is()) { - return sem::Constant(type, sem::Constant::Scalars(result_size, f16{0.f})); - } - if (elem_type->Is()) { - return sem::Constant(type, sem::Constant::Scalars(result_size, 0_f)); - } - if (elem_type->Is()) { - return sem::Constant(type, sem::Constant::Scalars(result_size, false)); + using Scalars = sem::Constant::Scalars; + auto constant = Switch( + elem_type, + [&](const sem::I32*) { return sem::Constant(type, Scalars(result_size, AInt(0))); }, + [&](const sem::U32*) { return sem::Constant(type, Scalars(result_size, AInt(0))); }, + [&](const sem::F32*) { return sem::Constant(type, Scalars(result_size, AFloat(0))); }, + [&](const sem::F16*) { return sem::Constant(type, Scalars(result_size, AFloat(0))); }, + [&](const sem::Bool*) { return sem::Constant(type, Scalars(result_size, false)); }); + if (constant.IsValid()) { + return constant; } } @@ -112,33 +104,14 @@ sem::Constant Resolver::ConstantCast(const sem::Constant& value, sem::Constant::Scalars elems; for (size_t i = 0; i < value.Elements().size(); ++i) { + // TODO(crbug.com/tint/1504): Check that value fits in new type elems.emplace_back(Switch( - target_elem_type, - [&](const sem::I32*) { - return value.WithScalarAt(i, [](auto&& s) { // - return i32(static_cast(s)); - }); - }, - [&](const sem::U32*) { - return value.WithScalarAt(i, [](auto&& s) { // - return u32(static_cast(s)); - }); - }, - [&](const sem::F16*) { - return value.WithScalarAt(i, [](auto&& s) { // - return f16{static_cast(s)}; - }); - }, - [&](const sem::F32*) { - return value.WithScalarAt(i, [](auto&& s) { // - return static_cast(s); - }); - }, - [&](const sem::Bool*) { - return value.WithScalarAt(i, [](auto&& s) { // - return static_cast(s); - }); - }, + target_elem_type, // + [&](const sem::I32*) { return value.ElementAs(i); }, + [&](const sem::U32*) { return value.ElementAs(i); }, + [&](const sem::F32*) { return value.ElementAs(i); }, + [&](const sem::F16*) { return value.ElementAs(i); }, + [&](const sem::Bool*) { return value.ElementAs(i); }, [&](Default) { diag::List diags; TINT_UNREACHABLE(Semantic, diags) diff --git a/src/tint/resolver/resolver_constants_test.cc b/src/tint/resolver/resolver_constants_test.cc index a9624cab6a..05e6e4c4dc 100644 --- a/src/tint/resolver/resolver_constants_test.cc +++ b/src/tint/resolver/resolver_constants_test.cc @@ -39,7 +39,7 @@ TEST_F(ResolverConstantsTest, Scalar_i32) { EXPECT_EQ(sem->ConstantValue().Type(), sem->Type()); EXPECT_EQ(sem->ConstantValue().ElementType(), sem->Type()); ASSERT_EQ(sem->ConstantValue().Elements().size(), 1u); - EXPECT_EQ(sem->ConstantValue().Elements()[0].i32, 99); + EXPECT_EQ(sem->ConstantValue().Element(0).value, 99); } TEST_F(ResolverConstantsTest, Scalar_u32) { @@ -54,7 +54,7 @@ TEST_F(ResolverConstantsTest, Scalar_u32) { EXPECT_EQ(sem->ConstantValue().Type(), sem->Type()); EXPECT_EQ(sem->ConstantValue().ElementType(), sem->Type()); ASSERT_EQ(sem->ConstantValue().Elements().size(), 1u); - EXPECT_EQ(sem->ConstantValue().Elements()[0].u32, 99u); + EXPECT_EQ(sem->ConstantValue().Element(0).value, 99u); } TEST_F(ResolverConstantsTest, Scalar_f32) { @@ -69,7 +69,7 @@ TEST_F(ResolverConstantsTest, Scalar_f32) { EXPECT_EQ(sem->ConstantValue().Type(), sem->Type()); EXPECT_EQ(sem->ConstantValue().ElementType(), sem->Type()); ASSERT_EQ(sem->ConstantValue().Elements().size(), 1u); - EXPECT_EQ(sem->ConstantValue().Elements()[0].f32, 9.9f); + EXPECT_EQ(sem->ConstantValue().Element(0).value, 9.9f); } TEST_F(ResolverConstantsTest, Scalar_bool) { @@ -84,7 +84,7 @@ TEST_F(ResolverConstantsTest, Scalar_bool) { EXPECT_EQ(sem->ConstantValue().Type(), sem->Type()); EXPECT_EQ(sem->ConstantValue().ElementType(), sem->Type()); ASSERT_EQ(sem->ConstantValue().Elements().size(), 1u); - EXPECT_EQ(sem->ConstantValue().Elements()[0].bool_, true); + EXPECT_EQ(sem->ConstantValue().Element(0), true); } TEST_F(ResolverConstantsTest, Vec3_ZeroInit_i32) { @@ -101,9 +101,9 @@ TEST_F(ResolverConstantsTest, Vec3_ZeroInit_i32) { EXPECT_EQ(sem->ConstantValue().Type(), sem->Type()); EXPECT_TRUE(sem->ConstantValue().ElementType()->Is()); ASSERT_EQ(sem->ConstantValue().Elements().size(), 3u); - EXPECT_EQ(sem->ConstantValue().Elements()[0].i32, 0); - EXPECT_EQ(sem->ConstantValue().Elements()[1].i32, 0); - EXPECT_EQ(sem->ConstantValue().Elements()[2].i32, 0); + EXPECT_EQ(sem->ConstantValue().Element(0).value, 0); + EXPECT_EQ(sem->ConstantValue().Element(1).value, 0); + EXPECT_EQ(sem->ConstantValue().Element(2).value, 0); } TEST_F(ResolverConstantsTest, Vec3_ZeroInit_u32) { @@ -120,9 +120,9 @@ TEST_F(ResolverConstantsTest, Vec3_ZeroInit_u32) { EXPECT_EQ(sem->ConstantValue().Type(), sem->Type()); EXPECT_TRUE(sem->ConstantValue().ElementType()->Is()); ASSERT_EQ(sem->ConstantValue().Elements().size(), 3u); - EXPECT_EQ(sem->ConstantValue().Elements()[0].u32, 0u); - EXPECT_EQ(sem->ConstantValue().Elements()[1].u32, 0u); - EXPECT_EQ(sem->ConstantValue().Elements()[2].u32, 0u); + EXPECT_EQ(sem->ConstantValue().Element(0).value, 0u); + EXPECT_EQ(sem->ConstantValue().Element(1).value, 0u); + EXPECT_EQ(sem->ConstantValue().Element(2).value, 0u); } TEST_F(ResolverConstantsTest, Vec3_ZeroInit_f32) { @@ -139,9 +139,9 @@ TEST_F(ResolverConstantsTest, Vec3_ZeroInit_f32) { EXPECT_EQ(sem->ConstantValue().Type(), sem->Type()); EXPECT_TRUE(sem->ConstantValue().ElementType()->Is()); ASSERT_EQ(sem->ConstantValue().Elements().size(), 3u); - EXPECT_EQ(sem->ConstantValue().Elements()[0].f32, 0.0f); - EXPECT_EQ(sem->ConstantValue().Elements()[1].f32, 0.0f); - EXPECT_EQ(sem->ConstantValue().Elements()[2].f32, 0.0f); + EXPECT_EQ(sem->ConstantValue().Element(0).value, 0.0); + EXPECT_EQ(sem->ConstantValue().Element(1).value, 0.0); + EXPECT_EQ(sem->ConstantValue().Element(2).value, 0.0); } TEST_F(ResolverConstantsTest, Vec3_ZeroInit_bool) { @@ -158,9 +158,9 @@ TEST_F(ResolverConstantsTest, Vec3_ZeroInit_bool) { EXPECT_EQ(sem->ConstantValue().Type(), sem->Type()); EXPECT_TRUE(sem->ConstantValue().ElementType()->Is()); ASSERT_EQ(sem->ConstantValue().Elements().size(), 3u); - EXPECT_EQ(sem->ConstantValue().Elements()[0].bool_, false); - EXPECT_EQ(sem->ConstantValue().Elements()[1].bool_, false); - EXPECT_EQ(sem->ConstantValue().Elements()[2].bool_, false); + EXPECT_EQ(sem->ConstantValue().Element(0), false); + EXPECT_EQ(sem->ConstantValue().Element(1), false); + EXPECT_EQ(sem->ConstantValue().Element(2), false); } TEST_F(ResolverConstantsTest, Vec3_Splat_i32) { @@ -177,9 +177,9 @@ TEST_F(ResolverConstantsTest, Vec3_Splat_i32) { EXPECT_EQ(sem->ConstantValue().Type(), sem->Type()); EXPECT_TRUE(sem->ConstantValue().ElementType()->Is()); ASSERT_EQ(sem->ConstantValue().Elements().size(), 3u); - EXPECT_EQ(sem->ConstantValue().Elements()[0].i32, 99); - EXPECT_EQ(sem->ConstantValue().Elements()[1].i32, 99); - EXPECT_EQ(sem->ConstantValue().Elements()[2].i32, 99); + EXPECT_EQ(sem->ConstantValue().Element(0).value, 99); + EXPECT_EQ(sem->ConstantValue().Element(1).value, 99); + EXPECT_EQ(sem->ConstantValue().Element(2).value, 99); } TEST_F(ResolverConstantsTest, Vec3_Splat_u32) { @@ -196,9 +196,9 @@ TEST_F(ResolverConstantsTest, Vec3_Splat_u32) { EXPECT_EQ(sem->ConstantValue().Type(), sem->Type()); EXPECT_TRUE(sem->ConstantValue().ElementType()->Is()); ASSERT_EQ(sem->ConstantValue().Elements().size(), 3u); - EXPECT_EQ(sem->ConstantValue().Elements()[0].u32, 99u); - EXPECT_EQ(sem->ConstantValue().Elements()[1].u32, 99u); - EXPECT_EQ(sem->ConstantValue().Elements()[2].u32, 99u); + EXPECT_EQ(sem->ConstantValue().Element(0).value, 99u); + EXPECT_EQ(sem->ConstantValue().Element(1).value, 99u); + EXPECT_EQ(sem->ConstantValue().Element(2).value, 99u); } TEST_F(ResolverConstantsTest, Vec3_Splat_f32) { @@ -215,9 +215,9 @@ TEST_F(ResolverConstantsTest, Vec3_Splat_f32) { EXPECT_EQ(sem->ConstantValue().Type(), sem->Type()); EXPECT_TRUE(sem->ConstantValue().ElementType()->Is()); ASSERT_EQ(sem->ConstantValue().Elements().size(), 3u); - EXPECT_EQ(sem->ConstantValue().Elements()[0].f32, 9.9f); - EXPECT_EQ(sem->ConstantValue().Elements()[1].f32, 9.9f); - EXPECT_EQ(sem->ConstantValue().Elements()[2].f32, 9.9f); + EXPECT_EQ(sem->ConstantValue().Element(0).value, 9.9f); + EXPECT_EQ(sem->ConstantValue().Element(1).value, 9.9f); + EXPECT_EQ(sem->ConstantValue().Element(2).value, 9.9f); } TEST_F(ResolverConstantsTest, Vec3_Splat_bool) { @@ -234,9 +234,9 @@ TEST_F(ResolverConstantsTest, Vec3_Splat_bool) { EXPECT_EQ(sem->ConstantValue().Type(), sem->Type()); EXPECT_TRUE(sem->ConstantValue().ElementType()->Is()); ASSERT_EQ(sem->ConstantValue().Elements().size(), 3u); - EXPECT_EQ(sem->ConstantValue().Elements()[0].bool_, true); - EXPECT_EQ(sem->ConstantValue().Elements()[1].bool_, true); - EXPECT_EQ(sem->ConstantValue().Elements()[2].bool_, true); + EXPECT_EQ(sem->ConstantValue().Element(0), true); + EXPECT_EQ(sem->ConstantValue().Element(1), true); + EXPECT_EQ(sem->ConstantValue().Element(2), true); } TEST_F(ResolverConstantsTest, Vec3_FullConstruct_i32) { @@ -253,9 +253,9 @@ TEST_F(ResolverConstantsTest, Vec3_FullConstruct_i32) { EXPECT_EQ(sem->ConstantValue().Type(), sem->Type()); EXPECT_TRUE(sem->ConstantValue().ElementType()->Is()); ASSERT_EQ(sem->ConstantValue().Elements().size(), 3u); - EXPECT_EQ(sem->ConstantValue().Elements()[0].i32, 1); - EXPECT_EQ(sem->ConstantValue().Elements()[1].i32, 2); - EXPECT_EQ(sem->ConstantValue().Elements()[2].i32, 3); + EXPECT_EQ(sem->ConstantValue().Element(0).value, 1); + EXPECT_EQ(sem->ConstantValue().Element(1).value, 2); + EXPECT_EQ(sem->ConstantValue().Element(2).value, 3); } TEST_F(ResolverConstantsTest, Vec3_FullConstruct_u32) { @@ -272,9 +272,9 @@ TEST_F(ResolverConstantsTest, Vec3_FullConstruct_u32) { EXPECT_EQ(sem->ConstantValue().Type(), sem->Type()); EXPECT_TRUE(sem->ConstantValue().ElementType()->Is()); ASSERT_EQ(sem->ConstantValue().Elements().size(), 3u); - EXPECT_EQ(sem->ConstantValue().Elements()[0].u32, 1u); - EXPECT_EQ(sem->ConstantValue().Elements()[1].u32, 2u); - EXPECT_EQ(sem->ConstantValue().Elements()[2].u32, 3u); + EXPECT_EQ(sem->ConstantValue().Element(0).value, 1); + EXPECT_EQ(sem->ConstantValue().Element(1).value, 2); + EXPECT_EQ(sem->ConstantValue().Element(2).value, 3); } TEST_F(ResolverConstantsTest, Vec3_FullConstruct_f32) { @@ -291,9 +291,9 @@ TEST_F(ResolverConstantsTest, Vec3_FullConstruct_f32) { EXPECT_EQ(sem->ConstantValue().Type(), sem->Type()); EXPECT_TRUE(sem->ConstantValue().ElementType()->Is()); ASSERT_EQ(sem->ConstantValue().Elements().size(), 3u); - EXPECT_EQ(sem->ConstantValue().Elements()[0].f32, 1.f); - EXPECT_EQ(sem->ConstantValue().Elements()[1].f32, 2.f); - EXPECT_EQ(sem->ConstantValue().Elements()[2].f32, 3.f); + EXPECT_EQ(sem->ConstantValue().Element(0).value, 1.f); + EXPECT_EQ(sem->ConstantValue().Element(1).value, 2.f); + EXPECT_EQ(sem->ConstantValue().Element(2).value, 3.f); } TEST_F(ResolverConstantsTest, Vec3_FullConstruct_bool) { @@ -310,9 +310,9 @@ TEST_F(ResolverConstantsTest, Vec3_FullConstruct_bool) { EXPECT_EQ(sem->ConstantValue().Type(), sem->Type()); EXPECT_TRUE(sem->ConstantValue().ElementType()->Is()); ASSERT_EQ(sem->ConstantValue().Elements().size(), 3u); - EXPECT_EQ(sem->ConstantValue().Elements()[0].bool_, true); - EXPECT_EQ(sem->ConstantValue().Elements()[1].bool_, false); - EXPECT_EQ(sem->ConstantValue().Elements()[2].bool_, true); + EXPECT_EQ(sem->ConstantValue().Element(0), true); + EXPECT_EQ(sem->ConstantValue().Element(1), false); + EXPECT_EQ(sem->ConstantValue().Element(2), true); } TEST_F(ResolverConstantsTest, Vec3_MixConstruct_i32) { @@ -329,9 +329,9 @@ TEST_F(ResolverConstantsTest, Vec3_MixConstruct_i32) { EXPECT_EQ(sem->ConstantValue().Type(), sem->Type()); EXPECT_TRUE(sem->ConstantValue().ElementType()->Is()); ASSERT_EQ(sem->ConstantValue().Elements().size(), 3u); - EXPECT_EQ(sem->ConstantValue().Elements()[0].i32, 1); - EXPECT_EQ(sem->ConstantValue().Elements()[1].i32, 2); - EXPECT_EQ(sem->ConstantValue().Elements()[2].i32, 3); + EXPECT_EQ(sem->ConstantValue().Element(0).value, 1); + EXPECT_EQ(sem->ConstantValue().Element(1).value, 2); + EXPECT_EQ(sem->ConstantValue().Element(2).value, 3); } TEST_F(ResolverConstantsTest, Vec3_MixConstruct_u32) { @@ -348,9 +348,9 @@ TEST_F(ResolverConstantsTest, Vec3_MixConstruct_u32) { EXPECT_EQ(sem->ConstantValue().Type(), sem->Type()); EXPECT_TRUE(sem->ConstantValue().ElementType()->Is()); ASSERT_EQ(sem->ConstantValue().Elements().size(), 3u); - EXPECT_EQ(sem->ConstantValue().Elements()[0].u32, 1u); - EXPECT_EQ(sem->ConstantValue().Elements()[1].u32, 2u); - EXPECT_EQ(sem->ConstantValue().Elements()[2].u32, 3u); + EXPECT_EQ(sem->ConstantValue().Element(0).value, 1); + EXPECT_EQ(sem->ConstantValue().Element(1).value, 2); + EXPECT_EQ(sem->ConstantValue().Element(2).value, 3); } TEST_F(ResolverConstantsTest, Vec3_MixConstruct_f32) { @@ -367,9 +367,9 @@ TEST_F(ResolverConstantsTest, Vec3_MixConstruct_f32) { EXPECT_EQ(sem->ConstantValue().Type(), sem->Type()); EXPECT_TRUE(sem->ConstantValue().ElementType()->Is()); ASSERT_EQ(sem->ConstantValue().Elements().size(), 3u); - EXPECT_EQ(sem->ConstantValue().Elements()[0].f32, 1.f); - EXPECT_EQ(sem->ConstantValue().Elements()[1].f32, 2.f); - EXPECT_EQ(sem->ConstantValue().Elements()[2].f32, 3.f); + EXPECT_EQ(sem->ConstantValue().Element(0).value, 1.f); + EXPECT_EQ(sem->ConstantValue().Element(1).value, 2.f); + EXPECT_EQ(sem->ConstantValue().Element(2).value, 3.f); } TEST_F(ResolverConstantsTest, Vec3_MixConstruct_bool) { @@ -386,9 +386,9 @@ TEST_F(ResolverConstantsTest, Vec3_MixConstruct_bool) { EXPECT_EQ(sem->ConstantValue().Type(), sem->Type()); EXPECT_TRUE(sem->ConstantValue().ElementType()->Is()); ASSERT_EQ(sem->ConstantValue().Elements().size(), 3u); - EXPECT_EQ(sem->ConstantValue().Elements()[0].bool_, true); - EXPECT_EQ(sem->ConstantValue().Elements()[1].bool_, false); - EXPECT_EQ(sem->ConstantValue().Elements()[2].bool_, true); + EXPECT_EQ(sem->ConstantValue().Element(0), true); + EXPECT_EQ(sem->ConstantValue().Element(1), false); + EXPECT_EQ(sem->ConstantValue().Element(2), true); } TEST_F(ResolverConstantsTest, Vec3_Cast_f32_to_32) { @@ -405,9 +405,9 @@ TEST_F(ResolverConstantsTest, Vec3_Cast_f32_to_32) { EXPECT_EQ(sem->ConstantValue().Type(), sem->Type()); EXPECT_TRUE(sem->ConstantValue().ElementType()->Is()); ASSERT_EQ(sem->ConstantValue().Elements().size(), 3u); - EXPECT_EQ(sem->ConstantValue().Elements()[0].i32, 1); - EXPECT_EQ(sem->ConstantValue().Elements()[1].i32, 2); - EXPECT_EQ(sem->ConstantValue().Elements()[2].i32, 3); + EXPECT_EQ(sem->ConstantValue().Element(0).value, 1); + EXPECT_EQ(sem->ConstantValue().Element(1).value, 2); + EXPECT_EQ(sem->ConstantValue().Element(2).value, 3); } TEST_F(ResolverConstantsTest, Vec3_Cast_u32_to_f32) { @@ -424,9 +424,9 @@ TEST_F(ResolverConstantsTest, Vec3_Cast_u32_to_f32) { EXPECT_EQ(sem->ConstantValue().Type(), sem->Type()); EXPECT_TRUE(sem->ConstantValue().ElementType()->Is()); ASSERT_EQ(sem->ConstantValue().Elements().size(), 3u); - EXPECT_EQ(sem->ConstantValue().Elements()[0].f32, 10.f); - EXPECT_EQ(sem->ConstantValue().Elements()[1].f32, 20.f); - EXPECT_EQ(sem->ConstantValue().Elements()[2].f32, 30.f); + EXPECT_EQ(sem->ConstantValue().Element(0).value, 10.f); + EXPECT_EQ(sem->ConstantValue().Element(1).value, 20.f); + EXPECT_EQ(sem->ConstantValue().Element(2).value, 30.f); } } // namespace diff --git a/src/tint/resolver/validator.cc b/src/tint/resolver/validator.cc index 73e1acfd3a..ec0b16f65b 100644 --- a/src/tint/resolver/validator.cc +++ b/src/tint/resolver/validator.cc @@ -1515,7 +1515,7 @@ bool Validator::TextureBuiltinFunction(const sem::Call* call) const { if (is_const_expr) { auto vector = builtin->Parameters()[index]->Type()->Is(); for (size_t i = 0; i < values.Elements().size(); i++) { - auto value = values.Elements()[i].i32; + auto value = values.Element(i).value; if (value < min || value > max) { if (vector) { AddError("each component of the " + name + diff --git a/src/tint/sem/constant.cc b/src/tint/sem/constant.cc index f5c82a2ced..1c4dc58dd7 100644 --- a/src/tint/sem/constant.cc +++ b/src/tint/sem/constant.cc @@ -60,16 +60,12 @@ Constant::~Constant() = default; Constant& Constant::operator=(const Constant& rhs) = default; bool Constant::AnyZero() const { - for (size_t i = 0; i < Elements().size(); ++i) { - if (WithScalarAt(i, [&](auto&& s) { - // Use std::equal_to to work around -Wfloat-equal warnings - using T = std::remove_reference_t; - auto equal_to = std::equal_to{}; - if (equal_to(s, T(0))) { - return true; - } - return false; - })) { + for (auto scalar : elems_) { + auto is_zero = [&](auto&& s) { + using T = std::remove_reference_t; + return s == T(0); + }; + if (std::visit(is_zero, scalar)) { return true; } } diff --git a/src/tint/sem/constant.h b/src/tint/sem/constant.h index 2e8258dd1c..a20170c97a 100644 --- a/src/tint/sem/constant.h +++ b/src/tint/sem/constant.h @@ -15,6 +15,7 @@ #ifndef SRC_TINT_SEM_CONSTANT_H_ #define SRC_TINT_SEM_CONSTANT_H_ +#include #include #include "src/tint/program_builder.h" @@ -26,40 +27,8 @@ namespace tint::sem { /// list of scalar values. Value may be of a scalar or vector type. class Constant { public: - /// Scalar holds a single constant scalar value, as a union of an i32, u32, - /// f32 or boolean. - union Scalar { - /// The scalar value as a i32 - tint::i32 i32; - /// The scalar value as a u32 - tint::u32 u32; - /// The scalar value as a f32 - tint::f32 f32; - /// The scalar value as a f16, internally stored as float - tint::f16 f16; - /// The scalar value as a bool - bool bool_; - - /// Constructs the scalar with the i32 value `v` - /// @param v the value of the Scalar - Scalar(tint::i32 v) : i32(v) {} // NOLINT - - /// Constructs the scalar with the u32 value `v` - /// @param v the value of the Scalar - Scalar(tint::u32 v) : u32(v) {} // NOLINT - - /// Constructs the scalar with the f32 value `v` - /// @param v the value of the Scalar - Scalar(tint::f32 v) : f32(v) {} // NOLINT - - /// Constructs the scalar with the f16 value `v` - /// @param v the value of the Scalar - Scalar(tint::f16 v) : f16({v}) {} // NOLINT - - /// Constructs the scalar with the bool value `v` - /// @param v the value of the Scalar - Scalar(bool v) : bool_(v) {} // NOLINT - }; + /// Scalar holds a single constant scalar value - one of: AInt, AFloat or bool. + using Scalar = std::variant; /// Scalars is a list of scalar values using Scalars = std::vector; @@ -101,33 +70,18 @@ class Constant { /// @returns true if any scalar element is zero bool AnyZero() const; - /// Calls `func(s)` with s being the current scalar value at `index`. - /// `func` is typically a lambda of the form '[](auto&& s)'. /// @param index the index of the scalar value - /// @param func a function with signature `T(S)` - /// @return the value returned by func. - template - auto WithScalarAt(size_t index, Func&& func) const { - return Switch( - ElementType(), // - [&](const I32*) { return func(elems_[index].i32); }, - [&](const U32*) { return func(elems_[index].u32); }, - [&](const F16*) { return func(elems_[index].f16); }, - [&](const F32*) { return func(elems_[index].f32); }, - [&](const Bool*) { return func(elems_[index].bool_); }, - [&](Default) { - diag::List diags; - TINT_UNREACHABLE(Semantic, diags) - << "invalid scalar type " << type_->TypeInfo().name; - return func(u32(0u)); - }); + /// @return the value of the scalar at `index`, which must be of type `T`. + template + T Element(size_t index) const { + return std::get(elems_[index]); } /// @param index the index of the scalar value /// @return the value of the scalar `static_cast` to type T. template T ElementAs(size_t index) const { - return WithScalarAt(index, [](auto val) { return static_cast(val); }); + return std::visit([](auto val) { return static_cast(val); }, elems_[index]); } private: diff --git a/src/tint/transform/fold_constants.cc b/src/tint/transform/fold_constants.cc index be547b144c..d51a68a725 100644 --- a/src/tint/transform/fold_constants.cc +++ b/src/tint/transform/fold_constants.cc @@ -56,6 +56,21 @@ void FoldConstants::Run(CloneContext& ctx, const DataMap&, DataMap&) const { return nullptr; } + auto build_scalar = [&](sem::Constant::Scalar s) { + return Switch( + value.ElementType(), // + [&](const sem::I32*) { return ctx.dst->Expr(i32(std::get(s).value)); }, + [&](const sem::U32*) { return ctx.dst->Expr(u32(std::get(s).value)); }, + [&](const sem::F32*) { return ctx.dst->Expr(f32(std::get(s).value)); }, + [&](const sem::Bool*) { return ctx.dst->Expr(std::get(s)); }, + [&](Default) { + TINT_ICE(Transform, ctx.dst->Diagnostics()) + << "unhandled Constant::Scalar type: " + << value.ElementType()->FriendlyName(ctx.src->Symbols()); + return nullptr; + }); + }; + if (auto* vec = ty->As()) { uint32_t vec_size = static_cast(vec->Width()); @@ -73,7 +88,7 @@ void FoldConstants::Run(CloneContext& ctx, const DataMap&, DataMap&) const { ast::ExpressionList ctors; for (uint32_t i = 0; i < ctor_size; ++i) { - value.WithScalarAt(i, [&](auto&& s) { ctors.emplace_back(ctx.dst->Expr(s)); }); + ctors.emplace_back(build_scalar(value.Elements()[i])); } auto* el_ty = CreateASTTypeFor(ctx, vec->type()); @@ -81,8 +96,7 @@ void FoldConstants::Run(CloneContext& ctx, const DataMap&, DataMap&) const { } if (ty->is_scalar()) { - return value.WithScalarAt( - 0, [&](auto&& s) -> const ast::LiteralExpression* { return ctx.dst->Expr(s); }); + return build_scalar(value.Elements()[0]); } return nullptr; diff --git a/src/tint/transform/robustness.cc b/src/tint/transform/robustness.cc index a7f2a4abb5..46624b609a 100644 --- a/src/tint/transform/robustness.cc +++ b/src/tint/transform/robustness.cc @@ -123,10 +123,10 @@ struct Robustness::State { if (auto idx_constant = idx_sem->ConstantValue()) { // Constant value index if (idx_constant.Type()->Is()) { - idx.i32 = idx_constant.Elements()[0].i32; + idx.i32 = static_cast(idx_constant.Element(0).value); idx.is_signed = true; } else if (idx_constant.Type()->Is()) { - idx.u32 = idx_constant.Elements()[0].u32; + idx.u32 = static_cast(idx_constant.Element(0).value); idx.is_signed = false; } else { TINT_ICE(Transform, b.Diagnostics()) << "unsupported constant value for accessor " diff --git a/src/tint/transform/zero_init_workgroup_memory.cc b/src/tint/transform/zero_init_workgroup_memory.cc index 21a4565142..6e84310992 100644 --- a/src/tint/transform/zero_init_workgroup_memory.cc +++ b/src/tint/transform/zero_init_workgroup_memory.cc @@ -359,13 +359,8 @@ struct ZeroInitWorkgroupMemory::State { } auto* sem = ctx.src->Sem().Get(expr); if (auto c = sem->ConstantValue()) { - if (c.ElementType()->Is()) { - workgroup_size_const *= static_cast(c.Elements()[0].i32); - continue; - } else if (c.ElementType()->Is()) { - workgroup_size_const *= c.Elements()[0].u32; - continue; - } + workgroup_size_const *= c.Element(0).value; + continue; } // Constant value could not be found. Build expression instead. workgroup_size_expr = [this, expr, size = workgroup_size_expr] { diff --git a/src/tint/writer/hlsl/generator_impl.cc b/src/tint/writer/hlsl/generator_impl.cc index 9a66118514..45d57a29b3 100644 --- a/src/tint/writer/hlsl/generator_impl.cc +++ b/src/tint/writer/hlsl/generator_impl.cc @@ -660,13 +660,8 @@ bool GeneratorImpl::EmitExpressionOrOneIfZero(std::ostream& out, const ast::Expr if (i != 0) { out << ", "; } - if (!val.WithScalarAt(i, [&](auto&& s) -> bool { - // Use std::equal_to to work around -Wfloat-equal warnings - using T = std::remove_reference_t; - auto equal_to = std::equal_to{}; - bool is_zero = equal_to(s, T(0)); - return EmitValue(out, elem_ty, is_zero ? 1 : static_cast(s)); - })) { + auto s = val.Element(i).value; + if (!EmitValue(out, elem_ty, (s == 0) ? 1 : static_cast(s))) { return false; } } @@ -1191,7 +1186,7 @@ bool GeneratorImpl::EmitUniformBufferAccess( if (auto val = offset_arg->ConstantValue()) { TINT_ASSERT(Writer, val.Type()->Is()); - scalar_offset_value = val.Elements()[0].u32; + scalar_offset_value = static_cast(val.Element(0).value); scalar_offset_value /= 4; // bytes -> scalar index scalar_offset_constant = true; } @@ -2371,7 +2366,7 @@ bool GeneratorImpl::EmitTextureCall(std::ostream& out, case sem::BuiltinType::kTextureGather: out << ".Gather"; if (builtin->Parameters()[0]->Usage() == sem::ParameterUsage::kComponent) { - switch (call->Arguments()[0]->ConstantValue().Elements()[0].i32) { + switch (call->Arguments()[0]->ConstantValue().Element(0).value) { case 0: out << "Red"; break; diff --git a/src/tint/writer/msl/generator_impl.cc b/src/tint/writer/msl/generator_impl.cc index 743e03b58a..ce764fb865 100644 --- a/src/tint/writer/msl/generator_impl.cc +++ b/src/tint/writer/msl/generator_impl.cc @@ -1136,8 +1136,8 @@ bool GeneratorImpl::EmitTextureCall(std::ostream& out, break; // Other texture dimensions don't have an offset } } - auto c = component->ConstantValue().Elements()[0].i32; - switch (c) { + auto c = component->ConstantValue().Element(0); + switch (c.value) { case 0: out << "component::x"; break;