tint: Simplify sem::Constant::Scalar

Migrate from a hand-rolled tagged-union of [i32, u32, f32, f16, bool]
types. Instead use a std::variant of [AInt, AFloat, bool]. The Constant
holds the actual type, so no information is lost with the reduced types.

Note: Currently integer constants are still limited to 32-bits in size.
This is enforced by the frontend.

Bug: tint:1504
Change-Id: I316957787649c454fffb532334159d726cd1fb2d
Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/90643
Kokoro: Kokoro <noreply+kokoro@google.com>
Reviewed-by: James Price <jrprice@google.com>
Commit-Queue: Ben Clayton <bclayton@google.com>
This commit is contained in:
Ben Clayton 2022-05-17 20:51:04 +00:00 committed by Dawn LUCI CQ
parent c590373cfb
commit aaa9ba3043
11 changed files with 123 additions and 198 deletions

View File

@ -793,13 +793,12 @@ bool Resolver::WorkgroupSize(const ast::Function* func) {
continue; continue;
} }
// validator_.Validate and set the default value for this dimension. // validator_.Validate and set the default value for this dimension.
if (is_i32 ? value.Elements()[0].i32 < 1 : value.Elements()[0].u32 < 1) { if (value.Element<AInt>(0).value < 1) {
AddError("workgroup_size argument must be at least 1", values[i]->source); AddError("workgroup_size argument must be at least 1", values[i]->source);
return false; return false;
} }
ws[i].value = ws[i].value = static_cast<uint32_t>(value.Element<AInt>(0).value);
is_i32 ? static_cast<uint32_t>(value.Elements()[0].i32) : value.Elements()[0].u32;
} }
current_function_->SetWorkgroupSize(std::move(ws)); current_function_->SetWorkgroupSize(std::move(ws));
@ -1855,13 +1854,12 @@ sem::Array* Resolver::Array(const ast::Array* arr) {
return nullptr; return nullptr;
} }
if (ty->is_signed_integer_scalar() ? count_val.Elements()[0].i32 < 1 if (count_val.Element<AInt>(0).value < 1) {
: count_val.Elements()[0].u32 < 1u) {
AddError("array size must be at least 1", size_source); AddError("array size must be at least 1", size_source);
return nullptr; return nullptr;
} }
count = count_val.Elements()[0].u32; count = static_cast<uint32_t>(count_val.Element<AInt>(0).value);
} }
auto size = std::max<uint64_t>(count, 1) * stride; auto size = std::max<uint64_t>(count, 1) * stride;

View File

