tint: Simplify sem::Constant::Scalar

Migrate from a hand-rolled tagged-union of [i32, u32, f32, f16, bool]
types. Instead use a std::variant of [AInt, AFloat, bool]. The Constant
holds the actual type, so no information is lost with the reduced types.

Note: Currently integer constants are still limited to 32-bits in size.
This is enforced by the frontend.

Bug: tint:1504
Change-Id: I316957787649c454fffb532334159d726cd1fb2d
Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/90643
Kokoro: Kokoro <noreply+kokoro@google.com>
Reviewed-by: James Price <jrprice@google.com>
Commit-Queue: Ben Clayton <bclayton@google.com>
This commit is contained in:
Ben Clayton 2022-05-17 20:51:04 +00:00 committed by Dawn LUCI CQ
parent c590373cfb
commit aaa9ba3043
11 changed files with 123 additions and 198 deletions

View File

@ -793,13 +793,12 @@ bool Resolver::WorkgroupSize(const ast::Function* func) {
continue;
}
// validator_.Validate and set the default value for this dimension.
if (is_i32 ? value.Elements()[0].i32 < 1 : value.Elements()[0].u32 < 1) {
if (value.Element<AInt>(0).value < 1) {
AddError("workgroup_size argument must be at least 1", values[i]->source);
return false;
}
ws[i].value =
is_i32 ? static_cast<uint32_t>(value.Elements()[0].i32) : value.Elements()[0].u32;
ws[i].value = static_cast<uint32_t>(value.Element<AInt>(0).value);
}
current_function_->SetWorkgroupSize(std::move(ws));
@ -1855,13 +1854,12 @@ sem::Array* Resolver::Array(const ast::Array* arr) {
return nullptr;
}
if (ty->is_signed_integer_scalar() ? count_val.Elements()[0].i32 < 1
: count_val.Elements()[0].u32 < 1u) {
if (count_val.Element<AInt>(0).value < 1) {
AddError("array size must be at least 1", size_source);
return nullptr;
}
count = count_val.Elements()[0].u32;
count = static_cast<uint32_t>(count_val.Element<AInt>(0).value);
}
auto size = std::max<uint64_t>(count, 1) * stride;

View File

