Moved sem::ArrayCount to an inherited structure

This CL moves the ArrayCount from a variant to use inheritance. This
will allow the sem to have different array count classes from the IR.
The ArrayCounts, similar to types, are unique across the code base and
are provided by the TypeManager.

Bug: tint:1718
Change-Id: Ib9c7c9df881e7a34cc3def2ff29571f536d66244
Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/112441
Commit-Queue: Dan Sinclair <dsinclair@chromium.org>
Reviewed-by: Ben Clayton <bclayton@google.com>
Kokoro: Kokoro <noreply+kokoro@google.com>
This commit is contained in:
dan sinclair 2022-12-01 23:45:18 +00:00 committed by Dawn LUCI CQ
parent 71e6bcf1af
commit 4b1d79e292
26 changed files with 439 additions and 302 deletions

View File

@ -419,6 +419,7 @@ libtint_source_set("libtint_core_all_src") {
"sem/abstract_int.h",
"sem/abstract_numeric.h",
"sem/array.h",
"sem/array_count.h",
"sem/atomic.h",
"sem/behavior.h",
"sem/binding_point.h",
@ -635,6 +636,8 @@ libtint_source_set("libtint_sem_src") {
"sem/abstract_numeric.h",
"sem/array.cc",
"sem/array.h",
"sem/array_count.cc",
"sem/array_count.h",
"sem/atomic.cc",
"sem/atomic.h",
"sem/behavior.cc",

View File

@ -299,6 +299,8 @@ list(APPEND TINT_LIB_SRCS
sem/abstract_numeric.h
sem/array.cc
sem/array.h
sem/array_count.cc
sem/array_count.h
sem/atomic.cc
sem/atomic.h
sem/behavior.cc

View File

@ -91,6 +91,7 @@
#include "src/tint/program.h"
#include "src/tint/program_id.h"
#include "src/tint/sem/array.h"
#include "src/tint/sem/array_count.h"
#include "src/tint/sem/bool.h"
#include "src/tint/sem/constant.h"
#include "src/tint/sem/depth_texture.h"
@ -457,7 +458,8 @@ class ProgramBuilder {
/// @returns the node pointer
template <typename T, typename... ARGS>
traits::EnableIf<traits::IsTypeOrDerived<T, sem::Node> &&
!traits::IsTypeOrDerived<T, sem::Type>,
!traits::IsTypeOrDerived<T, sem::Type> &&
!traits::IsTypeOrDerived<T, sem::ArrayCount>,
T>*
create(ARGS&&... args) {
AssertNotMoved();
@ -476,17 +478,28 @@ class ProgramBuilder {
/// Creates a new sem::Type owned by the ProgramBuilder.
/// When the ProgramBuilder is destructed, owned ProgramBuilder and the
/// returned`Type` will also be destructed.
/// returned `Type` will also be destructed.
/// Types are unique (de-aliased), and so calling create() for the same `T`
/// and arguments will return the same pointer.
/// @param args the arguments to pass to the type constructor
/// @returns the de-aliased type pointer
template <typename T, typename... ARGS>
traits::EnableIfIsType<T, sem::Type>* create(ARGS&&... args) {
static_assert(std::is_base_of<sem::Type, T>::value, "T does not derive from sem::Type");
AssertNotMoved();
return types_.Get<T>(std::forward<ARGS>(args)...);
}
/// Creates a new sem::ArrayCount owned by the ProgramBuilder.
/// When the ProgramBuilder is destructed, owned ProgramBuilder and the
/// returned `ArrayCount` will also be destructed.
/// ArrayCounts are unique (de-aliased), and so calling create() for the same `T`
/// and arguments will return the same pointer.
/// @param args the arguments to pass to the array count constructor
/// @returns the de-aliased array count pointer
template <typename T, typename... ARGS>
traits::EnableIfIsType<T, sem::ArrayCount>* create(ARGS&&... args) {
AssertNotMoved();
return types_.GetArrayCount<T>(std::forward<ARGS>(args)...);
}
/// Marks this builder as moved, preventing any further use of the builder.
void MarkAsMoved();

View File

@ -1321,7 +1321,6 @@ TEST_F(ResolverConstEvalTest, Array_i32_Zero) {
auto* arr = sem->Type()->As<sem::Array>();
ASSERT_NE(arr, nullptr);
EXPECT_TRUE(arr->ElemType()->Is<sem::I32>());
EXPECT_EQ(arr->Count(), sem::ConstantArrayCount{4u});
EXPECT_TYPE(sem->ConstantValue()->Type(), sem->Type());
EXPECT_TRUE(sem->ConstantValue()->AllEqual());
EXPECT_TRUE(sem->ConstantValue()->AnyZero());
@ -1359,7 +1358,6 @@ TEST_F(ResolverConstEvalTest, Array_f32_Zero) {
auto* arr = sem->Type()->As<sem::Array>();
ASSERT_NE(arr, nullptr);
EXPECT_TRUE(arr->ElemType()->Is<sem::F32>());
EXPECT_EQ(arr->Count(), sem::ConstantArrayCount{4u});
EXPECT_TYPE(sem->ConstantValue()->Type(), sem->Type());
EXPECT_TRUE(sem->ConstantValue()->AllEqual());
EXPECT_TRUE(sem->ConstantValue()->AnyZero());
@ -1397,7 +1395,6 @@ TEST_F(ResolverConstEvalTest, Array_vec3_f32_Zero) {
auto* arr = sem->Type()->As<sem::Array>();
ASSERT_NE(arr, nullptr);
EXPECT_TRUE(arr->ElemType()->Is<sem::Vector>());
EXPECT_EQ(arr->Count(), sem::ConstantArrayCount{2u});
EXPECT_TYPE(sem->ConstantValue()->Type(), sem->Type());
EXPECT_TRUE(sem->ConstantValue()->AllEqual());
EXPECT_TRUE(sem->ConstantValue()->AnyZero());
@ -1449,7 +1446,6 @@ TEST_F(ResolverConstEvalTest, Array_Struct_f32_Zero) {
auto* arr = sem->Type()->As<sem::Array>();
ASSERT_NE(arr, nullptr);
EXPECT_TRUE(arr->ElemType()->Is<sem::Struct>());
EXPECT_EQ(arr->Count(), sem::ConstantArrayCount{2u});
EXPECT_TYPE(sem->ConstantValue()->Type(), sem->Type());
EXPECT_TRUE(sem->ConstantValue()->AllEqual());
EXPECT_TRUE(sem->ConstantValue()->AnyZero());
@ -1487,7 +1483,6 @@ TEST_F(ResolverConstEvalTest, Array_i32_Elements) {
auto* arr = sem->Type()->As<sem::Array>();
ASSERT_NE(arr, nullptr);
EXPECT_TRUE(arr->ElemType()->Is<sem::I32>());
EXPECT_EQ(arr->Count(), sem::ConstantArrayCount{4u});
EXPECT_TYPE(sem->ConstantValue()->Type(), sem->Type());
EXPECT_FALSE(sem->ConstantValue()->AllEqual());
EXPECT_FALSE(sem->ConstantValue()->AnyZero());
@ -1525,7 +1520,6 @@ TEST_F(ResolverConstEvalTest, Array_f32_Elements) {
auto* arr = sem->Type()->As<sem::Array>();
ASSERT_NE(arr, nullptr);
EXPECT_TRUE(arr->ElemType()->Is<sem::F32>());
EXPECT_EQ(arr->Count(), sem::ConstantArrayCount{4u});
EXPECT_TYPE(sem->ConstantValue()->Type(), sem->Type());
EXPECT_FALSE(sem->ConstantValue()->AllEqual());
EXPECT_FALSE(sem->ConstantValue()->AnyZero());
@ -1564,7 +1558,6 @@ TEST_F(ResolverConstEvalTest, Array_vec3_f32_Elements) {
auto* arr = sem->Type()->As<sem::Array>();
ASSERT_NE(arr, nullptr);
EXPECT_TRUE(arr->ElemType()->Is<sem::Vector>());
EXPECT_EQ(arr->Count(), sem::ConstantArrayCount{2u});
EXPECT_TYPE(sem->ConstantValue()->Type(), sem->Type());
EXPECT_FALSE(sem->ConstantValue()->AllEqual());
EXPECT_FALSE(sem->ConstantValue()->AnyZero());
@ -1594,7 +1587,6 @@ TEST_F(ResolverConstEvalTest, Array_Struct_f32_Elements) {
auto* arr = sem->Type()->As<sem::Array>();
ASSERT_NE(arr, nullptr);
EXPECT_TRUE(arr->ElemType()->Is<sem::Struct>());
EXPECT_EQ(arr->Count(), sem::ConstantArrayCount{2u});
EXPECT_TYPE(sem->ConstantValue()->Type(), sem->Type());
EXPECT_FALSE(sem->ConstantValue()->AllEqual());
EXPECT_FALSE(sem->ConstantValue()->AnyZero());

View File

@ -135,8 +135,8 @@ INSTANTIATE_TEST_SUITE_P(ResolverTest, ResolverInferredTypeParamTest, testing::V
TEST_F(ResolverInferredTypeTest, InferArray_Pass) {
auto* type = ty.array(ty.u32(), 10_u);
auto* expected_type =
create<sem::Array>(create<sem::U32>(), sem::ConstantArrayCount{10u}, 4u, 4u * 10u, 4u, 4u);
auto* expected_type = create<sem::Array>(
create<sem::U32>(), create<sem::ConstantArrayCount>(10u), 4u, 4u * 10u, 4u, 4u);
auto* ctor_expr = Construct(type);
auto* var = Var("a", ast::AddressSpace::kFunction, ctor_expr);

View File

@ -532,12 +532,13 @@ bool match_array(MatchState&, const sem::Type* ty, const sem::Type*& T) {
}
const sem::Array* build_array(MatchState& state, const sem::Type* el) {
return state.builder.create<sem::Array>(el,
/* count */ sem::RuntimeArrayCount{},
/* align */ 0u,
/* size */ 0u,
/* stride */ 0u,
/* stride_implicit */ 0u);
return state.builder.create<sem::Array>(
el,
/* count */ state.builder.create<sem::RuntimeArrayCount>(),
/* align */ 0u,
/* size */ 0u,
/* stride */ 0u,
/* stride_implicit */ 0u);
}
bool match_ptr(MatchState&, const sem::Type* ty, Number& S, const sem::Type*& T, Number& A) {

View File

@ -252,7 +252,8 @@ TEST_F(IntrinsicTableTest, MismatchPointer) {
}
TEST_F(IntrinsicTableTest, MatchArray) {
auto* arr = create<sem::Array>(create<sem::U32>(), sem::RuntimeArrayCount{}, 4u, 4u, 4u, 4u);
auto* arr =
create<sem::Array>(create<sem::U32>(), create<sem::RuntimeArrayCount>(), 4u, 4u, 4u, 4u);
auto* arr_ptr = create<sem::Pointer>(arr, ast::AddressSpace::kStorage, ast::Access::kReadWrite);
auto result = table->Lookup(BuiltinType::kArrayLength, utils::Vector{arr_ptr},
sem::EvaluationStage::kConstant, Source{});
@ -955,7 +956,8 @@ TEST_F(IntrinsicTableTest, MatchTypeConversion) {
}
TEST_F(IntrinsicTableTest, MismatchTypeConversion) {
auto* arr = create<sem::Array>(create<sem::U32>(), sem::RuntimeArrayCount{}, 4u, 4u, 4u, 4u);
auto* arr =
create<sem::Array>(create<sem::U32>(), create<sem::RuntimeArrayCount>(), 4u, 4u, 4u, 4u);
auto* f32 = create<sem::F32>();
auto result = table->Lookup(InitConvIntrinsic::kVec3, f32, utils::Vector{arr},
sem::EvaluationStage::kConstant, Source{{12, 34}});

View File

@ -106,13 +106,14 @@ TEST_F(ResolverIsHostShareable, Atomic) {
}
TEST_F(ResolverIsHostShareable, ArraySizedOfHostShareable) {
auto* arr =
create<sem::Array>(create<sem::I32>(), sem::ConstantArrayCount{5u}, 4u, 20u, 4u, 4u);
auto* arr = create<sem::Array>(create<sem::I32>(), create<sem::ConstantArrayCount>(5u), 4u, 20u,
4u, 4u);
EXPECT_TRUE(r()->IsHostShareable(arr));
}
TEST_F(ResolverIsHostShareable, ArrayUnsizedOfHostShareable) {
auto* arr = create<sem::Array>(create<sem::I32>(), sem::RuntimeArrayCount{}, 4u, 4u, 4u, 4u);
auto* arr =
create<sem::Array>(create<sem::I32>(), create<sem::RuntimeArrayCount>(), 4u, 4u, 4u, 4u);
EXPECT_TRUE(r()->IsHostShareable(arr));
}

View File

@ -89,13 +89,14 @@ TEST_F(ResolverIsStorableTest, Atomic) {
}
TEST_F(ResolverIsStorableTest, ArraySizedOfStorable) {
auto* arr =
create<sem::Array>(create<sem::I32>(), sem::ConstantArrayCount{5u}, 4u, 20u, 4u, 4u);
auto* arr = create<sem::Array>(create<sem::I32>(), create<sem::ConstantArrayCount>(5u), 4u, 20u,
4u, 4u);
EXPECT_TRUE(r()->IsStorable(arr));
}
TEST_F(ResolverIsStorableTest, ArrayUnsizedOfStorable) {
auto* arr = create<sem::Array>(create<sem::I32>(), sem::RuntimeArrayCount{}, 4u, 4u, 4u, 4u);
auto* arr =
create<sem::Array>(create<sem::I32>(), create<sem::RuntimeArrayCount>(), 4u, 4u, 4u, 4u);
EXPECT_TRUE(r()->IsStorable(arr));
}

View File

@ -2143,8 +2143,7 @@ sem::Call* Resolver::Call(const ast::CallExpression* expr) {
[&](const ast::Array* a) -> sem::Call* {
Mark(a);
// array element type must be inferred if it was not specified.
sem::ArrayCount el_count =
sem::ConstantArrayCount{static_cast<uint32_t>(args.Length())};
const sem::ArrayCount* el_count = nullptr;
const sem::Type* el_ty = nullptr;
if (a->type) {
el_ty = Type(a->type);
@ -2155,14 +2154,15 @@ sem::Call* Resolver::Call(const ast::CallExpression* expr) {
AddError("cannot construct a runtime-sized array", expr->source);
return nullptr;
}
if (auto count = ArrayCount(a->count)) {
el_count = count.Get();
} else {
el_count = ArrayCount(a->count);
if (!el_count) {
return nullptr;
}
// Note: validation later will detect any mismatches between explicit array
// size and number of initializer expressions.
} else {
el_count = builder_->create<sem::ConstantArrayCount>(
static_cast<uint32_t>(args.Length()));
auto arg_tys =
utils::Transform(args, [](auto* arg) { return arg->Type()->UnwrapRef(); });
el_ty = sem::Type::Common(arg_tys);
@ -2936,15 +2936,16 @@ sem::Array* Resolver::Array(const ast::Array* arr) {
return nullptr;
}
sem::ArrayCount el_count = sem::RuntimeArrayCount{};
const sem::ArrayCount* el_count = nullptr;
// Evaluate the constant array count expression.
if (auto* count_expr = arr->count) {
if (auto count = ArrayCount(count_expr)) {
el_count = count.Get();
} else {
el_count = ArrayCount(count_expr);
if (!el_count) {
return nullptr;
}
} else {
el_count = builder_->create<sem::RuntimeArrayCount>();
}
auto* out = Array(arr->type->source, //
@ -2971,11 +2972,11 @@ sem::Array* Resolver::Array(const ast::Array* arr) {
return out;
}
utils::Result<sem::ArrayCount> Resolver::ArrayCount(const ast::Expression* count_expr) {
const sem::ArrayCount* Resolver::ArrayCount(const ast::Expression* count_expr) {
// Evaluate the constant array count expression.
const auto* count_sem = Materialize(Expression(count_expr));
if (!count_sem) {
return utils::Failure;
return nullptr;
}
if (count_sem->Stage() == sem::EvaluationStage::kOverride) {
@ -2983,34 +2984,34 @@ utils::Result<sem::ArrayCount> Resolver::ArrayCount(const ast::Expression* count
// Is the count a named 'override'?
if (auto* user = count_sem->UnwrapMaterialize()->As<sem::VariableUser>()) {
if (auto* global = user->Variable()->As<sem::GlobalVariable>()) {
return sem::ArrayCount{sem::NamedOverrideArrayCount{global}};
return builder_->create<sem::NamedOverrideArrayCount>(global);
}
}
return sem::ArrayCount{sem::UnnamedOverrideArrayCount{count_sem}};
return builder_->create<sem::UnnamedOverrideArrayCount>(count_sem);
}
auto* count_val = count_sem->ConstantValue();
if (!count_val) {
AddError("array count must evaluate to a constant integer expression or override variable",
count_expr->source);
return utils::Failure;
return nullptr;
}
if (auto* ty = count_val->Type(); !ty->is_integer_scalar()) {
AddError("array count must evaluate to a constant integer expression, but is type '" +
builder_->FriendlyName(ty) + "'",
count_expr->source);
return utils::Failure;
return nullptr;
}
int64_t count = count_val->As<AInt>();
if (count < 1) {
AddError("array count (" + std::to_string(count) + ") must be greater than 0",
count_expr->source);
return utils::Failure;
return nullptr;
}
return sem::ArrayCount{sem::ConstantArrayCount{static_cast<uint32_t>(count)}};
return builder_->create<sem::ConstantArrayCount>(static_cast<uint32_t>(count));
}
bool Resolver::ArrayAttributes(utils::VectorRef<const ast::Attribute*> attributes,
@ -3046,7 +3047,7 @@ bool Resolver::ArrayAttributes(utils::VectorRef<const ast::Attribute*> attribute
sem::Array* Resolver::Array(const Source& el_source,
const Source& count_source,
const sem::Type* el_ty,
sem::ArrayCount el_count,
const sem::ArrayCount* el_count,
uint32_t explicit_stride) {
uint32_t el_align = el_ty->Align();
uint32_t el_size = el_ty->Size();
@ -3054,7 +3055,7 @@ sem::Array* Resolver::Array(const Source& el_source,
uint64_t stride = explicit_stride ? explicit_stride : implicit_stride;
uint64_t size = 0;
if (auto const_count = std::get_if<sem::ConstantArrayCount>(&el_count)) {
if (auto const_count = el_count->As<sem::ConstantArrayCount>()) {
size = const_count->value * stride;
if (size > std::numeric_limits<uint32_t>::max()) {
std::stringstream msg;
@ -3063,7 +3064,7 @@ sem::Array* Resolver::Array(const Source& el_source,
AddError(msg.str(), count_source);
return nullptr;
}
} else if (std::holds_alternative<sem::RuntimeArrayCount>(el_count)) {
} else if (el_count->Is<sem::RuntimeArrayCount>()) {
size = stride;
}
auto* out = builder_->create<sem::Array>(el_ty, el_count, el_align, static_cast<uint32_t>(size),

View File

@ -273,7 +273,7 @@ class Resolver {
/// Resolves and validates the expression used as the count parameter of an array.
/// @param count_expr the expression used as the second template parameter to an array<>.
/// @returns the number of elements in the array.
utils::Result<sem::ArrayCount> ArrayCount(const ast::Expression* count_expr);
const sem::ArrayCount* ArrayCount(const ast::Expression* count_expr);
/// Resolves and validates the attributes on an array.
/// @param attributes the attributes on the array type.
@ -296,7 +296,7 @@ class Resolver {
sem::Array* Array(const Source& el_source,
const Source& count_source,
const sem::Type* el_ty,
sem::ArrayCount el_count,
const sem::ArrayCount* el_count,
uint32_t explicit_stride);
/// Builds and returns the semantic information for the alias `alias`.

View File

@ -440,7 +440,7 @@ TEST_F(ResolverTest, ArraySize_UnsignedLiteral) {
auto* ref = TypeOf(a)->As<sem::Reference>();
ASSERT_NE(ref, nullptr);
auto* ary = ref->StoreType()->As<sem::Array>();
EXPECT_EQ(ary->Count(), sem::ConstantArrayCount{10u});
EXPECT_EQ(ary->Count(), create<sem::ConstantArrayCount>(10u));
}
TEST_F(ResolverTest, ArraySize_SignedLiteral) {
@ -453,7 +453,7 @@ TEST_F(ResolverTest, ArraySize_SignedLiteral) {
auto* ref = TypeOf(a)->As<sem::Reference>();
ASSERT_NE(ref, nullptr);
auto* ary = ref->StoreType()->As<sem::Array>();
EXPECT_EQ(ary->Count(), sem::ConstantArrayCount{10u});
EXPECT_EQ(ary->Count(), create<sem::ConstantArrayCount>(10u));
}
TEST_F(ResolverTest, ArraySize_UnsignedConst) {
@ -468,7 +468,7 @@ TEST_F(ResolverTest, ArraySize_UnsignedConst) {
auto* ref = TypeOf(a)->As<sem::Reference>();
ASSERT_NE(ref, nullptr);
auto* ary = ref->StoreType()->As<sem::Array>();
EXPECT_EQ(ary->Count(), sem::ConstantArrayCount{10u});
EXPECT_EQ(ary->Count(), create<sem::ConstantArrayCount>(10u));
}
TEST_F(ResolverTest, ArraySize_SignedConst) {
@ -483,7 +483,7 @@ TEST_F(ResolverTest, ArraySize_SignedConst) {
auto* ref = TypeOf(a)->As<sem::Reference>();
ASSERT_NE(ref, nullptr);
auto* ary = ref->StoreType()->As<sem::Array>();
EXPECT_EQ(ary->Count(), sem::ConstantArrayCount{10u});
EXPECT_EQ(ary->Count(), create<sem::ConstantArrayCount>(10u));
}
TEST_F(ResolverTest, ArraySize_NamedOverride) {
@ -500,7 +500,7 @@ TEST_F(ResolverTest, ArraySize_NamedOverride) {
auto* ary = ref->StoreType()->As<sem::Array>();
auto* sem_override = Sem().Get<sem::GlobalVariable>(override);
ASSERT_NE(sem_override, nullptr);
EXPECT_EQ(ary->Count(), sem::NamedOverrideArrayCount{sem_override});
EXPECT_EQ(ary->Count(), create<sem::NamedOverrideArrayCount>(sem_override));
}
TEST_F(ResolverTest, ArraySize_NamedOverride_Equivalence) {
@ -525,8 +525,8 @@ TEST_F(ResolverTest, ArraySize_NamedOverride_Equivalence) {
auto* sem_override = Sem().Get<sem::GlobalVariable>(override);
ASSERT_NE(sem_override, nullptr);
EXPECT_EQ(ary_a->Count(), sem::NamedOverrideArrayCount{sem_override});
EXPECT_EQ(ary_b->Count(), sem::NamedOverrideArrayCount{sem_override});
EXPECT_EQ(ary_a->Count(), create<sem::NamedOverrideArrayCount>(sem_override));
EXPECT_EQ(ary_b->Count(), create<sem::NamedOverrideArrayCount>(sem_override));
EXPECT_EQ(ary_a, ary_b);
}
@ -545,7 +545,7 @@ TEST_F(ResolverTest, ArraySize_UnnamedOverride) {
auto* ary = ref->StoreType()->As<sem::Array>();
auto* sem_override = Sem().Get<sem::GlobalVariable>(override);
ASSERT_NE(sem_override, nullptr);
EXPECT_EQ(ary->Count(), sem::UnnamedOverrideArrayCount{Sem().Get(cnt)});
EXPECT_EQ(ary->Count(), create<sem::UnnamedOverrideArrayCount>(Sem().Get(cnt)));
}
TEST_F(ResolverTest, ArraySize_UnamedOverride_Equivalence) {
@ -572,8 +572,8 @@ TEST_F(ResolverTest, ArraySize_UnamedOverride_Equivalence) {
auto* sem_override = Sem().Get<sem::GlobalVariable>(override);
ASSERT_NE(sem_override, nullptr);
EXPECT_EQ(ary_a->Count(), sem::UnnamedOverrideArrayCount{Sem().Get(a_cnt)});
EXPECT_EQ(ary_b->Count(), sem::UnnamedOverrideArrayCount{Sem().Get(b_cnt)});
EXPECT_EQ(ary_a->Count(), create<sem::UnnamedOverrideArrayCount>(Sem().Get(a_cnt)));
EXPECT_EQ(ary_b->Count(), create<sem::UnnamedOverrideArrayCount>(Sem().Get(b_cnt)));
EXPECT_NE(ary_a, ary_b);
}

View File

@ -659,9 +659,11 @@ struct DataType<array<N, T>> {
/// @return the semantic array type
static inline const sem::Type* Sem(ProgramBuilder& b) {
auto* el = DataType<T>::Sem(b);
sem::ArrayCount count = sem::ConstantArrayCount{N};
const sem::ArrayCount* count = nullptr;
if (N == 0) {
count = sem::RuntimeArrayCount{};
count = b.create<sem::RuntimeArrayCount>();
} else {
count = b.create<sem::ConstantArrayCount>(N);
}
return b.create<sem::Array>(
/* element */ el,

View File

@ -1778,7 +1778,12 @@ bool Validator::ArrayInitializer(const ast::CallExpression* ctor,
return false;
}
const auto count = std::get<sem::ConstantArrayCount>(array_type->Count()).value;
if (!array_type->IsConstantSized()) {
TINT_ICE(Resolver, diagnostics_) << "Invalid ArrayCount found";
return false;
}
const auto count = array_type->Count()->As<sem::ConstantArrayCount>()->value;
if (!values.IsEmpty() && (values.Length() != count)) {
std::string fm = values.Length() < count ? "few" : "many";
AddError("array initializer has too " + fm + " elements: expected " +

View File

@ -89,13 +89,14 @@ TEST_F(ValidatorIsStorableTest, Atomic) {
}
TEST_F(ValidatorIsStorableTest, ArraySizedOfStorable) {
auto* arr =
create<sem::Array>(create<sem::I32>(), sem::ConstantArrayCount{5u}, 4u, 20u, 4u, 4u);
auto* arr = create<sem::Array>(create<sem::I32>(), create<sem::ConstantArrayCount>(5u), 4u, 20u,
4u, 4u);
EXPECT_TRUE(v()->IsStorable(arr));
}
TEST_F(ValidatorIsStorableTest, ArrayUnsizedOfStorable) {
auto* arr = create<sem::Array>(create<sem::I32>(), sem::RuntimeArrayCount{}, 4u, 4u, 4u, 4u);
auto* arr =
create<sem::Array>(create<sem::I32>(), create<sem::RuntimeArrayCount>(), 4u, 4u, 4u, 4u);
EXPECT_TRUE(v()->IsStorable(arr));
}

View File

@ -28,10 +28,10 @@ namespace tint::sem {
namespace {
TypeFlags FlagsFrom(const Type* element, ArrayCount count) {
TypeFlags FlagsFrom(const Type* element, const ArrayCount* count) {
TypeFlags flags;
// Only constant-expression sized arrays are constructible
if (std::holds_alternative<ConstantArrayCount>(count)) {
if (count->Is<ConstantArrayCount>()) {
if (element->IsConstructible()) {
flags.Add(TypeFlag::kConstructable);
}
@ -39,9 +39,7 @@ TypeFlags FlagsFrom(const Type* element, ArrayCount count) {
flags.Add(TypeFlag::kCreationFixedFootprint);
}
}
if (std::holds_alternative<ConstantArrayCount>(count) ||
std::holds_alternative<NamedOverrideArrayCount>(count) ||
std::holds_alternative<UnnamedOverrideArrayCount>(count)) {
if (count->IsAnyOf<ConstantArrayCount, NamedOverrideArrayCount, UnnamedOverrideArrayCount>()) {
if (element->HasFixedFootprint()) {
flags.Add(TypeFlag::kFixedFootprint);
}
@ -56,7 +54,7 @@ const char* const Array::kErrExpectedConstantCount =
"Was the SubstituteOverride transform run?";
Array::Array(const Type* element,
ArrayCount count,
const ArrayCount* count,
uint32_t align,
uint32_t size,
uint32_t stride,
@ -91,11 +89,11 @@ std::string Array::FriendlyName(const SymbolTable& symbols) const {
out << "@stride(" << stride_ << ") ";
}
out << "array<" << element_->FriendlyName(symbols);
if (auto* const_count = std::get_if<ConstantArrayCount>(&count_)) {
if (auto* const_count = count_->As<ConstantArrayCount>()) {
out << ", " << const_count->value;
} else if (auto* named_override_count = std::get_if<NamedOverrideArrayCount>(&count_)) {
} else if (auto* named_override_count = count_->As<NamedOverrideArrayCount>()) {
out << ", " << symbols.NameFor(named_override_count->variable->Declaration()->symbol);
} else if (std::holds_alternative<UnnamedOverrideArrayCount>(count_)) {
} else if (count_->Is<UnnamedOverrideArrayCount>()) {
out << ", [unnamed override-expression]";
}
out << ">";

View File

@ -20,6 +20,7 @@
#include <string>
#include <variant>
#include "src/tint/sem/array_count.h"
#include "src/tint/sem/node.h"
#include "src/tint/sem/type.h"
#include "src/tint/utils/compiler_macros.h"
@ -33,115 +34,6 @@ class GlobalVariable;
namespace tint::sem {
/// The variant of an ArrayCount when the array is a const-expression.
/// Example:
/// ```
/// const N = 123;
/// type arr = array<i32, N>
/// ```
struct ConstantArrayCount {
/// The array count constant-expression value.
uint32_t value;
};
/// The variant of an ArrayCount when the count is a named override variable.
/// Example:
/// ```
/// override N : i32;
/// type arr = array<i32, N>
/// ```
struct NamedOverrideArrayCount {
/// The `override` variable.
const GlobalVariable* variable;
};
/// The variant of an ArrayCount when the count is an unnamed override variable.
/// Example:
/// ```
/// override N : i32;
/// type arr = array<i32, N*2>
/// ```
struct UnnamedOverrideArrayCount {
/// The unnamed override expression.
/// Note: Each AST expression gets a unique semantic expression node, so two equivalent AST
/// expressions will not result in the same `expr` pointer. This property is important to ensure
/// that two array declarations with equivalent AST expressions do not compare equal.
/// For example, consider:
/// ```
/// override size : u32;
/// var<workgroup> a : array<f32, size * 2>;
/// var<workgroup> b : array<f32, size * 2>;
/// ```
// The array count for `a` and `b` have equivalent AST expressions, but the types for `a` and
// `b` must not compare equal.
const Expression* expr;
};
/// The variant of an ArrayCount when the array is is runtime-sized.
/// Example:
/// ```
/// type arr = array<i32>
/// ```
struct RuntimeArrayCount {};
/// An array count is either a constant-expression value, a named override identifier, an unnamed
/// override identifier, or runtime-sized.
using ArrayCount = std::variant<ConstantArrayCount,
NamedOverrideArrayCount,
UnnamedOverrideArrayCount,
RuntimeArrayCount>;
/// Equality operator
/// @param a the LHS ConstantArrayCount
/// @param b the RHS ConstantArrayCount
/// @returns true if @p a is equal to @p b
inline bool operator==(const ConstantArrayCount& a, const ConstantArrayCount& b) {
return a.value == b.value;
}
/// Equality operator
/// @param a the LHS OverrideArrayCount
/// @param b the RHS OverrideArrayCount
/// @returns true if @p a is equal to @p b
inline bool operator==(const NamedOverrideArrayCount& a, const NamedOverrideArrayCount& b) {
return a.variable == b.variable;
}
/// Equality operator
/// @param a the LHS OverrideArrayCount
/// @param b the RHS OverrideArrayCount
/// @returns true if @p a is equal to @p b
inline bool operator==(const UnnamedOverrideArrayCount& a, const UnnamedOverrideArrayCount& b) {
return a.expr == b.expr;
}
/// Equality operator
/// @returns true
inline bool operator==(const RuntimeArrayCount&, const RuntimeArrayCount&) {
return true;
}
/// Equality operator
/// @param a the LHS ArrayCount
/// @param b the RHS count
/// @returns true if @p a is equal to @p b
template <typename T,
typename = std::enable_if_t<
std::is_same_v<T, ConstantArrayCount> || std::is_same_v<T, NamedOverrideArrayCount> ||
std::is_same_v<T, UnnamedOverrideArrayCount> || std::is_same_v<T, RuntimeArrayCount>>>
inline bool operator==(const ArrayCount& a, const T& b) {
TINT_BEGIN_DISABLE_WARNING(UNREACHABLE_CODE);
return std::visit(
[&](auto count) {
if constexpr (std::is_same_v<std::decay_t<decltype(count)>, T>) {
return count == b;
}
return false;
},
a);
TINT_END_DISABLE_WARNING(UNREACHABLE_CODE);
}
/// Array holds the semantic information for Array nodes.
class Array final : public Castable<Array, Type> {
public:
@ -161,7 +53,7 @@ class Array final : public Castable<Array, Type> {
/// of the array to the start of the next element, if there was no `@stride`
/// attribute applied.
Array(Type const* element,
ArrayCount count,
const ArrayCount* count,
uint32_t align,
uint32_t size,
uint32_t stride,
@ -178,11 +70,11 @@ class Array final : public Castable<Array, Type> {
Type const* ElemType() const { return element_; }
/// @returns the number of elements in the array.
const ArrayCount& Count() const { return count_; }
const ArrayCount* Count() const { return count_; }
/// @returns the array count if the count is a const-expression, otherwise returns nullopt.
inline std::optional<uint32_t> ConstantCount() const {
if (auto* count = std::get_if<ConstantArrayCount>(&count_)) {
if (auto* count = count_->As<ConstantArrayCount>()) {
return count->value;
}
return std::nullopt;
@ -212,23 +104,19 @@ class Array final : public Castable<Array, Type> {
bool IsStrideImplicit() const { return stride_ == implicit_stride_; }
/// @returns true if this array is sized using an const-expression
bool IsConstantSized() const { return std::holds_alternative<ConstantArrayCount>(count_); }
bool IsConstantSized() const { return count_->Is<ConstantArrayCount>(); }
/// @returns true if this array is sized using a named override variable
bool IsNamedOverrideSized() const {
return std::holds_alternative<NamedOverrideArrayCount>(count_);
}
bool IsNamedOverrideSized() const { return count_->Is<NamedOverrideArrayCount>(); }
/// @returns true if this array is sized using an unnamed override variable
bool IsUnnamedOverrideSized() const {
return std::holds_alternative<UnnamedOverrideArrayCount>(count_);
}
bool IsUnnamedOverrideSized() const { return count_->Is<UnnamedOverrideArrayCount>(); }
/// @returns true if this array is sized using a named or unnamed override variable
bool IsOverrideSized() const { return IsNamedOverrideSized() || IsUnnamedOverrideSized(); }
/// @returns true if this array is runtime sized
bool IsRuntimeSized() const { return std::holds_alternative<RuntimeArrayCount>(count_); }
bool IsRuntimeSized() const { return count_->Is<RuntimeArrayCount>(); }
/// @param symbols the program's symbol table
/// @returns the name for this type that closely resembles how it would be
@ -237,7 +125,7 @@ class Array final : public Castable<Array, Type> {
private:
Type const* const element_;
const ArrayCount count_;
const ArrayCount* count_;
const uint32_t align_;
const uint32_t size_;
const uint32_t stride_;
@ -246,49 +134,4 @@ class Array final : public Castable<Array, Type> {
} // namespace tint::sem
namespace std {
/// Custom std::hash specialization for tint::sem::ConstantArrayCount.
template <>
class hash<tint::sem::ConstantArrayCount> {
public:
/// @param count the count to hash
/// @return the hash value
inline std::size_t operator()(const tint::sem::ConstantArrayCount& count) const {
return std::hash<decltype(count.value)>()(count.value);
}
};
/// Custom std::hash specialization for tint::sem::NamedOverrideArrayCount.
template <>
class hash<tint::sem::NamedOverrideArrayCount> {
public:
/// @param count the count to hash
/// @return the hash value
inline std::size_t operator()(const tint::sem::NamedOverrideArrayCount& count) const {
return std::hash<decltype(count.variable)>()(count.variable);
}
};
/// Custom std::hash specialization for tint::sem::UnnamedOverrideArrayCount.
template <>
class hash<tint::sem::UnnamedOverrideArrayCount> {
public:
/// @param count the count to hash
/// @return the hash value
inline std::size_t operator()(const tint::sem::UnnamedOverrideArrayCount& count) const {
return std::hash<decltype(count.expr)>()(count.expr);
}
};
/// Custom std::hash specialization for tint::sem::RuntimeArrayCount.
template <>
class hash<tint::sem::RuntimeArrayCount> {
public:
/// @return the hash value
inline std::size_t operator()(const tint::sem::RuntimeArrayCount&) const { return 42; }
};
} // namespace std
#endif // SRC_TINT_SEM_ARRAY_H_

View File

@ -0,0 +1,82 @@
// Copyright 2021 The Tint Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "src/tint/sem/array_count.h"
TINT_INSTANTIATE_TYPEINFO(tint::sem::ArrayCount);
TINT_INSTANTIATE_TYPEINFO(tint::sem::ConstantArrayCount);
TINT_INSTANTIATE_TYPEINFO(tint::sem::RuntimeArrayCount);
TINT_INSTANTIATE_TYPEINFO(tint::sem::NamedOverrideArrayCount);
TINT_INSTANTIATE_TYPEINFO(tint::sem::UnnamedOverrideArrayCount);
namespace tint::sem {
ArrayCount::ArrayCount() : Base() {}
ArrayCount::~ArrayCount() = default;
ConstantArrayCount::ConstantArrayCount(uint32_t val) : Base(), value(val) {}
ConstantArrayCount::~ConstantArrayCount() = default;
size_t ConstantArrayCount::Hash() const {
return static_cast<size_t>(TypeInfo::Of<ConstantArrayCount>().full_hashcode);
}
bool ConstantArrayCount::Equals(const ArrayCount& other) const {
if (auto* v = other.As<ConstantArrayCount>()) {
return value == v->value;
}
return false;
}
RuntimeArrayCount::RuntimeArrayCount() : Base() {}
RuntimeArrayCount::~RuntimeArrayCount() = default;
size_t RuntimeArrayCount::Hash() const {
return static_cast<size_t>(TypeInfo::Of<RuntimeArrayCount>().full_hashcode);
}
bool RuntimeArrayCount::Equals(const ArrayCount& other) const {
return other.Is<RuntimeArrayCount>();
}
NamedOverrideArrayCount::NamedOverrideArrayCount(const GlobalVariable* var)
: Base(), variable(var) {}
NamedOverrideArrayCount::~NamedOverrideArrayCount() = default;
size_t NamedOverrideArrayCount::Hash() const {
return static_cast<size_t>(TypeInfo::Of<NamedOverrideArrayCount>().full_hashcode);
}
bool NamedOverrideArrayCount::Equals(const ArrayCount& other) const {
if (auto* v = other.As<NamedOverrideArrayCount>()) {
return variable == v->variable;
}
return false;
}
UnnamedOverrideArrayCount::UnnamedOverrideArrayCount(const Expression* e) : Base(), expr(e) {}
UnnamedOverrideArrayCount::~UnnamedOverrideArrayCount() = default;
size_t UnnamedOverrideArrayCount::Hash() const {
return static_cast<size_t>(TypeInfo::Of<UnnamedOverrideArrayCount>().full_hashcode);
}
bool UnnamedOverrideArrayCount::Equals(const ArrayCount& other) const {
if (auto* v = other.As<UnnamedOverrideArrayCount>()) {
return expr == v->expr;
}
return false;
}
} // namespace tint::sem

170
src/tint/sem/array_count.h Normal file
View File

@ -0,0 +1,170 @@
// Copyright 2022 The Tint Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#ifndef SRC_TINT_SEM_ARRAY_COUNT_H_
#define SRC_TINT_SEM_ARRAY_COUNT_H_
#include <functional>
#include <string>
#include "src/tint/sem/expression.h"
#include "src/tint/sem/node.h"
#include "src/tint/sem/variable.h"
namespace tint::sem {
/// An array count
class ArrayCount : public Castable<ArrayCount, Node> {
public:
~ArrayCount() override;
/// @returns a hash of the array count.
virtual size_t Hash() const = 0;
/// @param t other array count
/// @returns true if this array count is equal to the given array count
virtual bool Equals(const ArrayCount& t) const = 0;
protected:
ArrayCount();
};
/// The variant of an ArrayCount when the array is a const-expression.
/// Example:
/// ```
/// const N = 123;
/// type arr = array<i32, N>
/// ```
class ConstantArrayCount final : public Castable<ConstantArrayCount, ArrayCount> {
public:
/// Constructor
/// @param val the constant-expression value
explicit ConstantArrayCount(uint32_t val);
~ConstantArrayCount() override;
/// @returns a hash of the array count.
size_t Hash() const override;
/// @param t other array count
/// @returns true if this array count is equal to the given array count
bool Equals(const ArrayCount& t) const override;
/// The array count constant-expression value.
uint32_t value;
};
/// The variant of an ArrayCount when the array is is runtime-sized.
/// Example:
/// ```
/// type arr = array<i32>
/// ```
class RuntimeArrayCount final : public Castable<RuntimeArrayCount, ArrayCount> {
public:
/// Constructor
RuntimeArrayCount();
~RuntimeArrayCount() override;
/// @returns a hash of the array count.
size_t Hash() const override;
/// @param t other array count
/// @returns true if this array count is equal to the given array count
bool Equals(const ArrayCount& t) const override;
};
/// The variant of an ArrayCount when the count is a named override variable.
/// Example:
/// ```
/// override N : i32;
/// type arr = array<i32, N>
/// ```
class NamedOverrideArrayCount final : public Castable<NamedOverrideArrayCount, ArrayCount> {
public:
/// Constructor
/// @param var the `override` variable
explicit NamedOverrideArrayCount(const GlobalVariable* var);
~NamedOverrideArrayCount() override;
/// @returns a hash of the array count.
size_t Hash() const override;
/// @param t other array count
/// @returns true if this array count is equal to the given array count
bool Equals(const ArrayCount& t) const override;
/// The `override` variable.
const GlobalVariable* variable;
};
/// The variant of an ArrayCount when the count is an unnamed override variable.
/// Example:
/// ```
/// override N : i32;
/// type arr = array<i32, N*2>
/// ```
class UnnamedOverrideArrayCount final : public Castable<UnnamedOverrideArrayCount, ArrayCount> {
public:
/// Constructor
/// @param e the override expression
explicit UnnamedOverrideArrayCount(const Expression* e);
~UnnamedOverrideArrayCount() override;
/// @returns a hash of the array count.
size_t Hash() const override;
/// @param t other array count
/// @returns true if this array count is equal to the given array count
bool Equals(const ArrayCount& t) const override;
/// The unnamed override expression.
/// Note: Each AST expression gets a unique semantic expression node, so two equivalent AST
/// expressions will not result in the same `expr` pointer. This property is important to ensure
/// that two array declarations with equivalent AST expressions do not compare equal.
/// For example, consider:
/// ```
/// override size : u32;
/// var<workgroup> a : array<f32, size * 2>;
/// var<workgroup> b : array<f32, size * 2>;
/// ```
// The array count for `a` and `b` have equivalent AST expressions, but the types for `a` and
// `b` must not compare equal.
const Expression* expr;
};
} // namespace tint::sem
namespace std {
/// std::hash specialization for tint::sem::ArrayCount
template <>
struct hash<tint::sem::ArrayCount> {
/// @param a the array count to obtain a hash from
/// @returns the hash of the array count
size_t operator()(const tint::sem::ArrayCount& a) const { return a.Hash(); }
};
/// std::equal_to specialization for tint::sem::ArrayCount
template <>
struct equal_to<tint::sem::ArrayCount> {
/// @param a the first array count to compare
/// @param b the second array count to compare
/// @returns true if the two array counts are equal
bool operator()(const tint::sem::ArrayCount& a, const tint::sem::ArrayCount& b) const {
return a.Equals(b);
}
};
} // namespace std
#endif // SRC_TINT_SEM_ARRAY_COUNT_H_

View File

@ -21,16 +21,16 @@ namespace {
using ArrayTest = TestHelper;
TEST_F(ArrayTest, CreateSizedArray) {
auto* a = create<Array>(create<U32>(), ConstantArrayCount{2u}, 4u, 8u, 32u, 16u);
auto* b = create<Array>(create<U32>(), ConstantArrayCount{2u}, 4u, 8u, 32u, 16u);
auto* c = create<Array>(create<U32>(), ConstantArrayCount{3u}, 4u, 8u, 32u, 16u);
auto* d = create<Array>(create<U32>(), ConstantArrayCount{2u}, 5u, 8u, 32u, 16u);
auto* e = create<Array>(create<U32>(), ConstantArrayCount{2u}, 4u, 9u, 32u, 16u);
auto* f = create<Array>(create<U32>(), ConstantArrayCount{2u}, 4u, 8u, 33u, 16u);
auto* g = create<Array>(create<U32>(), ConstantArrayCount{2u}, 4u, 8u, 33u, 17u);
auto* a = create<Array>(create<U32>(), create<ConstantArrayCount>(2u), 4u, 8u, 32u, 16u);
auto* b = create<Array>(create<U32>(), create<ConstantArrayCount>(2u), 4u, 8u, 32u, 16u);
auto* c = create<Array>(create<U32>(), create<ConstantArrayCount>(3u), 4u, 8u, 32u, 16u);
auto* d = create<Array>(create<U32>(), create<ConstantArrayCount>(2u), 5u, 8u, 32u, 16u);
auto* e = create<Array>(create<U32>(), create<ConstantArrayCount>(2u), 4u, 9u, 32u, 16u);
auto* f = create<Array>(create<U32>(), create<ConstantArrayCount>(2u), 4u, 8u, 33u, 16u);
auto* g = create<Array>(create<U32>(), create<ConstantArrayCount>(2u), 4u, 8u, 33u, 17u);
EXPECT_EQ(a->ElemType(), create<U32>());
EXPECT_EQ(a->Count(), ConstantArrayCount{2u});
EXPECT_EQ(a->Count(), create<ConstantArrayCount>(2u));
EXPECT_EQ(a->Align(), 4u);
EXPECT_EQ(a->Size(), 8u);
EXPECT_EQ(a->Stride(), 32u);
@ -47,15 +47,15 @@ TEST_F(ArrayTest, CreateSizedArray) {
}
TEST_F(ArrayTest, CreateRuntimeArray) {
auto* a = create<Array>(create<U32>(), RuntimeArrayCount{}, 4u, 8u, 32u, 32u);
auto* b = create<Array>(create<U32>(), RuntimeArrayCount{}, 4u, 8u, 32u, 32u);
auto* c = create<Array>(create<U32>(), RuntimeArrayCount{}, 5u, 8u, 32u, 32u);
auto* d = create<Array>(create<U32>(), RuntimeArrayCount{}, 4u, 9u, 32u, 32u);
auto* e = create<Array>(create<U32>(), RuntimeArrayCount{}, 4u, 8u, 33u, 32u);
auto* f = create<Array>(create<U32>(), RuntimeArrayCount{}, 4u, 8u, 33u, 17u);
auto* a = create<Array>(create<U32>(), create<RuntimeArrayCount>(), 4u, 8u, 32u, 32u);
auto* b = create<Array>(create<U32>(), create<RuntimeArrayCount>(), 4u, 8u, 32u, 32u);
auto* c = create<Array>(create<U32>(), create<RuntimeArrayCount>(), 5u, 8u, 32u, 32u);
auto* d = create<Array>(create<U32>(), create<RuntimeArrayCount>(), 4u, 9u, 32u, 32u);
auto* e = create<Array>(create<U32>(), create<RuntimeArrayCount>(), 4u, 8u, 33u, 32u);
auto* f = create<Array>(create<U32>(), create<RuntimeArrayCount>(), 4u, 8u, 33u, 17u);
EXPECT_EQ(a->ElemType(), create<U32>());
EXPECT_EQ(a->Count(), sem::RuntimeArrayCount{});
EXPECT_EQ(a->Count(), create<sem::RuntimeArrayCount>());
EXPECT_EQ(a->Align(), 4u);
EXPECT_EQ(a->Size(), 8u);
EXPECT_EQ(a->Stride(), 32u);
@ -71,13 +71,13 @@ TEST_F(ArrayTest, CreateRuntimeArray) {
}
TEST_F(ArrayTest, Hash) {
auto* a = create<Array>(create<U32>(), ConstantArrayCount{2u}, 4u, 8u, 32u, 16u);
auto* b = create<Array>(create<U32>(), ConstantArrayCount{2u}, 4u, 8u, 32u, 16u);
auto* c = create<Array>(create<U32>(), ConstantArrayCount{3u}, 4u, 8u, 32u, 16u);
auto* d = create<Array>(create<U32>(), ConstantArrayCount{2u}, 5u, 8u, 32u, 16u);
auto* e = create<Array>(create<U32>(), ConstantArrayCount{2u}, 4u, 9u, 32u, 16u);
auto* f = create<Array>(create<U32>(), ConstantArrayCount{2u}, 4u, 8u, 33u, 16u);
auto* g = create<Array>(create<U32>(), ConstantArrayCount{2u}, 4u, 8u, 33u, 17u);
auto* a = create<Array>(create<U32>(), create<ConstantArrayCount>(2u), 4u, 8u, 32u, 16u);
auto* b = create<Array>(create<U32>(), create<ConstantArrayCount>(2u), 4u, 8u, 32u, 16u);
auto* c = create<Array>(create<U32>(), create<ConstantArrayCount>(3u), 4u, 8u, 32u, 16u);
auto* d = create<Array>(create<U32>(), create<ConstantArrayCount>(2u), 5u, 8u, 32u, 16u);
auto* e = create<Array>(create<U32>(), create<ConstantArrayCount>(2u), 4u, 9u, 32u, 16u);
auto* f = create<Array>(create<U32>(), create<ConstantArrayCount>(2u), 4u, 8u, 33u, 16u);
auto* g = create<Array>(create<U32>(), create<ConstantArrayCount>(2u), 4u, 8u, 33u, 17u);
EXPECT_EQ(a->Hash(), b->Hash());
EXPECT_NE(a->Hash(), c->Hash());
@ -88,13 +88,13 @@ TEST_F(ArrayTest, Hash) {
}
TEST_F(ArrayTest, Equals) {
auto* a = create<Array>(create<U32>(), ConstantArrayCount{2u}, 4u, 8u, 32u, 16u);
auto* b = create<Array>(create<U32>(), ConstantArrayCount{2u}, 4u, 8u, 32u, 16u);
auto* c = create<Array>(create<U32>(), ConstantArrayCount{3u}, 4u, 8u, 32u, 16u);
auto* d = create<Array>(create<U32>(), ConstantArrayCount{2u}, 5u, 8u, 32u, 16u);
auto* e = create<Array>(create<U32>(), ConstantArrayCount{2u}, 4u, 9u, 32u, 16u);
auto* f = create<Array>(create<U32>(), ConstantArrayCount{2u}, 4u, 8u, 33u, 16u);
auto* g = create<Array>(create<U32>(), ConstantArrayCount{2u}, 4u, 8u, 33u, 17u);
auto* a = create<Array>(create<U32>(), create<ConstantArrayCount>(2u), 4u, 8u, 32u, 16u);
auto* b = create<Array>(create<U32>(), create<ConstantArrayCount>(2u), 4u, 8u, 32u, 16u);
auto* c = create<Array>(create<U32>(), create<ConstantArrayCount>(3u), 4u, 8u, 32u, 16u);
auto* d = create<Array>(create<U32>(), create<ConstantArrayCount>(2u), 5u, 8u, 32u, 16u);
auto* e = create<Array>(create<U32>(), create<ConstantArrayCount>(2u), 4u, 9u, 32u, 16u);
auto* f = create<Array>(create<U32>(), create<ConstantArrayCount>(2u), 4u, 8u, 33u, 16u);
auto* g = create<Array>(create<U32>(), create<ConstantArrayCount>(2u), 4u, 8u, 33u, 17u);
EXPECT_TRUE(a->Equals(*b));
EXPECT_FALSE(a->Equals(*c));
@ -106,32 +106,34 @@ TEST_F(ArrayTest, Equals) {
}
TEST_F(ArrayTest, FriendlyNameRuntimeSized) {
auto* arr = create<Array>(create<I32>(), RuntimeArrayCount{}, 0u, 4u, 4u, 4u);
auto* arr = create<Array>(create<I32>(), create<RuntimeArrayCount>(), 0u, 4u, 4u, 4u);
EXPECT_EQ(arr->FriendlyName(Symbols()), "array<i32>");
}
TEST_F(ArrayTest, FriendlyNameStaticSized) {
auto* arr = create<Array>(create<I32>(), ConstantArrayCount{5u}, 4u, 20u, 4u, 4u);
auto* arr = create<Array>(create<I32>(), create<ConstantArrayCount>(5u), 4u, 20u, 4u, 4u);
EXPECT_EQ(arr->FriendlyName(Symbols()), "array<i32, 5>");
}
TEST_F(ArrayTest, FriendlyNameRuntimeSizedNonImplicitStride) {
auto* arr = create<Array>(create<I32>(), RuntimeArrayCount{}, 0u, 4u, 8u, 4u);
auto* arr = create<Array>(create<I32>(), create<RuntimeArrayCount>(), 0u, 4u, 8u, 4u);
EXPECT_EQ(arr->FriendlyName(Symbols()), "@stride(8) array<i32>");
}
TEST_F(ArrayTest, FriendlyNameStaticSizedNonImplicitStride) {
auto* arr = create<Array>(create<I32>(), ConstantArrayCount{5u}, 4u, 20u, 8u, 4u);
auto* arr = create<Array>(create<I32>(), create<ConstantArrayCount>(5u), 4u, 20u, 8u, 4u);
EXPECT_EQ(arr->FriendlyName(Symbols()), "@stride(8) array<i32, 5>");
}
TEST_F(ArrayTest, IsConstructable) {
auto* fixed_sized = create<Array>(create<U32>(), ConstantArrayCount{2u}, 4u, 8u, 32u, 16u);
auto* fixed_sized =
create<Array>(create<U32>(), create<ConstantArrayCount>(2u), 4u, 8u, 32u, 16u);
auto* named_override_sized =
create<Array>(create<U32>(), NamedOverrideArrayCount{}, 4u, 8u, 32u, 16u);
create<Array>(create<U32>(), create<NamedOverrideArrayCount>(nullptr), 4u, 8u, 32u, 16u);
auto* unnamed_override_sized =
create<Array>(create<U32>(), UnnamedOverrideArrayCount{}, 4u, 8u, 32u, 16u);
auto* runtime_sized = create<Array>(create<U32>(), RuntimeArrayCount{}, 4u, 8u, 32u, 16u);
create<Array>(create<U32>(), create<UnnamedOverrideArrayCount>(nullptr), 4u, 8u, 32u, 16u);
auto* runtime_sized =
create<Array>(create<U32>(), create<RuntimeArrayCount>(), 4u, 8u, 32u, 16u);
EXPECT_TRUE(fixed_sized->IsConstructible());
EXPECT_FALSE(named_override_sized->IsConstructible());
@ -140,12 +142,14 @@ TEST_F(ArrayTest, IsConstructable) {
}
TEST_F(ArrayTest, HasCreationFixedFootprint) {
auto* fixed_sized = create<Array>(create<U32>(), ConstantArrayCount{2u}, 4u, 8u, 32u, 16u);
auto* fixed_sized =
create<Array>(create<U32>(), create<ConstantArrayCount>(2u), 4u, 8u, 32u, 16u);
auto* named_override_sized =
create<Array>(create<U32>(), NamedOverrideArrayCount{}, 4u, 8u, 32u, 16u);
create<Array>(create<U32>(), create<NamedOverrideArrayCount>(nullptr), 4u, 8u, 32u, 16u);
auto* unnamed_override_sized =
create<Array>(create<U32>(), UnnamedOverrideArrayCount{}, 4u, 8u, 32u, 16u);
auto* runtime_sized = create<Array>(create<U32>(), RuntimeArrayCount{}, 4u, 8u, 32u, 16u);
create<Array>(create<U32>(), create<UnnamedOverrideArrayCount>(nullptr), 4u, 8u, 32u, 16u);
auto* runtime_sized =
create<Array>(create<U32>(), create<RuntimeArrayCount>(), 4u, 8u, 32u, 16u);
EXPECT_TRUE(fixed_sized->HasCreationFixedFootprint());
EXPECT_FALSE(named_override_sized->HasCreationFixedFootprint());
@ -154,12 +158,14 @@ TEST_F(ArrayTest, HasCreationFixedFootprint) {
}
TEST_F(ArrayTest, HasFixedFootprint) {
auto* fixed_sized = create<Array>(create<U32>(), ConstantArrayCount{2u}, 4u, 8u, 32u, 16u);
auto* fixed_sized =
create<Array>(create<U32>(), create<ConstantArrayCount>(2u), 4u, 8u, 32u, 16u);
auto* named_override_sized =
create<Array>(create<U32>(), NamedOverrideArrayCount{}, 4u, 8u, 32u, 16u);
create<Array>(create<U32>(), create<NamedOverrideArrayCount>(nullptr), 4u, 8u, 32u, 16u);
auto* unnamed_override_sized =
create<Array>(create<U32>(), UnnamedOverrideArrayCount{}, 4u, 8u, 32u, 16u);
auto* runtime_sized = create<Array>(create<U32>(), RuntimeArrayCount{}, 4u, 8u, 32u, 16u);
create<Array>(create<U32>(), create<UnnamedOverrideArrayCount>(nullptr), 4u, 8u, 32u, 16u);
auto* runtime_sized =
create<Array>(create<U32>(), create<RuntimeArrayCount>(), 4u, 8u, 32u, 16u);
EXPECT_TRUE(fixed_sized->HasFixedFootprint());
EXPECT_TRUE(named_override_sized->HasFixedFootprint());

View File

@ -273,7 +273,7 @@ const Type* Type::ElementOf(const Type* ty, uint32_t* count /* = nullptr */) {
},
[&](const Array* a) {
if (count) {
if (auto* const_count = std::get_if<ConstantArrayCount>(&a->Count())) {
if (auto* const_count = a->Count()->As<ConstantArrayCount>()) {
*count = const_count->value;
}
}

View File

@ -19,6 +19,7 @@
#include <unordered_map>
#include <utility>
#include "src/tint/sem/array_count.h"
#include "src/tint/sem/type.h"
#include "src/tint/utils/unique_allocator.h"
@ -56,6 +57,7 @@ class TypeManager final {
static TypeManager Wrap(const TypeManager& inner) {
TypeManager out;
out.types_.Wrap(inner.types_);
out.array_counts_.Wrap(inner.array_counts_);
return out;
}
@ -80,6 +82,17 @@ class TypeManager final {
return types_.Find<TYPE>(std::forward<ARGS>(args)...);
}
/// @param args the arguments used to construct the object.
/// @return a pointer to an instance of `T` with the provided arguments.
/// If an existing instance of `T` has been constructed, then the same
/// pointer is returned.
template <typename TYPE,
typename _ = std::enable_if<traits::IsTypeOrDerived<TYPE, sem::ArrayCount>>,
typename... ARGS>
TYPE* GetArrayCount(ARGS&&... args) {
return array_counts_.Get<TYPE>(std::forward<ARGS>(args)...);
}
/// @returns an iterator to the beginning of the types
TypeIterator begin() const { return types_.begin(); }
/// @returns an iterator to the end of the types
@ -87,6 +100,7 @@ class TypeManager final {
private:
utils::UniqueAllocator<Type> types_;
utils::UniqueAllocator<ArrayCount> array_counts_;
};
} // namespace tint::sem

View File

@ -100,63 +100,63 @@ struct TypeTest : public TestHelper {
/* size_no_padding*/ 4u);
const sem::Array* arr_i32 = create<Array>(
/* element */ i32,
/* count */ ConstantArrayCount{5u},
/* count */ create<ConstantArrayCount>(5u),
/* align */ 4u,
/* size */ 5u * 4u,
/* stride */ 5u * 4u,
/* implicit_stride */ 5u * 4u);
const sem::Array* arr_ai = create<Array>(
/* element */ ai,
/* count */ ConstantArrayCount{5u},
/* count */ create<ConstantArrayCount>(5u),
/* align */ 4u,
/* size */ 5u * 4u,
/* stride */ 5u * 4u,
/* implicit_stride */ 5u * 4u);
const sem::Array* arr_vec3_i32 = create<Array>(
/* element */ vec3_i32,
/* count */ ConstantArrayCount{5u},
/* count */ create<ConstantArrayCount>(5u),
/* align */ 16u,
/* size */ 5u * 16u,
/* stride */ 5u * 16u,
/* implicit_stride */ 5u * 16u);
const sem::Array* arr_vec3_ai = create<Array>(
/* element */ vec3_ai,
/* count */ ConstantArrayCount{5u},
/* count */ create<ConstantArrayCount>(5u),
/* align */ 16u,
/* size */ 5u * 16u,
/* stride */ 5u * 16u,
/* implicit_stride */ 5u * 16u);
const sem::Array* arr_mat4x3_f16 = create<Array>(
/* element */ mat4x3_f16,
/* count */ ConstantArrayCount{5u},
/* count */ create<ConstantArrayCount>(5u),
/* align */ 32u,
/* size */ 5u * 32u,
/* stride */ 5u * 32u,
/* implicit_stride */ 5u * 32u);
const sem::Array* arr_mat4x3_f32 = create<Array>(
/* element */ mat4x3_f32,
/* count */ ConstantArrayCount{5u},
/* count */ create<ConstantArrayCount>(5u),
/* align */ 64u,
/* size */ 5u * 64u,
/* stride */ 5u * 64u,
/* implicit_stride */ 5u * 64u);
const sem::Array* arr_mat4x3_af = create<Array>(
/* element */ mat4x3_af,
/* count */ ConstantArrayCount{5u},
/* count */ create<ConstantArrayCount>(5u),
/* align */ 64u,
/* size */ 5u * 64u,
/* stride */ 5u * 64u,
/* implicit_stride */ 5u * 64u);
const sem::Array* arr_str_f16 = create<Array>(
/* element */ str_f16,
/* count */ ConstantArrayCount{5u},
/* count */ create<ConstantArrayCount>(5u),
/* align */ 4u,
/* size */ 5u * 4u,
/* stride */ 5u * 4u,
/* implicit_stride */ 5u * 4u);
const sem::Array* arr_str_af = create<Array>(
/* element */ str_af,
/* count */ ConstantArrayCount{5u},
/* count */ create<ConstantArrayCount>(5u),
/* align */ 4u,
/* size */ 5u * 4u,
/* stride */ 5u * 4u,

View File

@ -109,11 +109,11 @@ const ast::Type* Transform::CreateASTTypeFor(CloneContext& ctx, const sem::Type*
if (a->IsRuntimeSized()) {
return ctx.dst->ty.array(el, nullptr, std::move(attrs));
}
if (auto* override = std::get_if<sem::NamedOverrideArrayCount>(&a->Count())) {
if (auto* override = a->Count()->As<sem::NamedOverrideArrayCount>()) {
auto* count = ctx.Clone(override->variable->Declaration());
return ctx.dst->ty.array(el, count, std::move(attrs));
}
if (auto* override = std::get_if<sem::UnnamedOverrideArrayCount>(&a->Count())) {
if (auto* override = a->Count()->As<sem::UnnamedOverrideArrayCount>()) {
// If the array count is an unnamed (complex) override expression, then its not safe to
// redeclare this type as we'd end up with two types that would not compare equal.
// See crbug.com/tint/1764.

View File

@ -69,8 +69,8 @@ TEST_F(CreateASTTypeForTest, Vector) {
TEST_F(CreateASTTypeForTest, ArrayImplicitStride) {
auto* arr = create([](ProgramBuilder& b) {
return b.create<sem::Array>(b.create<sem::F32>(), sem::ConstantArrayCount{2u}, 4u, 4u, 32u,
32u);
return b.create<sem::Array>(b.create<sem::F32>(), b.create<sem::ConstantArrayCount>(2u), 4u,
4u, 32u, 32u);
});
ASSERT_TRUE(arr->Is<ast::Array>());
ASSERT_TRUE(arr->As<ast::Array>()->type->Is<ast::F32>());
@ -83,8 +83,8 @@ TEST_F(CreateASTTypeForTest, ArrayImplicitStride) {
TEST_F(CreateASTTypeForTest, ArrayNonImplicitStride) {
auto* arr = create([](ProgramBuilder& b) {
return b.create<sem::Array>(b.create<sem::F32>(), sem::ConstantArrayCount{2u}, 4u, 4u, 64u,
32u);
return b.create<sem::Array>(b.create<sem::F32>(), b.create<sem::ConstantArrayCount>(2u), 4u,
4u, 64u, 32u);
});
ASSERT_TRUE(arr->Is<ast::Array>());
ASSERT_TRUE(arr->As<ast::Array>()->type->Is<ast::F32>());

View File

@ -91,7 +91,7 @@ class UniqueAllocator {
struct Entry {
/// The pre-calculated hash of the entry
size_t hash;
/// Tge pointer to the unique object
/// The pointer to the unique object
T* ptr;
};
/// Comparator is the hashing and equality function used by the unordered_set