@ -37,13 +37,10 @@ sem::Constant Resolver::EvaluateConstantValue(const ast::LiteralExpression* lite
return Switch( return Switch(
literal, literal,
[&](const ast::IntLiteralExpression* lit) { [&](const ast::IntLiteralExpression* lit) {
if (lit->suffix == ast::IntLiteralExpression::Suffix::kU) { return sem::Constant{type, {AInt(lit->value)}};
return sem::Constant{type, {u32(lit->value)}};
}
return sem::Constant{type, {i32(lit->value)}};
}, },
[&](const ast::FloatLiteralExpression* lit) { [&](const ast::FloatLiteralExpression* lit) {
return sem::Constant{type, {f32(lit->value)}}; return sem::Constant{type, {AFloat(lit->value)}};
}, },
[&](const ast::BoolLiteralExpression* lit) { [&](const ast::BoolLiteralExpression* lit) {
return sem::Constant{type, {lit->value}}; return sem::Constant{type, {lit->value}};
@ -64,21 +61,16 @@ sem::Constant Resolver::EvaluateConstantValue(const ast::CallExpression* call,
// For zero value init, return 0s // For zero value init, return 0s
if (call->args.empty()) { if (call->args.empty()) {
if (elem_type->Is<sem::I32>()) { using Scalars = sem::Constant::Scalars;
return sem::Constant(type, sem::Constant::Scalars(result_size, 0_i)); auto constant = Switch(
} elem_type,
if (elem_type->Is<sem::U32>()) { [&](const sem::I32*) { return sem::Constant(type, Scalars(result_size, AInt(0))); },
return sem::Constant(type, sem::Constant::Scalars(result_size, 0_u)); [&](const sem::U32*) { return sem::Constant(type, Scalars(result_size, AInt(0))); },
} [&](const sem::F32*) { return sem::Constant(type, Scalars(result_size, AFloat(0))); },
// Add f16 zero scalar here [&](const sem::F16*) { return sem::Constant(type, Scalars(result_size, AFloat(0))); },
if (elem_type->Is<sem::F16>()) { [&](const sem::Bool*) { return sem::Constant(type, Scalars(result_size, false)); });
return sem::Constant(type, sem::Constant::Scalars(result_size, f16{0.f})); if (constant.IsValid()) {
} return constant;
if (elem_type->Is<sem::F32>()) {
return sem::Constant(type, sem::Constant::Scalars(result_size, 0_f));
}
if (elem_type->Is<sem::Bool>()) {
return sem::Constant(type, sem::Constant::Scalars(result_size, false));
} }
} }
@ -112,33 +104,14 @@ sem::Constant Resolver::ConstantCast(const sem::Constant& value,
sem::Constant::Scalars elems; sem::Constant::Scalars elems;
for (size_t i = 0; i < value.Elements().size(); ++i) { for (size_t i = 0; i < value.Elements().size(); ++i) {
// TODO(crbug.com/tint/1504): Check that value fits in new type
elems.emplace_back(Switch<sem::Constant::Scalar>( elems.emplace_back(Switch<sem::Constant::Scalar>(
target_elem_type, target_elem_type, //
[&](const sem::I32*) { [&](const sem::I32*) { return value.ElementAs<AInt>(i); },
return value.WithScalarAt(i, [](auto&& s) { // [&](const sem::U32*) { return value.ElementAs<AInt>(i); },
return i32(static_cast<int32_t>(s)); [&](const sem::F32*) { return value.ElementAs<AFloat>(i); },
}); [&](const sem::F16*) { return value.ElementAs<AFloat>(i); },
}, [&](const sem::Bool*) { return value.ElementAs<bool>(i); },
[&](const sem::U32*) {
return value.WithScalarAt(i, [](auto&& s) { //
return u32(static_cast<uint32_t>(s));
});
},
[&](const sem::F16*) {
return value.WithScalarAt(i, [](auto&& s) { //
return f16{static_cast<float>(s)};
});
},
[&](const sem::F32*) {
return value.WithScalarAt(i, [](auto&& s) { //
return static_cast<f32>(s);
});
},
[&](const sem::Bool*) {
return value.WithScalarAt(i, [](auto&& s) { //
return static_cast<bool>(s);
});
},
[&](Default) { [&](Default) {
diag::List diags; diag::List diags;
TINT_UNREACHABLE(Semantic, diags) TINT_UNREACHABLE(Semantic, diags)

View File

@ -39,7 +39,7 @@ TEST_F(ResolverConstantsTest, Scalar_i32) {
EXPECT_EQ(sem->ConstantValue().Type(), sem->Type()); EXPECT_EQ(sem->ConstantValue().Type(), sem->Type());
EXPECT_EQ(sem->ConstantValue().ElementType(), sem->Type()); EXPECT_EQ(sem->ConstantValue().ElementType(), sem->Type());
ASSERT_EQ(sem->ConstantValue().Elements().size(), 1u); ASSERT_EQ(sem->ConstantValue().Elements().size(), 1u);
EXPECT_EQ(sem->ConstantValue().Elements()[0].i32, 99); EXPECT_EQ(sem->ConstantValue().Element<AInt>(0).value, 99);
} }
TEST_F(ResolverConstantsTest, Scalar_u32) { TEST_F(ResolverConstantsTest, Scalar_u32) {
@ -54,7 +54,7 @@ TEST_F(ResolverConstantsTest, Scalar_u32) {
EXPECT_EQ(sem->ConstantValue().Type(), sem->Type()); EXPECT_EQ(sem->ConstantValue().Type(), sem->Type());
EXPECT_EQ(sem->ConstantValue().ElementType(), sem->Type()); EXPECT_EQ(sem->ConstantValue().ElementType(), sem->Type());
ASSERT_EQ(sem->ConstantValue().Elements().size(), 1u); ASSERT_EQ(sem->ConstantValue().Elements().size(), 1u);
EXPECT_EQ(sem->ConstantValue().Elements()[0].u32, 99u); EXPECT_EQ(sem->ConstantValue().Element<AInt>(0).value, 99u);
} }
TEST_F(ResolverConstantsTest, Scalar_f32) { TEST_F(ResolverConstantsTest, Scalar_f32) {
@ -69,7 +69,7 @@ TEST_F(ResolverConstantsTest, Scalar_f32) {
EXPECT_EQ(sem->ConstantValue().Type(), sem->Type()); EXPECT_EQ(sem->ConstantValue().Type(), sem->Type());
EXPECT_EQ(sem->ConstantValue().ElementType(), sem->Type()); EXPECT_EQ(sem->ConstantValue().ElementType(), sem->Type());
ASSERT_EQ(sem->ConstantValue().Elements().size(), 1u); ASSERT_EQ(sem->ConstantValue().Elements().size(), 1u);
EXPECT_EQ(sem->ConstantValue().Elements()[0].f32, 9.9f); EXPECT_EQ(sem->ConstantValue().Element<AFloat>(0).value, 9.9f);
} }
TEST_F(ResolverConstantsTest, Scalar_bool) { TEST_F(ResolverConstantsTest, Scalar_bool) {
@ -84,7 +84,7 @@ TEST_F(ResolverConstantsTest, Scalar_bool) {
EXPECT_EQ(sem->ConstantValue().Type(), sem->Type()); EXPECT_EQ(sem->ConstantValue().Type(), sem->Type());
EXPECT_EQ(sem->ConstantValue().ElementType(), sem->Type()); EXPECT_EQ(sem->ConstantValue().ElementType(), sem->Type());
ASSERT_EQ(sem->ConstantValue().Elements().size(), 1u); ASSERT_EQ(sem->ConstantValue().Elements().size(), 1u);
EXPECT_EQ(sem->ConstantValue().Elements()[0].bool_, true); EXPECT_EQ(sem->ConstantValue().Element<bool>(0), true);
} }
TEST_F(ResolverConstantsTest, Vec3_ZeroInit_i32) { TEST_F(ResolverConstantsTest, Vec3_ZeroInit_i32) {
@ -101,9 +101,9 @@ TEST_F(ResolverConstantsTest, Vec3_ZeroInit_i32) {
EXPECT_EQ(sem->ConstantValue().Type(), sem->Type()); EXPECT_EQ(sem->ConstantValue().Type(), sem->Type());
EXPECT_TRUE(sem->ConstantValue().ElementType()->Is<sem::I32>()); EXPECT_TRUE(sem->ConstantValue().ElementType()->Is<sem::I32>());
ASSERT_EQ(sem->ConstantValue().Elements().size(), 3u); ASSERT_EQ(sem->ConstantValue().Elements().size(), 3u);
EXPECT_EQ(sem->ConstantValue().Elements()[0].i32, 0); EXPECT_EQ(sem->ConstantValue().Element<AInt>(0).value, 0);
EXPECT_EQ(sem->ConstantValue().Elements()[1].i32, 0); EXPECT_EQ(sem->ConstantValue().Element<AInt>(1).value, 0);
EXPECT_EQ(sem->ConstantValue().Elements()[2].i32, 0); EXPECT_EQ(sem->ConstantValue().Element<AInt>(2).value, 0);
} }
TEST_F(ResolverConstantsTest, Vec3_ZeroInit_u32) { TEST_F(ResolverConstantsTest, Vec3_ZeroInit_u32) {
@ -120,9 +120,9 @@ TEST_F(ResolverConstantsTest, Vec3_ZeroInit_u32) {
EXPECT_EQ(sem->ConstantValue().Type(), sem->Type()); EXPECT_EQ(sem->ConstantValue().Type(), sem->Type());
EXPECT_TRUE(sem->ConstantValue().ElementType()->Is<sem::U32>()); EXPECT_TRUE(sem->ConstantValue().ElementType()->Is<sem::U32>());
ASSERT_EQ(sem->ConstantValue().Elements().size(), 3u); ASSERT_EQ(sem->ConstantValue().Elements().size(), 3u);
EXPECT_EQ(sem->ConstantValue().Elements()[0].u32, 0u); EXPECT_EQ(sem->ConstantValue().Element<AInt>(0).value, 0u);
EXPECT_EQ(sem->ConstantValue().Elements()[1].u32, 0u); EXPECT_EQ(sem->ConstantValue().Element<AInt>(1).value, 0u);
EXPECT_EQ(sem->ConstantValue().Elements()[2].u32, 0u); EXPECT_EQ(sem->ConstantValue().Element<AInt>(2).value, 0u);
} }
TEST_F(ResolverConstantsTest, Vec3_ZeroInit_f32) { TEST_F(ResolverConstantsTest, Vec3_ZeroInit_f32) {
@ -139,9 +139,9 @@ TEST_F(ResolverConstantsTest, Vec3_ZeroInit_f32) {
EXPECT_EQ(sem->ConstantValue().Type(), sem->Type()); EXPECT_EQ(sem->ConstantValue().Type(), sem->Type());
EXPECT_TRUE(sem->ConstantValue().ElementType()->Is<sem::F32>()); EXPECT_TRUE(sem->ConstantValue().ElementType()->Is<sem::F32>());
ASSERT_EQ(sem->ConstantValue().Elements().size(), 3u); ASSERT_EQ(sem->ConstantValue().Elements().size(), 3u);
EXPECT_EQ(sem->ConstantValue().Elements()[0].f32, 0.0f); EXPECT_EQ(sem->ConstantValue().Element<AFloat>(0).value, 0.0);
EXPECT_EQ(sem->ConstantValue().Elements()[1].f32, 0.0f); EXPECT_EQ(sem->ConstantValue().Element<AFloat>(1).value, 0.0);
EXPECT_EQ(sem->ConstantValue().Elements()[2].f32, 0.0f); EXPECT_EQ(sem->ConstantValue().Element<AFloat>(2).value, 0.0);
} }
TEST_F(ResolverConstantsTest, Vec3_ZeroInit_bool) { TEST_F(ResolverConstantsTest, Vec3_ZeroInit_bool) {
@ -158,9 +158,9 @@ TEST_F(ResolverConstantsTest, Vec3_ZeroInit_bool) {
EXPECT_EQ(sem->ConstantValue().Type(), sem->Type()); EXPECT_EQ(sem->ConstantValue().Type(), sem->Type());
EXPECT_TRUE(sem->ConstantValue().ElementType()->Is<sem::Bool>()); EXPECT_TRUE(sem->ConstantValue().ElementType()->Is<sem::Bool>());
ASSERT_EQ(sem->ConstantValue().Elements().size(), 3u); ASSERT_EQ(sem->ConstantValue().Elements().size(), 3u);
EXPECT_EQ(sem->ConstantValue().Elements()[0].bool_, false); EXPECT_EQ(sem->ConstantValue().Element<bool>(0), false);
EXPECT_EQ(sem->ConstantValue().Elements()[1].bool_, false); EXPECT_EQ(sem->ConstantValue().Element<bool>(1), false);
EXPECT_EQ(sem->ConstantValue().Elements()[2].bool_, false); EXPECT_EQ(sem->ConstantValue().Element<bool>(2), false);
} }
TEST_F(ResolverConstantsTest, Vec3_Splat_i32) { TEST_F(ResolverConstantsTest, Vec3_Splat_i32) {
@ -177,9 +177,9 @@ TEST_F(ResolverConstantsTest, Vec3_Splat_i32) {
EXPECT_EQ(sem->ConstantValue().Type(), sem->Type()); EXPECT_EQ(sem->ConstantValue().Type(), sem->Type());
EXPECT_TRUE(sem->ConstantValue().ElementType()->Is<sem::I32>()); EXPECT_TRUE(sem->ConstantValue().ElementType()->Is<sem::I32>());
ASSERT_EQ(sem->ConstantValue().Elements().size(), 3u); ASSERT_EQ(sem->ConstantValue().Elements().size(), 3u);
EXPECT_EQ(sem->ConstantValue().Elements()[0].i32, 99); EXPECT_EQ(sem->ConstantValue().Element<AInt>(0).value, 99);
EXPECT_EQ(sem->ConstantValue().Elements()[1].i32, 99); EXPECT_EQ(sem->ConstantValue().Element<AInt>(1).value, 99);
EXPECT_EQ(sem->ConstantValue().Elements()[2].i32, 99); EXPECT_EQ(sem->ConstantValue().Element<AInt>(2).value, 99);
} }
TEST_F(ResolverConstantsTest, Vec3_Splat_u32) { TEST_F(ResolverConstantsTest, Vec3_Splat_u32) {
@ -196,9 +196,9 @@ TEST_F(ResolverConstantsTest, Vec3_Splat_u32) {
EXPECT_EQ(sem->ConstantValue().Type(), sem->Type()); EXPECT_EQ(sem->ConstantValue().Type(), sem->Type());
EXPECT_TRUE(sem->ConstantValue().ElementType()->Is<sem::U32>()); EXPECT_TRUE(sem->ConstantValue().ElementType()->Is<sem::U32>());
ASSERT_EQ(sem->ConstantValue().Elements().size(), 3u); ASSERT_EQ(sem->ConstantValue().Elements().size(), 3u);
EXPECT_EQ(sem->ConstantValue().Elements()[0].u32, 99u); EXPECT_EQ(sem->ConstantValue().Element<AInt>(0).value, 99u);
EXPECT_EQ(sem->ConstantValue().Elements()[1].u32, 99u); EXPECT_EQ(sem->ConstantValue().Element<AInt>(1).value, 99u);
EXPECT_EQ(sem->ConstantValue().Elements()[2].u32, 99u); EXPECT_EQ(sem->ConstantValue().Element<AInt>(2).value, 99u);
} }
TEST_F(ResolverConstantsTest, Vec3_Splat_f32) { TEST_F(ResolverConstantsTest, Vec3_Splat_f32) {
@ -215,9 +215,9 @@ TEST_F(ResolverConstantsTest, Vec3_Splat_f32) {
EXPECT_EQ(sem->ConstantValue().Type(), sem->Type()); EXPECT_EQ(sem->ConstantValue().Type(), sem->Type());
EXPECT_TRUE(sem->ConstantValue().ElementType()->Is<sem::F32>()); EXPECT_TRUE(sem->ConstantValue().ElementType()->Is<sem::F32>());
ASSERT_EQ(sem->ConstantValue().Elements().size(), 3u); ASSERT_EQ(sem->ConstantValue().Elements().size(), 3u);
EXPECT_EQ(sem->ConstantValue().Elements()[0].f32, 9.9f); EXPECT_EQ(sem->ConstantValue().Element<AFloat>(0).value, 9.9f);
EXPECT_EQ(sem->ConstantValue().Elements()[1].f32, 9.9f); EXPECT_EQ(sem->ConstantValue().Element<AFloat>(1).value, 9.9f);
EXPECT_EQ(sem->ConstantValue().Elements()[2].f32, 9.9f); EXPECT_EQ(sem->ConstantValue().Element<AFloat>(2).value, 9.9f);
} }
TEST_F(ResolverConstantsTest, Vec3_Splat_bool) { TEST_F(ResolverConstantsTest, Vec3_Splat_bool) {
@ -234,9 +234,9 @@ TEST_F(ResolverConstantsTest, Vec3_Splat_bool) {
EXPECT_EQ(sem->ConstantValue().Type(), sem->Type()); EXPECT_EQ(sem->ConstantValue().Type(), sem->Type());
EXPECT_TRUE(sem->ConstantValue().ElementType()->Is<sem::Bool>()); EXPECT_TRUE(sem->ConstantValue().ElementType()->Is<sem::Bool>());
ASSERT_EQ(sem->ConstantValue().Elements().size(), 3u); ASSERT_EQ(sem->ConstantValue().Elements().size(), 3u);
EXPECT_EQ(sem->ConstantValue().Elements()[0].bool_, true); EXPECT_EQ(sem->ConstantValue().Element<bool>(0), true);
EXPECT_EQ(sem->ConstantValue().Elements()[1].bool_, true); EXPECT_EQ(sem->ConstantValue().Element<bool>(1), true);
EXPECT_EQ(sem->ConstantValue().Elements()[2].bool_, true); EXPECT_EQ(sem->ConstantValue().Element<bool>(2), true);
} }
TEST_F(ResolverConstantsTest, Vec3_FullConstruct_i32) { TEST_F(ResolverConstantsTest, Vec3_FullConstruct_i32) {
@ -253,9 +253,9 @@ TEST_F(ResolverConstantsTest, Vec3_FullConstruct_i32) {
EXPECT_EQ(sem->ConstantValue().Type(), sem->Type()); EXPECT_EQ(sem->ConstantValue().Type(), sem->Type());
EXPECT_TRUE(sem->ConstantValue().ElementType()->Is<sem::I32>()); EXPECT_TRUE(sem->ConstantValue().ElementType()->Is<sem::I32>());
ASSERT_EQ(sem->ConstantValue().Elements().size(), 3u); ASSERT_EQ(sem->ConstantValue().Elements().size(), 3u);
EXPECT_EQ(sem->ConstantValue().Elements()[0].i32, 1); EXPECT_EQ(sem->ConstantValue().Element<AInt>(0).value, 1);
EXPECT_EQ(sem->ConstantValue().Elements()[1].i32, 2); EXPECT_EQ(sem->ConstantValue().Element<AInt>(1).value, 2);
EXPECT_EQ(sem->ConstantValue().Elements()[2].i32, 3); EXPECT_EQ(sem->ConstantValue().Element<AInt>(2).value, 3);
} }
TEST_F(ResolverConstantsTest, Vec3_FullConstruct_u32) { TEST_F(ResolverConstantsTest, Vec3_FullConstruct_u32) {
@ -272,9 +272,9 @@ TEST_F(ResolverConstantsTest, Vec3_FullConstruct_u32) {
EXPECT_EQ(sem->ConstantValue().Type(), sem->Type()); EXPECT_EQ(sem->ConstantValue().Type(), sem->Type());
EXPECT_TRUE(sem->ConstantValue().ElementType()->Is<sem::U32>()); EXPECT_TRUE(sem->ConstantValue().ElementType()->Is<sem::U32>());
ASSERT_EQ(sem->ConstantValue().Elements().size(), 3u); ASSERT_EQ(sem->ConstantValue().Elements().size(), 3u);
EXPECT_EQ(sem->ConstantValue().Elements()[0].u32, 1u); EXPECT_EQ(sem->ConstantValue().Element<AInt>(0).value, 1);
EXPECT_EQ(sem->ConstantValue().Elements()[1].u32, 2u); EXPECT_EQ(sem->ConstantValue().Element<AInt>(1).value, 2);
EXPECT_EQ(sem->ConstantValue().Elements()[2].u32, 3u); EXPECT_EQ(sem->ConstantValue().Element<AInt>(2).value, 3);
} }
TEST_F(ResolverConstantsTest, Vec3_FullConstruct_f32) { TEST_F(ResolverConstantsTest, Vec3_FullConstruct_f32) {
@ -291,9 +291,9 @@ TEST_F(ResolverConstantsTest, Vec3_FullConstruct_f32) {
EXPECT_EQ(sem->ConstantValue().Type(), sem->Type()); EXPECT_EQ(sem->ConstantValue().Type(), sem->Type());
EXPECT_TRUE(sem->ConstantValue().ElementType()->Is<sem::F32>()); EXPECT_TRUE(sem->ConstantValue().ElementType()->Is<sem::F32>());
ASSERT_EQ(sem->ConstantValue().Elements().size(), 3u); ASSERT_EQ(sem->ConstantValue().Elements().size(), 3u);
EXPECT_EQ(sem->ConstantValue().Elements()[0].f32, 1.f); EXPECT_EQ(sem->ConstantValue().Element<AFloat>(0).value, 1.f);
EXPECT_EQ(sem->ConstantValue().Elements()[1].f32, 2.f); EXPECT_EQ(sem->ConstantValue().Element<AFloat>(1).value, 2.f);
EXPECT_EQ(sem->ConstantValue().Elements()[2].f32, 3.f); EXPECT_EQ(sem->ConstantValue().Element<AFloat>(2).value, 3.f);
} }
TEST_F(ResolverConstantsTest, Vec3_FullConstruct_bool) { TEST_F(ResolverConstantsTest, Vec3_FullConstruct_bool) {
@ -310,9 +310,9 @@ TEST_F(ResolverConstantsTest, Vec3_FullConstruct_bool) {
EXPECT_EQ(sem->ConstantValue().Type(), sem->Type()); EXPECT_EQ(sem->ConstantValue().Type(), sem->Type());
EXPECT_TRUE(sem->ConstantValue().ElementType()->Is<sem::Bool>()); EXPECT_TRUE(sem->ConstantValue().ElementType()->Is<sem::Bool>());
ASSERT_EQ(sem->ConstantValue().Elements().size(), 3u); ASSERT_EQ(sem->ConstantValue().Elements().size(), 3u);
EXPECT_EQ(sem->ConstantValue().Elements()[0].bool_, true); EXPECT_EQ(sem->ConstantValue().Element<bool>(0), true);
EXPECT_EQ(sem->ConstantValue().Elements()[1].bool_, false); EXPECT_EQ(sem->ConstantValue().Element<bool>(1), false);
EXPECT_EQ(sem->ConstantValue().Elements()[2].bool_, true); EXPECT_EQ(sem->ConstantValue().Element<bool>(2), true);
} }
TEST_F(ResolverConstantsTest, Vec3_MixConstruct_i32) { TEST_F(ResolverConstantsTest, Vec3_MixConstruct_i32) {
@ -329,9 +329,9 @@ TEST_F(ResolverConstantsTest, Vec3_MixConstruct_i32) {
EXPECT_EQ(sem->ConstantValue().Type(), sem->Type()); EXPECT_EQ(sem->ConstantValue().Type(), sem->Type());
EXPECT_TRUE(sem->ConstantValue().ElementType()->Is<sem::I32>()); EXPECT_TRUE(sem->ConstantValue().ElementType()->Is<sem::I32>());
ASSERT_EQ(sem->ConstantValue().Elements().size(), 3u); ASSERT_EQ(sem->ConstantValue().Elements().size(), 3u);
EXPECT_EQ(sem->ConstantValue().Elements()[0].i32, 1); EXPECT_EQ(sem->ConstantValue().Element<AInt>(0).value, 1);
EXPECT_EQ(sem->ConstantValue().Elements()[1].i32, 2); EXPECT_EQ(sem->ConstantValue().Element<AInt>(1).value, 2);
EXPECT_EQ(sem->ConstantValue().Elements()[2].i32, 3); EXPECT_EQ(sem->ConstantValue().Element<AInt>(2).value, 3);
} }
TEST_F(ResolverConstantsTest, Vec3_MixConstruct_u32) { TEST_F(ResolverConstantsTest, Vec3_MixConstruct_u32) {
@ -348,9 +348,9 @@ TEST_F(ResolverConstantsTest, Vec3_MixConstruct_u32) {
EXPECT_EQ(sem->ConstantValue().Type(), sem->Type()); EXPECT_EQ(sem->ConstantValue().Type(), sem->Type());
EXPECT_TRUE(sem->ConstantValue().ElementType()->Is<sem::U32>()); EXPECT_TRUE(sem->ConstantValue().ElementType()->Is<sem::U32>());
ASSERT_EQ(sem->ConstantValue().Elements().size(), 3u); ASSERT_EQ(sem->ConstantValue().Elements().size(), 3u);
EXPECT_EQ(sem->ConstantValue().Elements()[0].u32, 1u); EXPECT_EQ(sem->ConstantValue().Element<AInt>(0).value, 1);
EXPECT_EQ(sem->ConstantValue().Elements()[1].u32, 2u); EXPECT_EQ(sem->ConstantValue().Element<AInt>(1).value, 2);
EXPECT_EQ(sem->ConstantValue().Elements()[2].u32, 3u); EXPECT_EQ(sem->ConstantValue().Element<AInt>(2).value, 3);
} }
TEST_F(ResolverConstantsTest, Vec3_MixConstruct_f32) { TEST_F(ResolverConstantsTest, Vec3_MixConstruct_f32) {
@ -367,9 +367,9 @@ TEST_F(ResolverConstantsTest, Vec3_MixConstruct_f32) {
EXPECT_EQ(sem->ConstantValue().Type(), sem->Type()); EXPECT_EQ(sem->ConstantValue().Type(), sem->Type());
EXPECT_TRUE(sem->ConstantValue().ElementType()->Is<sem::F32>()); EXPECT_TRUE(sem->ConstantValue().ElementType()->Is<sem::F32>());
ASSERT_EQ(sem->ConstantValue().Elements().size(), 3u); ASSERT_EQ(sem->ConstantValue().Elements().size(), 3u);
EXPECT_EQ(sem->ConstantValue().Elements()[0].f32, 1.f); EXPECT_EQ(sem->ConstantValue().Element<AFloat>(0).value, 1.f);
EXPECT_EQ(sem->ConstantValue().Elements()[1].f32, 2.f); EXPECT_EQ(sem->ConstantValue().Element<AFloat>(1).value, 2.f);
EXPECT_EQ(sem->ConstantValue().Elements()[2].f32, 3.f); EXPECT_EQ(sem->ConstantValue().Element<AFloat>(2).value, 3.f);
} }
TEST_F(ResolverConstantsTest, Vec3_MixConstruct_bool) { TEST_F(ResolverConstantsTest, Vec3_MixConstruct_bool) {
@ -386,9 +386,9 @@ TEST_F(ResolverConstantsTest, Vec3_MixConstruct_bool) {
EXPECT_EQ(sem->ConstantValue().Type(), sem->Type()); EXPECT_EQ(sem->ConstantValue().Type(), sem->Type());
EXPECT_TRUE(sem->ConstantValue().ElementType()->Is<sem::Bool>()); EXPECT_TRUE(sem->ConstantValue().ElementType()->Is<sem::Bool>());
ASSERT_EQ(sem->ConstantValue().Elements().size(), 3u); ASSERT_EQ(sem->ConstantValue().Elements().size(), 3u);
EXPECT_EQ(sem->ConstantValue().Elements()[0].bool_, true); EXPECT_EQ(sem->ConstantValue().Element<bool>(0), true);
EXPECT_EQ(sem->ConstantValue().Elements()[1].bool_, false); EXPECT_EQ(sem->ConstantValue().Element<bool>(1), false);
EXPECT_EQ(sem->ConstantValue().Elements()[2].bool_, true); EXPECT_EQ(sem->ConstantValue().Element<bool>(2), true);
} }
TEST_F(ResolverConstantsTest, Vec3_Cast_f32_to_32) { TEST_F(ResolverConstantsTest, Vec3_Cast_f32_to_32) {
@ -405,9 +405,9 @@ TEST_F(ResolverConstantsTest, Vec3_Cast_f32_to_32) {
EXPECT_EQ(sem->ConstantValue().Type(), sem->Type()); EXPECT_EQ(sem->ConstantValue().Type(), sem->Type());
EXPECT_TRUE(sem->ConstantValue().ElementType()->Is<sem::I32>()); EXPECT_TRUE(sem->ConstantValue().ElementType()->Is<sem::I32>());
ASSERT_EQ(sem->ConstantValue().Elements().size(), 3u); ASSERT_EQ(sem->ConstantValue().Elements().size(), 3u);
EXPECT_EQ(sem->ConstantValue().Elements()[0].i32, 1); EXPECT_EQ(sem->ConstantValue().Element<AInt>(0).value, 1);
EXPECT_EQ(sem->ConstantValue().Elements()[1].i32, 2); EXPECT_EQ(sem->ConstantValue().Element<AInt>(1).value, 2);
EXPECT_EQ(sem->ConstantValue().Elements()[2].i32, 3); EXPECT_EQ(sem->ConstantValue().Element<AInt>(2).value, 3);
} }
TEST_F(ResolverConstantsTest, Vec3_Cast_u32_to_f32) { TEST_F(ResolverConstantsTest, Vec3_Cast_u32_to_f32) {
@ -424,9 +424,9 @@ TEST_F(ResolverConstantsTest, Vec3_Cast_u32_to_f32) {
EXPECT_EQ(sem->ConstantValue().Type(), sem->Type()); EXPECT_EQ(sem->ConstantValue().Type(), sem->Type());
EXPECT_TRUE(sem->ConstantValue().ElementType()->Is<sem::F32>()); EXPECT_TRUE(sem->ConstantValue().ElementType()->Is<sem::F32>());
ASSERT_EQ(sem->ConstantValue().Elements().size(), 3u); ASSERT_EQ(sem->ConstantValue().Elements().size(), 3u);
EXPECT_EQ(sem->ConstantValue().Elements()[0].f32, 10.f); EXPECT_EQ(sem->ConstantValue().Element<AFloat>(0).value, 10.f);
EXPECT_EQ(sem->ConstantValue().Elements()[1].f32, 20.f); EXPECT_EQ(sem->ConstantValue().Element<AFloat>(1).value, 20.f);
EXPECT_EQ(sem->ConstantValue().Elements()[2].f32, 30.f); EXPECT_EQ(sem->ConstantValue().Element<AFloat>(2).value, 30.f);
} }
} // namespace } // namespace

View File

@ -1515,7 +1515,7 @@ bool Validator::TextureBuiltinFunction(const sem::Call* call) const {
if (is_const_expr) { if (is_const_expr) {
auto vector = builtin->Parameters()[index]->Type()->Is<sem::Vector>(); auto vector = builtin->Parameters()[index]->Type()->Is<sem::Vector>();
for (size_t i = 0; i < values.Elements().size(); i++) { for (size_t i = 0; i < values.Elements().size(); i++) {
auto value = values.Elements()[i].i32; auto value = values.Element<AInt>(i).value;
if (value < min || value > max) { if (value < min || value > max) {
if (vector) { if (vector) {
AddError("each component of the " + name + AddError("each component of the " + name +

View File

@ -60,16 +60,12 @@ Constant::~Constant() = default;
Constant& Constant::operator=(const Constant& rhs) = default; Constant& Constant::operator=(const Constant& rhs) = default;
bool Constant::AnyZero() const { bool Constant::AnyZero() const {
for (size_t i = 0; i < Elements().size(); ++i) { for (auto scalar : elems_) {
if (WithScalarAt(i, [&](auto&& s) { auto is_zero = [&](auto&& s) {
// Use std::equal_to to work around -Wfloat-equal warnings
using T = std::remove_reference_t<decltype(s)>; using T = std::remove_reference_t<decltype(s)>;
auto equal_to = std::equal_to<T>{}; return s == T(0);
if (equal_to(s, T(0))) { };
return true; if (std::visit(is_zero, scalar)) {
}
return false;
})) {
return true; return true;
} }
} }

View File

@ -15,6 +15,7 @@
#ifndef SRC_TINT_SEM_CONSTANT_H_ #ifndef SRC_TINT_SEM_CONSTANT_H_
#define SRC_TINT_SEM_CONSTANT_H_ #define SRC_TINT_SEM_CONSTANT_H_
#include <variant>
#include <vector> #include <vector>
#include "src/tint/program_builder.h" #include "src/tint/program_builder.h"
@ -26,40 +27,8 @@ namespace tint::sem {
/// list of scalar values. Value may be of a scalar or vector type. /// list of scalar values. Value may be of a scalar or vector type.
class Constant { class Constant {
public: public:
/// Scalar holds a single constant scalar value, as a union of an i32, u32, /// Scalar holds a single constant scalar value - one of: AInt, AFloat or bool.
/// f32 or boolean. using Scalar = std::variant<AInt, AFloat, bool>;
union Scalar {
/// The scalar value as a i32
tint::i32 i32;
/// The scalar value as a u32
tint::u32 u32;
/// The scalar value as a f32
tint::f32 f32;
/// The scalar value as a f16, internally stored as float
tint::f16 f16;
/// The scalar value as a bool
bool bool_;
/// Constructs the scalar with the i32 value `v`
/// @param v the value of the Scalar
Scalar(tint::i32 v) : i32(v) {} // NOLINT
/// Constructs the scalar with the u32 value `v`
/// @param v the value of the Scalar
Scalar(tint::u32 v) : u32(v) {} // NOLINT
/// Constructs the scalar with the f32 value `v`
/// @param v the value of the Scalar
Scalar(tint::f32 v) : f32(v) {} // NOLINT
/// Constructs the scalar with the f16 value `v`
/// @param v the value of the Scalar
Scalar(tint::f16 v) : f16({v}) {} // NOLINT
/// Constructs the scalar with the bool value `v`
/// @param v the value of the Scalar
Scalar(bool v) : bool_(v) {} // NOLINT
};
/// Scalars is a list of scalar values /// Scalars is a list of scalar values
using Scalars = std::vector<Scalar>; using Scalars = std::vector<Scalar>;
@ -101,33 +70,18 @@ class Constant {
/// @returns true if any scalar element is zero /// @returns true if any scalar element is zero
bool AnyZero() const; bool AnyZero() const;
/// Calls `func(s)` with s being the current scalar value at `index`.
/// `func` is typically a lambda of the form '[](auto&& s)'.
/// @param index the index of the scalar value /// @param index the index of the scalar value
/// @param func a function with signature `T(S)` /// @return the value of the scalar at `index`, which must be of type `T`.
/// @return the value returned by func. template <typename T>
template <typename Func> T Element(size_t index) const {
auto WithScalarAt(size_t index, Func&& func) const { return std::get<T>(elems_[index]);
return Switch(
ElementType(), //
[&](const I32*) { return func(elems_[index].i32); },
[&](const U32*) { return func(elems_[index].u32); },
[&](const F16*) { return func(elems_[index].f16); },
[&](const F32*) { return func(elems_[index].f32); },
[&](const Bool*) { return func(elems_[index].bool_); },
[&](Default) {
diag::List diags;
TINT_UNREACHABLE(Semantic, diags)
<< "invalid scalar type " << type_->TypeInfo().name;
return func(u32(0u));
});
} }
/// @param index the index of the scalar value /// @param index the index of the scalar value
/// @return the value of the scalar `static_cast` to type T. /// @return the value of the scalar `static_cast` to type T.
template <typename T> template <typename T>
T ElementAs(size_t index) const { T ElementAs(size_t index) const {
return WithScalarAt(index, [](auto val) { return static_cast<T>(val); }); return std::visit([](auto val) { return static_cast<T>(val); }, elems_[index]);
} }
private: private:

View File

@ -56,6 +56,21 @@ void FoldConstants::Run(CloneContext& ctx, const DataMap&, DataMap&) const {
return nullptr; return nullptr;
} }
auto build_scalar = [&](sem::Constant::Scalar s) {
return Switch(
value.ElementType(), //
[&](const sem::I32*) { return ctx.dst->Expr(i32(std::get<AInt>(s).value)); },
[&](const sem::U32*) { return ctx.dst->Expr(u32(std::get<AInt>(s).value)); },
[&](const sem::F32*) { return ctx.dst->Expr(f32(std::get<AFloat>(s).value)); },
[&](const sem::Bool*) { return ctx.dst->Expr(std::get<bool>(s)); },
[&](Default) {
TINT_ICE(Transform, ctx.dst->Diagnostics())
<< "unhandled Constant::Scalar type: "
<< value.ElementType()->FriendlyName(ctx.src->Symbols());
return nullptr;
});
};
if (auto* vec = ty->As<sem::Vector>()) { if (auto* vec = ty->As<sem::Vector>()) {
uint32_t vec_size = static_cast<uint32_t>(vec->Width()); uint32_t vec_size = static_cast<uint32_t>(vec->Width());
@ -73,7 +88,7 @@ void FoldConstants::Run(CloneContext& ctx, const DataMap&, DataMap&) const {
ast::ExpressionList ctors; ast::ExpressionList ctors;
for (uint32_t i = 0; i < ctor_size; ++i) { for (uint32_t i = 0; i < ctor_size; ++i) {
value.WithScalarAt(i, [&](auto&& s) { ctors.emplace_back(ctx.dst->Expr(s)); }); ctors.emplace_back(build_scalar(value.Elements()[i]));
} }
auto* el_ty = CreateASTTypeFor(ctx, vec->type()); auto* el_ty = CreateASTTypeFor(ctx, vec->type());
@ -81,8 +96,7 @@ void FoldConstants::Run(CloneContext& ctx, const DataMap&, DataMap&) const {
} }
if (ty->is_scalar()) { if (ty->is_scalar()) {
return value.WithScalarAt( return build_scalar(value.Elements()[0]);
0, [&](auto&& s) -> const ast::LiteralExpression* { return ctx.dst->Expr(s); });
} }
return nullptr; return nullptr;

View File

@ -123,10 +123,10 @@ struct Robustness::State {
if (auto idx_constant = idx_sem->ConstantValue()) { if (auto idx_constant = idx_sem->ConstantValue()) {
// Constant value index // Constant value index
if (idx_constant.Type()->Is<sem::I32>()) { if (idx_constant.Type()->Is<sem::I32>()) {
idx.i32 = idx_constant.Elements()[0].i32; idx.i32 = static_cast<int32_t>(idx_constant.Element<AInt>(0).value);
idx.is_signed = true; idx.is_signed = true;
} else if (idx_constant.Type()->Is<sem::U32>()) { } else if (idx_constant.Type()->Is<sem::U32>()) {
idx.u32 = idx_constant.Elements()[0].u32; idx.u32 = static_cast<uint32_t>(idx_constant.Element<AInt>(0).value);
idx.is_signed = false; idx.is_signed = false;
} else { } else {
TINT_ICE(Transform, b.Diagnostics()) << "unsupported constant value for accessor " TINT_ICE(Transform, b.Diagnostics()) << "unsupported constant value for accessor "

View File

@ -359,13 +359,8 @@ struct ZeroInitWorkgroupMemory::State {
} }
auto* sem = ctx.src->Sem().Get(expr); auto* sem = ctx.src->Sem().Get(expr);
if (auto c = sem->ConstantValue()) { if (auto c = sem->ConstantValue()) {
if (c.ElementType()->Is<sem::I32>()) { workgroup_size_const *= c.Element<AInt>(0).value;
workgroup_size_const *= static_cast<uint32_t>(c.Elements()[0].i32);
continue; continue;
} else if (c.ElementType()->Is<sem::U32>()) {
workgroup_size_const *= c.Elements()[0].u32;
continue;
}
} }
// Constant value could not be found. Build expression instead. // Constant value could not be found. Build expression instead.
workgroup_size_expr = [this, expr, size = workgroup_size_expr] { workgroup_size_expr = [this, expr, size = workgroup_size_expr] {

View File

@ -660,13 +660,8 @@ bool GeneratorImpl::EmitExpressionOrOneIfZero(std::ostream& out, const ast::Expr
if (i != 0) { if (i != 0) {
out << ", "; out << ", ";
} }
if (!val.WithScalarAt(i, [&](auto&& s) -> bool { auto s = val.Element<AInt>(i).value;
// Use std::equal_to to work around -Wfloat-equal warnings if (!EmitValue(out, elem_ty, (s == 0) ? 1 : static_cast<int>(s))) {
using T = std::remove_reference_t<decltype(s)>;
auto equal_to = std::equal_to<T>{};
bool is_zero = equal_to(s, T(0));
return EmitValue(out, elem_ty, is_zero ? 1 : static_cast<int>(s));
})) {
return false; return false;
} }
} }
@ -1191,7 +1186,7 @@ bool GeneratorImpl::EmitUniformBufferAccess(
if (auto val = offset_arg->ConstantValue()) { if (auto val = offset_arg->ConstantValue()) {
TINT_ASSERT(Writer, val.Type()->Is<sem::U32>()); TINT_ASSERT(Writer, val.Type()->Is<sem::U32>());
scalar_offset_value = val.Elements()[0].u32; scalar_offset_value = static_cast<uint32_t>(val.Element<AInt>(0).value);
scalar_offset_value /= 4; // bytes -> scalar index scalar_offset_value /= 4; // bytes -> scalar index
scalar_offset_constant = true; scalar_offset_constant = true;
} }
@ -2371,7 +2366,7 @@ bool GeneratorImpl::EmitTextureCall(std::ostream& out,
case sem::BuiltinType::kTextureGather: case sem::BuiltinType::kTextureGather:
out << ".Gather"; out << ".Gather";
if (builtin->Parameters()[0]->Usage() == sem::ParameterUsage::kComponent) { if (builtin->Parameters()[0]->Usage() == sem::ParameterUsage::kComponent) {
switch (call->Arguments()[0]->ConstantValue().Elements()[0].i32) { switch (call->Arguments()[0]->ConstantValue().Element<AInt>(0).value) {
case 0: case 0:
out << "Red"; out << "Red";
break; break;

View File

@ -1136,8 +1136,8 @@ bool GeneratorImpl::EmitTextureCall(std::ostream& out,
break; // Other texture dimensions don't have an offset break; // Other texture dimensions don't have an offset
} }
} }
auto c = component->ConstantValue().Elements()[0].i32; auto c = component->ConstantValue().Element<AInt>(0);
switch (c) { switch (c.value) {
case 0: case 0:
out << "component::x"; out << "component::x";
break; break;