@ -37,13 +37,10 @@ sem::Constant Resolver::EvaluateConstantValue(const ast::LiteralExpression* lite
return Switch(
literal,
[&](const ast::IntLiteralExpression* lit) {
if (lit->suffix == ast::IntLiteralExpression::Suffix::kU) {
return sem::Constant{type, {u32(lit->value)}};
}
return sem::Constant{type, {i32(lit->value)}};
return sem::Constant{type, {AInt(lit->value)}};
},
[&](const ast::FloatLiteralExpression* lit) {
return sem::Constant{type, {f32(lit->value)}};
return sem::Constant{type, {AFloat(lit->value)}};
},
[&](const ast::BoolLiteralExpression* lit) {
return sem::Constant{type, {lit->value}};
@ -64,21 +61,16 @@ sem::Constant Resolver::EvaluateConstantValue(const ast::CallExpression* call,
// For zero value init, return 0s
if (call->args.empty()) {
if (elem_type->Is<sem::I32>()) {
return sem::Constant(type, sem::Constant::Scalars(result_size, 0_i));
}
if (elem_type->Is<sem::U32>()) {
return sem::Constant(type, sem::Constant::Scalars(result_size, 0_u));
}
// Add f16 zero scalar here
if (elem_type->Is<sem::F16>()) {
return sem::Constant(type, sem::Constant::Scalars(result_size, f16{0.f}));
}
if (elem_type->Is<sem::F32>()) {
return sem::Constant(type, sem::Constant::Scalars(result_size, 0_f));
}
if (elem_type->Is<sem::Bool>()) {
return sem::Constant(type, sem::Constant::Scalars(result_size, false));
using Scalars = sem::Constant::Scalars;
auto constant = Switch(
elem_type,
[&](const sem::I32*) { return sem::Constant(type, Scalars(result_size, AInt(0))); },
[&](const sem::U32*) { return sem::Constant(type, Scalars(result_size, AInt(0))); },
[&](const sem::F32*) { return sem::Constant(type, Scalars(result_size, AFloat(0))); },
[&](const sem::F16*) { return sem::Constant(type, Scalars(result_size, AFloat(0))); },
[&](const sem::Bool*) { return sem::Constant(type, Scalars(result_size, false)); });
if (constant.IsValid()) {
return constant;
}
}
@ -112,33 +104,14 @@ sem::Constant Resolver::ConstantCast(const sem::Constant& value,
sem::Constant::Scalars elems;
for (size_t i = 0; i < value.Elements().size(); ++i) {
// TODO(crbug.com/tint/1504): Check that value fits in new type
elems.emplace_back(Switch<sem::Constant::Scalar>(
target_elem_type,
[&](const sem::I32*) {
return value.WithScalarAt(i, [](auto&& s) { //
return i32(static_cast<int32_t>(s));
});
},
[&](const sem::U32*) {
return value.WithScalarAt(i, [](auto&& s) { //
return u32(static_cast<uint32_t>(s));
});
},
[&](const sem::F16*) {
return value.WithScalarAt(i, [](auto&& s) { //
return f16{static_cast<float>(s)};
});
},
[&](const sem::F32*) {
return value.WithScalarAt(i, [](auto&& s) { //
return static_cast<f32>(s);
});
},
[&](const sem::Bool*) {
return value.WithScalarAt(i, [](auto&& s) { //
return static_cast<bool>(s);
});
},
target_elem_type, //
[&](const sem::I32*) { return value.ElementAs<AInt>(i); },
[&](const sem::U32*) { return value.ElementAs<AInt>(i); },
[&](const sem::F32*) { return value.ElementAs<AFloat>(i); },
[&](const sem::F16*) { return value.ElementAs<AFloat>(i); },
[&](const sem::Bool*) { return value.ElementAs<bool>(i); },
[&](Default) {
diag::List diags;
TINT_UNREACHABLE(Semantic, diags)

View File

@ -39,7 +39,7 @@ TEST_F(ResolverConstantsTest, Scalar_i32) {
EXPECT_EQ(sem->ConstantValue().Type(), sem->Type());
EXPECT_EQ(sem->ConstantValue().ElementType(), sem->Type());
ASSERT_EQ(sem->ConstantValue().Elements().size(), 1u);
EXPECT_EQ(sem->ConstantValue().Elements()[0].i32, 99);
EXPECT_EQ(sem->ConstantValue().Element<AInt>(0).value, 99);
}
TEST_F(ResolverConstantsTest, Scalar_u32) {
@ -54,7 +54,7 @@ TEST_F(ResolverConstantsTest, Scalar_u32) {
EXPECT_EQ(sem->ConstantValue().Type(), sem->Type());
EXPECT_EQ(sem->ConstantValue().ElementType(), sem->Type());
ASSERT_EQ(sem->ConstantValue().Elements().size(), 1u);
EXPECT_EQ(sem->ConstantValue().Elements()[0].u32, 99u);
EXPECT_EQ(sem->ConstantValue().Element<AInt>(0).value, 99u);
}
TEST_F(ResolverConstantsTest, Scalar_f32) {
@ -69,7 +69,7 @@ TEST_F(ResolverConstantsTest, Scalar_f32) {
EXPECT_EQ(sem->ConstantValue().Type(), sem->Type());
EXPECT_EQ(sem->ConstantValue().ElementType(), sem->Type());
ASSERT_EQ(sem->ConstantValue().Elements().size(), 1u);
EXPECT_EQ(sem->ConstantValue().Elements()[0].f32, 9.9f);
EXPECT_EQ(sem->ConstantValue().Element<AFloat>(0).value, 9.9f);
}
TEST_F(ResolverConstantsTest, Scalar_bool) {
@ -84,7 +84,7 @@ TEST_F(ResolverConstantsTest, Scalar_bool) {
EXPECT_EQ(sem->ConstantValue().Type(), sem->Type());
EXPECT_EQ(sem->ConstantValue().ElementType(), sem->Type());
ASSERT_EQ(sem->ConstantValue().Elements().size(), 1u);
EXPECT_EQ(sem->ConstantValue().Elements()[0].bool_, true);
EXPECT_EQ(sem->ConstantValue().Element<bool>(0), true);
}
TEST_F(ResolverConstantsTest, Vec3_ZeroInit_i32) {
@ -101,9 +101,9 @@ TEST_F(ResolverConstantsTest, Vec3_ZeroInit_i32) {
EXPECT_EQ(sem->ConstantValue().Type(), sem->Type());
EXPECT_TRUE(sem->ConstantValue().ElementType()->Is<sem::I32>());
ASSERT_EQ(sem->ConstantValue().Elements().size(), 3u);
EXPECT_EQ(sem->ConstantValue().Elements()[0].i32, 0);
EXPECT_EQ(sem->ConstantValue().Elements()[1].i32, 0);
EXPECT_EQ(sem->ConstantValue().Elements()[2].i32, 0);
EXPECT_EQ(sem->ConstantValue().Element<AInt>(0).value, 0);
EXPECT_EQ(sem->ConstantValue().Element<AInt>(1).value, 0);
EXPECT_EQ(sem->ConstantValue().Element<AInt>(2).value, 0);
}
TEST_F(ResolverConstantsTest, Vec3_ZeroInit_u32) {
@ -120,9 +120,9 @@ TEST_F(ResolverConstantsTest, Vec3_ZeroInit_u32) {
EXPECT_EQ(sem->ConstantValue().Type(), sem->Type());
EXPECT_TRUE(sem->ConstantValue().ElementType()->Is<sem::U32>());
ASSERT_EQ(sem->ConstantValue().Elements().size(), 3u);
EXPECT_EQ(sem->ConstantValue().Elements()[0].u32, 0u);
EXPECT_EQ(sem->ConstantValue().Elements()[1].u32, 0u);
EXPECT_EQ(sem->ConstantValue().Elements()[2].u32, 0u);
EXPECT_EQ(sem->ConstantValue().Element<AInt>(0).value, 0u);
EXPECT_EQ(sem->ConstantValue().Element<AInt>(1).value, 0u);
EXPECT_EQ(sem->ConstantValue().Element<AInt>(2).value, 0u);
}
TEST_F(ResolverConstantsTest, Vec3_ZeroInit_f32) {
@ -139,9 +139,9 @@ TEST_F(ResolverConstantsTest, Vec3_ZeroInit_f32) {
EXPECT_EQ(sem->ConstantValue().Type(), sem->Type());
EXPECT_TRUE(sem->ConstantValue().ElementType()->Is<sem::F32>());
ASSERT_EQ(sem->ConstantValue().Elements().size(), 3u);
EXPECT_EQ(sem->ConstantValue().Elements()[0].f32, 0.0f);
EXPECT_EQ(sem->ConstantValue().Elements()[1].f32, 0.0f);
EXPECT_EQ(sem->ConstantValue().Elements()[2].f32, 0.0f);
EXPECT_EQ(sem->ConstantValue().Element<AFloat>(0).value, 0.0);
EXPECT_EQ(sem->ConstantValue().Element<AFloat>(1).value, 0.0);
EXPECT_EQ(sem->ConstantValue().Element<AFloat>(2).value, 0.0);
}
TEST_F(ResolverConstantsTest, Vec3_ZeroInit_bool) {
@ -158,9 +158,9 @@ TEST_F(ResolverConstantsTest, Vec3_ZeroInit_bool) {
EXPECT_EQ(sem->ConstantValue().Type(), sem->Type());
EXPECT_TRUE(sem->ConstantValue().ElementType()->Is<sem::Bool>());
ASSERT_EQ(sem->ConstantValue().Elements().size(), 3u);
EXPECT_EQ(sem->ConstantValue().Elements()[0].bool_, false);
EXPECT_EQ(sem->ConstantValue().Elements()[1].bool_, false);
EXPECT_EQ(sem->ConstantValue().Elements()[2].bool_, false);
EXPECT_EQ(sem->ConstantValue().Element<bool>(0), false);
EXPECT_EQ(sem->ConstantValue().Element<bool>(1), false);
EXPECT_EQ(sem->ConstantValue().Element<bool>(2), false);
}
TEST_F(ResolverConstantsTest, Vec3_Splat_i32) {
@ -177,9 +177,9 @@ TEST_F(ResolverConstantsTest, Vec3_Splat_i32) {
EXPECT_EQ(sem->ConstantValue().Type(), sem->Type());
EXPECT_TRUE(sem->ConstantValue().ElementType()->Is<sem::I32>());
ASSERT_EQ(sem->ConstantValue().Elements().size(), 3u);
EXPECT_EQ(sem->ConstantValue().Elements()[0].i32, 99);
EXPECT_EQ(sem->ConstantValue().Elements()[1].i32, 99);
EXPECT_EQ(sem->ConstantValue().Elements()[2].i32, 99);
EXPECT_EQ(sem->ConstantValue().Element<AInt>(0).value, 99);
EXPECT_EQ(sem->ConstantValue().Element<AInt>(1).value, 99);
EXPECT_EQ(sem->ConstantValue().Element<AInt>(2).value, 99);
}
TEST_F(ResolverConstantsTest, Vec3_Splat_u32) {
@ -196,9 +196,9 @@ TEST_F(ResolverConstantsTest, Vec3_Splat_u32) {
EXPECT_EQ(sem->ConstantValue().Type(), sem->Type());
EXPECT_TRUE(sem->ConstantValue().ElementType()->Is<sem::U32>());
ASSERT_EQ(sem->ConstantValue().Elements().size(), 3u);
EXPECT_EQ(sem->ConstantValue().Elements()[0].u32, 99u);
EXPECT_EQ(sem->ConstantValue().Elements()[1].u32, 99u);
EXPECT_EQ(sem->ConstantValue().Elements()[2].u32, 99u);
EXPECT_EQ(sem->ConstantValue().Element<AInt>(0).value, 99u);
EXPECT_EQ(sem->ConstantValue().Element<AInt>(1).value, 99u);
EXPECT_EQ(sem->ConstantValue().Element<AInt>(2).value, 99u);
}
TEST_F(ResolverConstantsTest, Vec3_Splat_f32) {
@ -215,9 +215,9 @@ TEST_F(ResolverConstantsTest, Vec3_Splat_f32) {
EXPECT_EQ(sem->ConstantValue().Type(), sem->Type());
EXPECT_TRUE(sem->ConstantValue().ElementType()->Is<sem::F32>());
ASSERT_EQ(sem->ConstantValue().Elements().size(), 3u);
EXPECT_EQ(sem->ConstantValue().Elements()[0].f32, 9.9f);
EXPECT_EQ(sem->ConstantValue().Elements()[1].f32, 9.9f);
EXPECT_EQ(sem->ConstantValue().Elements()[2].f32, 9.9f);
EXPECT_EQ(sem->ConstantValue().Element<AFloat>(0).value, 9.9f);
EXPECT_EQ(sem->ConstantValue().Element<AFloat>(1).value, 9.9f);
EXPECT_EQ(sem->ConstantValue().Element<AFloat>(2).value, 9.9f);
}
TEST_F(ResolverConstantsTest, Vec3_Splat_bool) {
@ -234,9 +234,9 @@ TEST_F(ResolverConstantsTest, Vec3_Splat_bool) {
EXPECT_EQ(sem->ConstantValue().Type(), sem->Type());
EXPECT_TRUE(sem->ConstantValue().ElementType()->Is<sem::Bool>());
ASSERT_EQ(sem->ConstantValue().Elements().size(), 3u);
EXPECT_EQ(sem->ConstantValue().Elements()[0].bool_, true);
EXPECT_EQ(sem->ConstantValue().Elements()[1].bool_, true);
EXPECT_EQ(sem->ConstantValue().Elements()[2].bool_, true);
EXPECT_EQ(sem->ConstantValue().Element<bool>(0), true);
EXPECT_EQ(sem->ConstantValue().Element<bool>(1), true);
EXPECT_EQ(sem->ConstantValue().Element<bool>(2), true);
}
TEST_F(ResolverConstantsTest, Vec3_FullConstruct_i32) {
@ -253,9 +253,9 @@ TEST_F(ResolverConstantsTest, Vec3_FullConstruct_i32) {
EXPECT_EQ(sem->ConstantValue().Type(), sem->Type());
EXPECT_TRUE(sem->ConstantValue().ElementType()->Is<sem::I32>());
ASSERT_EQ(sem->ConstantValue().Elements().size(), 3u);
EXPECT_EQ(sem->ConstantValue().Elements()[0].i32, 1);
EXPECT_EQ(sem->ConstantValue().Elements()[1].i32, 2);
EXPECT_EQ(sem->ConstantValue().Elements()[2].i32, 3);
EXPECT_EQ(sem->ConstantValue().Element<AInt>(0).value, 1);
EXPECT_EQ(sem->ConstantValue().Element<AInt>(1).value, 2);
EXPECT_EQ(sem->ConstantValue().Element<AInt>(2).value, 3);
}
TEST_F(ResolverConstantsTest, Vec3_FullConstruct_u32) {
@ -272,9 +272,9 @@ TEST_F(ResolverConstantsTest, Vec3_FullConstruct_u32) {
EXPECT_EQ(sem->ConstantValue().Type(), sem->Type());
EXPECT_TRUE(sem->ConstantValue().ElementType()->Is<sem::U32>());
ASSERT_EQ(sem->ConstantValue().Elements().size(), 3u);
EXPECT_EQ(sem->ConstantValue().Elements()[0].u32, 1u);
EXPECT_EQ(sem->ConstantValue().Elements()[1].u32, 2u);
EXPECT_EQ(sem->ConstantValue().Elements()[2].u32, 3u);
EXPECT_EQ(sem->ConstantValue().Element<AInt>(0).value, 1);
EXPECT_EQ(sem->ConstantValue().Element<AInt>(1).value, 2);
EXPECT_EQ(sem->ConstantValue().Element<AInt>(2).value, 3);
}
TEST_F(ResolverConstantsTest, Vec3_FullConstruct_f32) {
@ -291,9 +291,9 @@ TEST_F(ResolverConstantsTest, Vec3_FullConstruct_f32) {
EXPECT_EQ(sem->ConstantValue().Type(), sem->Type());
EXPECT_TRUE(sem->ConstantValue().ElementType()->Is<sem::F32>());
ASSERT_EQ(sem->ConstantValue().Elements().size(), 3u);
EXPECT_EQ(sem->ConstantValue().Elements()[0].f32, 1.f);
EXPECT_EQ(sem->ConstantValue().Elements()[1].f32, 2.f);
EXPECT_EQ(sem->ConstantValue().Elements()[2].f32, 3.f);
EXPECT_EQ(sem->ConstantValue().Element<AFloat>(0).value, 1.f);
EXPECT_EQ(sem->ConstantValue().Element<AFloat>(1).value, 2.f);
EXPECT_EQ(sem->ConstantValue().Element<AFloat>(2).value, 3.f);
}
TEST_F(ResolverConstantsTest, Vec3_FullConstruct_bool) {
@ -310,9 +310,9 @@ TEST_F(ResolverConstantsTest, Vec3_FullConstruct_bool) {
EXPECT_EQ(sem->ConstantValue().Type(), sem->Type());
EXPECT_TRUE(sem->ConstantValue().ElementType()->Is<sem::Bool>());
ASSERT_EQ(sem->ConstantValue().Elements().size(), 3u);
EXPECT_EQ(sem->ConstantValue().Elements()[0].bool_, true);
EXPECT_EQ(sem->ConstantValue().Elements()[1].bool_, false);
EXPECT_EQ(sem->ConstantValue().Elements()[2].bool_, true);
EXPECT_EQ(sem->ConstantValue().Element<bool>(0), true);
EXPECT_EQ(sem->ConstantValue().Element<bool>(1), false);
EXPECT_EQ(sem->ConstantValue().Element<bool>(2), true);
}
TEST_F(ResolverConstantsTest, Vec3_MixConstruct_i32) {
@ -329,9 +329,9 @@ TEST_F(ResolverConstantsTest, Vec3_MixConstruct_i32) {
EXPECT_EQ(sem->ConstantValue().Type(), sem->Type());
EXPECT_TRUE(sem->ConstantValue().ElementType()->Is<sem::I32>());
ASSERT_EQ(sem->ConstantValue().Elements().size(), 3u);
EXPECT_EQ(sem->ConstantValue().Elements()[0].i32, 1);
EXPECT_EQ(sem->ConstantValue().Elements()[1].i32, 2);
EXPECT_EQ(sem->ConstantValue().Elements()[2].i32, 3);
EXPECT_EQ(sem->ConstantValue().Element<AInt>(0).value, 1);
EXPECT_EQ(sem->ConstantValue().Element<AInt>(1).value, 2);
EXPECT_EQ(sem->ConstantValue().Element<AInt>(2).value, 3);
}
TEST_F(ResolverConstantsTest, Vec3_MixConstruct_u32) {
@ -348,9 +348,9 @@ TEST_F(ResolverConstantsTest, Vec3_MixConstruct_u32) {
EXPECT_EQ(sem->ConstantValue().Type(), sem->Type());
EXPECT_TRUE(sem->ConstantValue().ElementType()->Is<sem::U32>());
ASSERT_EQ(sem->ConstantValue().Elements().size(), 3u);
EXPECT_EQ(sem->ConstantValue().Elements()[0].u32, 1u);
EXPECT_EQ(sem->ConstantValue().Elements()[1].u32, 2u);
EXPECT_EQ(sem->ConstantValue().Elements()[2].u32, 3u);
EXPECT_EQ(sem->ConstantValue().Element<AInt>(0).value, 1);
EXPECT_EQ(sem->ConstantValue().Element<AInt>(1).value, 2);
EXPECT_EQ(sem->ConstantValue().Element<AInt>(2).value, 3);
}
TEST_F(ResolverConstantsTest, Vec3_MixConstruct_f32) {
@ -367,9 +367,9 @@ TEST_F(ResolverConstantsTest, Vec3_MixConstruct_f32) {
EXPECT_EQ(sem->ConstantValue().Type(), sem->Type());
EXPECT_TRUE(sem->ConstantValue().ElementType()->Is<sem::F32>());
ASSERT_EQ(sem->ConstantValue().Elements().size(), 3u);
EXPECT_EQ(sem->ConstantValue().Elements()[0].f32, 1.f);
EXPECT_EQ(sem->ConstantValue().Elements()[1].f32, 2.f);
EXPECT_EQ(sem->ConstantValue().Elements()[2].f32, 3.f);
EXPECT_EQ(sem->ConstantValue().Element<AFloat>(0).value, 1.f);
EXPECT_EQ(sem->ConstantValue().Element<AFloat>(1).value, 2.f);
EXPECT_EQ(sem->ConstantValue().Element<AFloat>(2).value, 3.f);
}
TEST_F(ResolverConstantsTest, Vec3_MixConstruct_bool) {
@ -386,9 +386,9 @@ TEST_F(ResolverConstantsTest, Vec3_MixConstruct_bool) {
EXPECT_EQ(sem->ConstantValue().Type(), sem->Type());
EXPECT_TRUE(sem->ConstantValue().ElementType()->Is<sem::Bool>());
ASSERT_EQ(sem->ConstantValue().Elements().size(), 3u);
EXPECT_EQ(sem->ConstantValue().Elements()[0].bool_, true);
EXPECT_EQ(sem->ConstantValue().Elements()[1].bool_, false);
EXPECT_EQ(sem->ConstantValue().Elements()[2].bool_, true);
EXPECT_EQ(sem->ConstantValue().Element<bool>(0), true);
EXPECT_EQ(sem->ConstantValue().Element<bool>(1), false);
EXPECT_EQ(sem->ConstantValue().Element<bool>(2), true);
}
TEST_F(ResolverConstantsTest, Vec3_Cast_f32_to_32) {
@ -405,9 +405,9 @@ TEST_F(ResolverConstantsTest, Vec3_Cast_f32_to_32) {
EXPECT_EQ(sem->ConstantValue().Type(), sem->Type());
EXPECT_TRUE(sem->ConstantValue().ElementType()->Is<sem::I32>());
ASSERT_EQ(sem->ConstantValue().Elements().size(), 3u);
EXPECT_EQ(sem->ConstantValue().Elements()[0].i32, 1);
EXPECT_EQ(sem->ConstantValue().Elements()[1].i32, 2);
EXPECT_EQ(sem->ConstantValue().Elements()[2].i32, 3);
EXPECT_EQ(sem->ConstantValue().Element<AInt>(0).value, 1);
EXPECT_EQ(sem->ConstantValue().Element<AInt>(1).value, 2);
EXPECT_EQ(sem->ConstantValue().Element<AInt>(2).value, 3);
}
TEST_F(ResolverConstantsTest, Vec3_Cast_u32_to_f32) {
@ -424,9 +424,9 @@ TEST_F(ResolverConstantsTest, Vec3_Cast_u32_to_f32) {
EXPECT_EQ(sem->ConstantValue().Type(), sem->Type());
EXPECT_TRUE(sem->ConstantValue().ElementType()->Is<sem::F32>());
ASSERT_EQ(sem->ConstantValue().Elements().size(), 3u);
EXPECT_EQ(sem->ConstantValue().Elements()[0].f32, 10.f);
EXPECT_EQ(sem->ConstantValue().Elements()[1].f32, 20.f);
EXPECT_EQ(sem->ConstantValue().Elements()[2].f32, 30.f);
EXPECT_EQ(sem->ConstantValue().Element<AFloat>(0).value, 10.f);
EXPECT_EQ(sem->ConstantValue().Element<AFloat>(1).value, 20.f);
EXPECT_EQ(sem->ConstantValue().Element<AFloat>(2).value, 30.f);
}
} // namespace

View File

@ -1515,7 +1515,7 @@ bool Validator::TextureBuiltinFunction(const sem::Call* call) const {
if (is_const_expr) {
auto vector = builtin->Parameters()[index]->Type()->Is<sem::Vector>();
for (size_t i = 0; i < values.Elements().size(); i++) {
auto value = values.Elements()[i].i32;
auto value = values.Element<AInt>(i).value;
if (value < min || value > max) {
if (vector) {
AddError("each component of the " + name +

View File

@ -60,16 +60,12 @@ Constant::~Constant() = default;
Constant& Constant::operator=(const Constant& rhs) = default;
bool Constant::AnyZero() const {
for (size_t i = 0; i < Elements().size(); ++i) {
if (WithScalarAt(i, [&](auto&& s) {
// Use std::equal_to to work around -Wfloat-equal warnings
using T = std::remove_reference_t<decltype(s)>;
auto equal_to = std::equal_to<T>{};
if (equal_to(s, T(0))) {
return true;
}
return false;
})) {
for (auto scalar : elems_) {
auto is_zero = [&](auto&& s) {
using T = std::remove_reference_t<decltype(s)>;
return s == T(0);
};
if (std::visit(is_zero, scalar)) {
return true;
}
}

View File

@ -15,6 +15,7 @@
#ifndef SRC_TINT_SEM_CONSTANT_H_
#define SRC_TINT_SEM_CONSTANT_H_
#include <variant>
#include <vector>
#include "src/tint/program_builder.h"
@ -26,40 +27,8 @@ namespace tint::sem {
/// list of scalar values. Value may be of a scalar or vector type.
class Constant {
public:
/// Scalar holds a single constant scalar value, as a union of an i32, u32,
/// f32 or boolean.
union Scalar {
/// The scalar value as a i32
tint::i32 i32;
/// The scalar value as a u32
tint::u32 u32;
/// The scalar value as a f32
tint::f32 f32;
/// The scalar value as a f16, internally stored as float
tint::f16 f16;
/// The scalar value as a bool
bool bool_;
/// Constructs the scalar with the i32 value `v`
/// @param v the value of the Scalar
Scalar(tint::i32 v) : i32(v) {} // NOLINT
/// Constructs the scalar with the u32 value `v`
/// @param v the value of the Scalar
Scalar(tint::u32 v) : u32(v) {} // NOLINT
/// Constructs the scalar with the f32 value `v`
/// @param v the value of the Scalar
Scalar(tint::f32 v) : f32(v) {} // NOLINT
/// Constructs the scalar with the f16 value `v`
/// @param v the value of the Scalar
Scalar(tint::f16 v) : f16({v}) {} // NOLINT
/// Constructs the scalar with the bool value `v`
/// @param v the value of the Scalar
Scalar(bool v) : bool_(v) {} // NOLINT
};
/// Scalar holds a single constant scalar value - one of: AInt, AFloat or bool.
using Scalar = std::variant<AInt, AFloat, bool>;
/// Scalars is a list of scalar values
using Scalars = std::vector<Scalar>;
@ -101,33 +70,18 @@ class Constant {
/// @returns true if any scalar element is zero
bool AnyZero() const;
/// Calls `func(s)` with s being the current scalar value at `index`.
/// `func` is typically a lambda of the form '[](auto&& s)'.
/// @param index the index of the scalar value
/// @param func a function with signature `T(S)`
/// @return the value returned by func.
template <typename Func>
auto WithScalarAt(size_t index, Func&& func) const {
return Switch(
ElementType(), //
[&](const I32*) { return func(elems_[index].i32); },
[&](const U32*) { return func(elems_[index].u32); },
[&](const F16*) { return func(elems_[index].f16); },
[&](const F32*) { return func(elems_[index].f32); },
[&](const Bool*) { return func(elems_[index].bool_); },
[&](Default) {
diag::List diags;
TINT_UNREACHABLE(Semantic, diags)
<< "invalid scalar type " << type_->TypeInfo().name;
return func(u32(0u));
});
/// @return the value of the scalar at `index`, which must be of type `T`.
template <typename T>
T Element(size_t index) const {
return std::get<T>(elems_[index]);
}
/// @param index the index of the scalar value
/// @return the value of the scalar `static_cast` to type T.
template <typename T>
T ElementAs(size_t index) const {
return WithScalarAt(index, [](auto val) { return static_cast<T>(val); });
return std::visit([](auto val) { return static_cast<T>(val); }, elems_[index]);
}
private:

View File

@ -56,6 +56,21 @@ void FoldConstants::Run(CloneContext& ctx, const DataMap&, DataMap&) const {
return nullptr;
}
auto build_scalar = [&](sem::Constant::Scalar s) {
return Switch(
value.ElementType(), //
[&](const sem::I32*) { return ctx.dst->Expr(i32(std::get<AInt>(s).value)); },
[&](const sem::U32*) { return ctx.dst->Expr(u32(std::get<AInt>(s).value)); },
[&](const sem::F32*) { return ctx.dst->Expr(f32(std::get<AFloat>(s).value)); },
[&](const sem::Bool*) { return ctx.dst->Expr(std::get<bool>(s)); },
[&](Default) {
TINT_ICE(Transform, ctx.dst->Diagnostics())
<< "unhandled Constant::Scalar type: "
<< value.ElementType()->FriendlyName(ctx.src->Symbols());
return nullptr;
});
};
if (auto* vec = ty->As<sem::Vector>()) {
uint32_t vec_size = static_cast<uint32_t>(vec->Width());
@ -73,7 +88,7 @@ void FoldConstants::Run(CloneContext& ctx, const DataMap&, DataMap&) const {
ast::ExpressionList ctors;
for (uint32_t i = 0; i < ctor_size; ++i) {
value.WithScalarAt(i, [&](auto&& s) { ctors.emplace_back(ctx.dst->Expr(s)); });
ctors.emplace_back(build_scalar(value.Elements()[i]));
}
auto* el_ty = CreateASTTypeFor(ctx, vec->type());
@ -81,8 +96,7 @@ void FoldConstants::Run(CloneContext& ctx, const DataMap&, DataMap&) const {
}
if (ty->is_scalar()) {
return value.WithScalarAt(
0, [&](auto&& s) -> const ast::LiteralExpression* { return ctx.dst->Expr(s); });
return build_scalar(value.Elements()[0]);
}
return nullptr;

View File

@ -123,10 +123,10 @@ struct Robustness::State {
if (auto idx_constant = idx_sem->ConstantValue()) {
// Constant value index
if (idx_constant.Type()->Is<sem::I32>()) {
idx.i32 = idx_constant.Elements()[0].i32;
idx.i32 = static_cast<int32_t>(idx_constant.Element<AInt>(0).value);
idx.is_signed = true;
} else if (idx_constant.Type()->Is<sem::U32>()) {
idx.u32 = idx_constant.Elements()[0].u32;
idx.u32 = static_cast<uint32_t>(idx_constant.Element<AInt>(0).value);
idx.is_signed = false;
} else {
TINT_ICE(Transform, b.Diagnostics()) << "unsupported constant value for accessor "

View File

@ -359,13 +359,8 @@ struct ZeroInitWorkgroupMemory::State {
}
auto* sem = ctx.src->Sem().Get(expr);
if (auto c = sem->ConstantValue()) {
if (c.ElementType()->Is<sem::I32>()) {
workgroup_size_const *= static_cast<uint32_t>(c.Elements()[0].i32);
continue;
} else if (c.ElementType()->Is<sem::U32>()) {
workgroup_size_const *= c.Elements()[0].u32;
continue;
}
workgroup_size_const *= c.Element<AInt>(0).value;
continue;
}
// Constant value could not be found. Build expression instead.
workgroup_size_expr = [this, expr, size = workgroup_size_expr] {

View File

@ -660,13 +660,8 @@ bool GeneratorImpl::EmitExpressionOrOneIfZero(std::ostream& out, const ast::Expr
if (i != 0) {
out << ", ";
}
if (!val.WithScalarAt(i, [&](auto&& s) -> bool {
// Use std::equal_to to work around -Wfloat-equal warnings
using T = std::remove_reference_t<decltype(s)>;
auto equal_to = std::equal_to<T>{};
bool is_zero = equal_to(s, T(0));
return EmitValue(out, elem_ty, is_zero ? 1 : static_cast<int>(s));
})) {
auto s = val.Element<AInt>(i).value;
if (!EmitValue(out, elem_ty, (s == 0) ? 1 : static_cast<int>(s))) {
return false;
}
}
@ -1191,7 +1186,7 @@ bool GeneratorImpl::EmitUniformBufferAccess(
if (auto val = offset_arg->ConstantValue()) {
TINT_ASSERT(Writer, val.Type()->Is<sem::U32>());
scalar_offset_value = val.Elements()[0].u32;
scalar_offset_value = static_cast<uint32_t>(val.Element<AInt>(0).value);
scalar_offset_value /= 4; // bytes -> scalar index
scalar_offset_constant = true;
}
@ -2371,7 +2366,7 @@ bool GeneratorImpl::EmitTextureCall(std::ostream& out,
case sem::BuiltinType::kTextureGather:
out << ".Gather";
if (builtin->Parameters()[0]->Usage() == sem::ParameterUsage::kComponent) {
switch (call->Arguments()[0]->ConstantValue().Elements()[0].i32) {
switch (call->Arguments()[0]->ConstantValue().Element<AInt>(0).value) {
case 0:
out << "Red";
break;

View File

@ -1136,8 +1136,8 @@ bool GeneratorImpl::EmitTextureCall(std::ostream& out,
break; // Other texture dimensions don't have an offset
}
}
auto c = component->ConstantValue().Elements()[0].i32;
switch (c) {
auto c = component->ConstantValue().Element<AInt>(0);
switch (c.value) {
case 0:
out << "component::x";
break;