diff --git a/src/tint/BUILD.gn b/src/tint/BUILD.gn index 3d10a06585..6a0e70d3ea 100644 --- a/src/tint/BUILD.gn +++ b/src/tint/BUILD.gn @@ -419,6 +419,7 @@ libtint_source_set("libtint_core_all_src") { "sem/abstract_int.h", "sem/abstract_numeric.h", "sem/array.h", + "sem/array_count.h", "sem/atomic.h", "sem/behavior.h", "sem/binding_point.h", @@ -635,6 +636,8 @@ libtint_source_set("libtint_sem_src") { "sem/abstract_numeric.h", "sem/array.cc", "sem/array.h", + "sem/array_count.cc", + "sem/array_count.h", "sem/atomic.cc", "sem/atomic.h", "sem/behavior.cc", diff --git a/src/tint/CMakeLists.txt b/src/tint/CMakeLists.txt index cb9c96efff..72cec7b847 100644 --- a/src/tint/CMakeLists.txt +++ b/src/tint/CMakeLists.txt @@ -299,6 +299,8 @@ list(APPEND TINT_LIB_SRCS sem/abstract_numeric.h sem/array.cc sem/array.h + sem/array_count.cc + sem/array_count.h sem/atomic.cc sem/atomic.h sem/behavior.cc diff --git a/src/tint/program_builder.h b/src/tint/program_builder.h index d6d1f0059a..4677009e00 100644 --- a/src/tint/program_builder.h +++ b/src/tint/program_builder.h @@ -91,6 +91,7 @@ #include "src/tint/program.h" #include "src/tint/program_id.h" #include "src/tint/sem/array.h" +#include "src/tint/sem/array_count.h" #include "src/tint/sem/bool.h" #include "src/tint/sem/constant.h" #include "src/tint/sem/depth_texture.h" @@ -457,7 +458,8 @@ class ProgramBuilder { /// @returns the node pointer template traits::EnableIf && - !traits::IsTypeOrDerived, + !traits::IsTypeOrDerived && + !traits::IsTypeOrDerived, T>* create(ARGS&&... args) { AssertNotMoved(); @@ -476,17 +478,28 @@ class ProgramBuilder { /// Creates a new sem::Type owned by the ProgramBuilder. /// When the ProgramBuilder is destructed, owned ProgramBuilder and the - /// returned`Type` will also be destructed. + /// returned `Type` will also be destructed. /// Types are unique (de-aliased), and so calling create() for the same `T` /// and arguments will return the same pointer. /// @param args the arguments to pass to the type constructor /// @returns the de-aliased type pointer template traits::EnableIfIsType* create(ARGS&&... args) { - static_assert(std::is_base_of::value, "T does not derive from sem::Type"); AssertNotMoved(); return types_.Get(std::forward(args)...); } + /// Creates a new sem::ArrayCount owned by the ProgramBuilder. + /// When the ProgramBuilder is destructed, owned ProgramBuilder and the + /// returned `ArrayCount` will also be destructed. + /// ArrayCounts are unique (de-aliased), and so calling create() for the same `T` + /// and arguments will return the same pointer. + /// @param args the arguments to pass to the array count constructor + /// @returns the de-aliased array count pointer + template + traits::EnableIfIsType* create(ARGS&&... args) { + AssertNotMoved(); + return types_.GetArrayCount(std::forward(args)...); + } /// Marks this builder as moved, preventing any further use of the builder. void MarkAsMoved(); diff --git a/src/tint/resolver/const_eval_construction_test.cc b/src/tint/resolver/const_eval_construction_test.cc index 9df980750c..cabc4abe75 100644 --- a/src/tint/resolver/const_eval_construction_test.cc +++ b/src/tint/resolver/const_eval_construction_test.cc @@ -1321,7 +1321,6 @@ TEST_F(ResolverConstEvalTest, Array_i32_Zero) { auto* arr = sem->Type()->As(); ASSERT_NE(arr, nullptr); EXPECT_TRUE(arr->ElemType()->Is()); - EXPECT_EQ(arr->Count(), sem::ConstantArrayCount{4u}); EXPECT_TYPE(sem->ConstantValue()->Type(), sem->Type()); EXPECT_TRUE(sem->ConstantValue()->AllEqual()); EXPECT_TRUE(sem->ConstantValue()->AnyZero()); @@ -1359,7 +1358,6 @@ TEST_F(ResolverConstEvalTest, Array_f32_Zero) { auto* arr = sem->Type()->As(); ASSERT_NE(arr, nullptr); EXPECT_TRUE(arr->ElemType()->Is()); - EXPECT_EQ(arr->Count(), sem::ConstantArrayCount{4u}); EXPECT_TYPE(sem->ConstantValue()->Type(), sem->Type()); EXPECT_TRUE(sem->ConstantValue()->AllEqual()); EXPECT_TRUE(sem->ConstantValue()->AnyZero()); @@ -1397,7 +1395,6 @@ TEST_F(ResolverConstEvalTest, Array_vec3_f32_Zero) { auto* arr = sem->Type()->As(); ASSERT_NE(arr, nullptr); EXPECT_TRUE(arr->ElemType()->Is()); - EXPECT_EQ(arr->Count(), sem::ConstantArrayCount{2u}); EXPECT_TYPE(sem->ConstantValue()->Type(), sem->Type()); EXPECT_TRUE(sem->ConstantValue()->AllEqual()); EXPECT_TRUE(sem->ConstantValue()->AnyZero()); @@ -1449,7 +1446,6 @@ TEST_F(ResolverConstEvalTest, Array_Struct_f32_Zero) { auto* arr = sem->Type()->As(); ASSERT_NE(arr, nullptr); EXPECT_TRUE(arr->ElemType()->Is()); - EXPECT_EQ(arr->Count(), sem::ConstantArrayCount{2u}); EXPECT_TYPE(sem->ConstantValue()->Type(), sem->Type()); EXPECT_TRUE(sem->ConstantValue()->AllEqual()); EXPECT_TRUE(sem->ConstantValue()->AnyZero()); @@ -1487,7 +1483,6 @@ TEST_F(ResolverConstEvalTest, Array_i32_Elements) { auto* arr = sem->Type()->As(); ASSERT_NE(arr, nullptr); EXPECT_TRUE(arr->ElemType()->Is()); - EXPECT_EQ(arr->Count(), sem::ConstantArrayCount{4u}); EXPECT_TYPE(sem->ConstantValue()->Type(), sem->Type()); EXPECT_FALSE(sem->ConstantValue()->AllEqual()); EXPECT_FALSE(sem->ConstantValue()->AnyZero()); @@ -1525,7 +1520,6 @@ TEST_F(ResolverConstEvalTest, Array_f32_Elements) { auto* arr = sem->Type()->As(); ASSERT_NE(arr, nullptr); EXPECT_TRUE(arr->ElemType()->Is()); - EXPECT_EQ(arr->Count(), sem::ConstantArrayCount{4u}); EXPECT_TYPE(sem->ConstantValue()->Type(), sem->Type()); EXPECT_FALSE(sem->ConstantValue()->AllEqual()); EXPECT_FALSE(sem->ConstantValue()->AnyZero()); @@ -1564,7 +1558,6 @@ TEST_F(ResolverConstEvalTest, Array_vec3_f32_Elements) { auto* arr = sem->Type()->As(); ASSERT_NE(arr, nullptr); EXPECT_TRUE(arr->ElemType()->Is()); - EXPECT_EQ(arr->Count(), sem::ConstantArrayCount{2u}); EXPECT_TYPE(sem->ConstantValue()->Type(), sem->Type()); EXPECT_FALSE(sem->ConstantValue()->AllEqual()); EXPECT_FALSE(sem->ConstantValue()->AnyZero()); @@ -1594,7 +1587,6 @@ TEST_F(ResolverConstEvalTest, Array_Struct_f32_Elements) { auto* arr = sem->Type()->As(); ASSERT_NE(arr, nullptr); EXPECT_TRUE(arr->ElemType()->Is()); - EXPECT_EQ(arr->Count(), sem::ConstantArrayCount{2u}); EXPECT_TYPE(sem->ConstantValue()->Type(), sem->Type()); EXPECT_FALSE(sem->ConstantValue()->AllEqual()); EXPECT_FALSE(sem->ConstantValue()->AnyZero()); diff --git a/src/tint/resolver/inferred_type_test.cc b/src/tint/resolver/inferred_type_test.cc index 469a5e0818..ddbc8f8d47 100644 --- a/src/tint/resolver/inferred_type_test.cc +++ b/src/tint/resolver/inferred_type_test.cc @@ -135,8 +135,8 @@ INSTANTIATE_TEST_SUITE_P(ResolverTest, ResolverInferredTypeParamTest, testing::V TEST_F(ResolverInferredTypeTest, InferArray_Pass) { auto* type = ty.array(ty.u32(), 10_u); - auto* expected_type = - create(create(), sem::ConstantArrayCount{10u}, 4u, 4u * 10u, 4u, 4u); + auto* expected_type = create( + create(), create(10u), 4u, 4u * 10u, 4u, 4u); auto* ctor_expr = Construct(type); auto* var = Var("a", ast::AddressSpace::kFunction, ctor_expr); diff --git a/src/tint/resolver/intrinsic_table.cc b/src/tint/resolver/intrinsic_table.cc index e58e9807e4..aed49c3566 100644 --- a/src/tint/resolver/intrinsic_table.cc +++ b/src/tint/resolver/intrinsic_table.cc @@ -532,12 +532,13 @@ bool match_array(MatchState&, const sem::Type* ty, const sem::Type*& T) { } const sem::Array* build_array(MatchState& state, const sem::Type* el) { - return state.builder.create(el, - /* count */ sem::RuntimeArrayCount{}, - /* align */ 0u, - /* size */ 0u, - /* stride */ 0u, - /* stride_implicit */ 0u); + return state.builder.create( + el, + /* count */ state.builder.create(), + /* align */ 0u, + /* size */ 0u, + /* stride */ 0u, + /* stride_implicit */ 0u); } bool match_ptr(MatchState&, const sem::Type* ty, Number& S, const sem::Type*& T, Number& A) { diff --git a/src/tint/resolver/intrinsic_table_test.cc b/src/tint/resolver/intrinsic_table_test.cc index 1608aa2582..03255dfd44 100644 --- a/src/tint/resolver/intrinsic_table_test.cc +++ b/src/tint/resolver/intrinsic_table_test.cc @@ -252,7 +252,8 @@ TEST_F(IntrinsicTableTest, MismatchPointer) { } TEST_F(IntrinsicTableTest, MatchArray) { - auto* arr = create(create(), sem::RuntimeArrayCount{}, 4u, 4u, 4u, 4u); + auto* arr = + create(create(), create(), 4u, 4u, 4u, 4u); auto* arr_ptr = create(arr, ast::AddressSpace::kStorage, ast::Access::kReadWrite); auto result = table->Lookup(BuiltinType::kArrayLength, utils::Vector{arr_ptr}, sem::EvaluationStage::kConstant, Source{}); @@ -955,7 +956,8 @@ TEST_F(IntrinsicTableTest, MatchTypeConversion) { } TEST_F(IntrinsicTableTest, MismatchTypeConversion) { - auto* arr = create(create(), sem::RuntimeArrayCount{}, 4u, 4u, 4u, 4u); + auto* arr = + create(create(), create(), 4u, 4u, 4u, 4u); auto* f32 = create(); auto result = table->Lookup(InitConvIntrinsic::kVec3, f32, utils::Vector{arr}, sem::EvaluationStage::kConstant, Source{{12, 34}}); diff --git a/src/tint/resolver/is_host_shareable_test.cc b/src/tint/resolver/is_host_shareable_test.cc index ab969fa89e..5e1555b6b5 100644 --- a/src/tint/resolver/is_host_shareable_test.cc +++ b/src/tint/resolver/is_host_shareable_test.cc @@ -106,13 +106,14 @@ TEST_F(ResolverIsHostShareable, Atomic) { } TEST_F(ResolverIsHostShareable, ArraySizedOfHostShareable) { - auto* arr = - create(create(), sem::ConstantArrayCount{5u}, 4u, 20u, 4u, 4u); + auto* arr = create(create(), create(5u), 4u, 20u, + 4u, 4u); EXPECT_TRUE(r()->IsHostShareable(arr)); } TEST_F(ResolverIsHostShareable, ArrayUnsizedOfHostShareable) { - auto* arr = create(create(), sem::RuntimeArrayCount{}, 4u, 4u, 4u, 4u); + auto* arr = + create(create(), create(), 4u, 4u, 4u, 4u); EXPECT_TRUE(r()->IsHostShareable(arr)); } diff --git a/src/tint/resolver/is_storeable_test.cc b/src/tint/resolver/is_storeable_test.cc index 61cf33b42d..6618199113 100644 --- a/src/tint/resolver/is_storeable_test.cc +++ b/src/tint/resolver/is_storeable_test.cc @@ -89,13 +89,14 @@ TEST_F(ResolverIsStorableTest, Atomic) { } TEST_F(ResolverIsStorableTest, ArraySizedOfStorable) { - auto* arr = - create(create(), sem::ConstantArrayCount{5u}, 4u, 20u, 4u, 4u); + auto* arr = create(create(), create(5u), 4u, 20u, + 4u, 4u); EXPECT_TRUE(r()->IsStorable(arr)); } TEST_F(ResolverIsStorableTest, ArrayUnsizedOfStorable) { - auto* arr = create(create(), sem::RuntimeArrayCount{}, 4u, 4u, 4u, 4u); + auto* arr = + create(create(), create(), 4u, 4u, 4u, 4u); EXPECT_TRUE(r()->IsStorable(arr)); } diff --git a/src/tint/resolver/resolver.cc b/src/tint/resolver/resolver.cc index 60f8207798..f909bf0394 100644 --- a/src/tint/resolver/resolver.cc +++ b/src/tint/resolver/resolver.cc @@ -2143,8 +2143,7 @@ sem::Call* Resolver::Call(const ast::CallExpression* expr) { [&](const ast::Array* a) -> sem::Call* { Mark(a); // array element type must be inferred if it was not specified. - sem::ArrayCount el_count = - sem::ConstantArrayCount{static_cast(args.Length())}; + const sem::ArrayCount* el_count = nullptr; const sem::Type* el_ty = nullptr; if (a->type) { el_ty = Type(a->type); @@ -2155,14 +2154,15 @@ sem::Call* Resolver::Call(const ast::CallExpression* expr) { AddError("cannot construct a runtime-sized array", expr->source); return nullptr; } - if (auto count = ArrayCount(a->count)) { - el_count = count.Get(); - } else { + el_count = ArrayCount(a->count); + if (!el_count) { return nullptr; } // Note: validation later will detect any mismatches between explicit array // size and number of initializer expressions. } else { + el_count = builder_->create( + static_cast(args.Length())); auto arg_tys = utils::Transform(args, [](auto* arg) { return arg->Type()->UnwrapRef(); }); el_ty = sem::Type::Common(arg_tys); @@ -2936,15 +2936,16 @@ sem::Array* Resolver::Array(const ast::Array* arr) { return nullptr; } - sem::ArrayCount el_count = sem::RuntimeArrayCount{}; + const sem::ArrayCount* el_count = nullptr; // Evaluate the constant array count expression. if (auto* count_expr = arr->count) { - if (auto count = ArrayCount(count_expr)) { - el_count = count.Get(); - } else { + el_count = ArrayCount(count_expr); + if (!el_count) { return nullptr; } + } else { + el_count = builder_->create(); } auto* out = Array(arr->type->source, // @@ -2971,11 +2972,11 @@ sem::Array* Resolver::Array(const ast::Array* arr) { return out; } -utils::Result Resolver::ArrayCount(const ast::Expression* count_expr) { +const sem::ArrayCount* Resolver::ArrayCount(const ast::Expression* count_expr) { // Evaluate the constant array count expression. const auto* count_sem = Materialize(Expression(count_expr)); if (!count_sem) { - return utils::Failure; + return nullptr; } if (count_sem->Stage() == sem::EvaluationStage::kOverride) { @@ -2983,34 +2984,34 @@ utils::Result Resolver::ArrayCount(const ast::Expression* count // Is the count a named 'override'? if (auto* user = count_sem->UnwrapMaterialize()->As()) { if (auto* global = user->Variable()->As()) { - return sem::ArrayCount{sem::NamedOverrideArrayCount{global}}; + return builder_->create(global); } } - return sem::ArrayCount{sem::UnnamedOverrideArrayCount{count_sem}}; + return builder_->create(count_sem); } auto* count_val = count_sem->ConstantValue(); if (!count_val) { AddError("array count must evaluate to a constant integer expression or override variable", count_expr->source); - return utils::Failure; + return nullptr; } if (auto* ty = count_val->Type(); !ty->is_integer_scalar()) { AddError("array count must evaluate to a constant integer expression, but is type '" + builder_->FriendlyName(ty) + "'", count_expr->source); - return utils::Failure; + return nullptr; } int64_t count = count_val->As(); if (count < 1) { AddError("array count (" + std::to_string(count) + ") must be greater than 0", count_expr->source); - return utils::Failure; + return nullptr; } - return sem::ArrayCount{sem::ConstantArrayCount{static_cast(count)}}; + return builder_->create(static_cast(count)); } bool Resolver::ArrayAttributes(utils::VectorRef attributes, @@ -3046,7 +3047,7 @@ bool Resolver::ArrayAttributes(utils::VectorRef attribute sem::Array* Resolver::Array(const Source& el_source, const Source& count_source, const sem::Type* el_ty, - sem::ArrayCount el_count, + const sem::ArrayCount* el_count, uint32_t explicit_stride) { uint32_t el_align = el_ty->Align(); uint32_t el_size = el_ty->Size(); @@ -3054,7 +3055,7 @@ sem::Array* Resolver::Array(const Source& el_source, uint64_t stride = explicit_stride ? explicit_stride : implicit_stride; uint64_t size = 0; - if (auto const_count = std::get_if(&el_count)) { + if (auto const_count = el_count->As()) { size = const_count->value * stride; if (size > std::numeric_limits::max()) { std::stringstream msg; @@ -3063,7 +3064,7 @@ sem::Array* Resolver::Array(const Source& el_source, AddError(msg.str(), count_source); return nullptr; } - } else if (std::holds_alternative(el_count)) { + } else if (el_count->Is()) { size = stride; } auto* out = builder_->create(el_ty, el_count, el_align, static_cast(size), diff --git a/src/tint/resolver/resolver.h b/src/tint/resolver/resolver.h index 56ae91214c..0deef5ffa4 100644 --- a/src/tint/resolver/resolver.h +++ b/src/tint/resolver/resolver.h @@ -273,7 +273,7 @@ class Resolver { /// Resolves and validates the expression used as the count parameter of an array. /// @param count_expr the expression used as the second template parameter to an array<>. /// @returns the number of elements in the array. - utils::Result ArrayCount(const ast::Expression* count_expr); + const sem::ArrayCount* ArrayCount(const ast::Expression* count_expr); /// Resolves and validates the attributes on an array. /// @param attributes the attributes on the array type. @@ -296,7 +296,7 @@ class Resolver { sem::Array* Array(const Source& el_source, const Source& count_source, const sem::Type* el_ty, - sem::ArrayCount el_count, + const sem::ArrayCount* el_count, uint32_t explicit_stride); /// Builds and returns the semantic information for the alias `alias`. diff --git a/src/tint/resolver/resolver_test.cc b/src/tint/resolver/resolver_test.cc index 13b1da1e0a..2a882494b5 100644 --- a/src/tint/resolver/resolver_test.cc +++ b/src/tint/resolver/resolver_test.cc @@ -440,7 +440,7 @@ TEST_F(ResolverTest, ArraySize_UnsignedLiteral) { auto* ref = TypeOf(a)->As(); ASSERT_NE(ref, nullptr); auto* ary = ref->StoreType()->As(); - EXPECT_EQ(ary->Count(), sem::ConstantArrayCount{10u}); + EXPECT_EQ(ary->Count(), create(10u)); } TEST_F(ResolverTest, ArraySize_SignedLiteral) { @@ -453,7 +453,7 @@ TEST_F(ResolverTest, ArraySize_SignedLiteral) { auto* ref = TypeOf(a)->As(); ASSERT_NE(ref, nullptr); auto* ary = ref->StoreType()->As(); - EXPECT_EQ(ary->Count(), sem::ConstantArrayCount{10u}); + EXPECT_EQ(ary->Count(), create(10u)); } TEST_F(ResolverTest, ArraySize_UnsignedConst) { @@ -468,7 +468,7 @@ TEST_F(ResolverTest, ArraySize_UnsignedConst) { auto* ref = TypeOf(a)->As(); ASSERT_NE(ref, nullptr); auto* ary = ref->StoreType()->As(); - EXPECT_EQ(ary->Count(), sem::ConstantArrayCount{10u}); + EXPECT_EQ(ary->Count(), create(10u)); } TEST_F(ResolverTest, ArraySize_SignedConst) { @@ -483,7 +483,7 @@ TEST_F(ResolverTest, ArraySize_SignedConst) { auto* ref = TypeOf(a)->As(); ASSERT_NE(ref, nullptr); auto* ary = ref->StoreType()->As(); - EXPECT_EQ(ary->Count(), sem::ConstantArrayCount{10u}); + EXPECT_EQ(ary->Count(), create(10u)); } TEST_F(ResolverTest, ArraySize_NamedOverride) { @@ -500,7 +500,7 @@ TEST_F(ResolverTest, ArraySize_NamedOverride) { auto* ary = ref->StoreType()->As(); auto* sem_override = Sem().Get(override); ASSERT_NE(sem_override, nullptr); - EXPECT_EQ(ary->Count(), sem::NamedOverrideArrayCount{sem_override}); + EXPECT_EQ(ary->Count(), create(sem_override)); } TEST_F(ResolverTest, ArraySize_NamedOverride_Equivalence) { @@ -525,8 +525,8 @@ TEST_F(ResolverTest, ArraySize_NamedOverride_Equivalence) { auto* sem_override = Sem().Get(override); ASSERT_NE(sem_override, nullptr); - EXPECT_EQ(ary_a->Count(), sem::NamedOverrideArrayCount{sem_override}); - EXPECT_EQ(ary_b->Count(), sem::NamedOverrideArrayCount{sem_override}); + EXPECT_EQ(ary_a->Count(), create(sem_override)); + EXPECT_EQ(ary_b->Count(), create(sem_override)); EXPECT_EQ(ary_a, ary_b); } @@ -545,7 +545,7 @@ TEST_F(ResolverTest, ArraySize_UnnamedOverride) { auto* ary = ref->StoreType()->As(); auto* sem_override = Sem().Get(override); ASSERT_NE(sem_override, nullptr); - EXPECT_EQ(ary->Count(), sem::UnnamedOverrideArrayCount{Sem().Get(cnt)}); + EXPECT_EQ(ary->Count(), create(Sem().Get(cnt))); } TEST_F(ResolverTest, ArraySize_UnamedOverride_Equivalence) { @@ -572,8 +572,8 @@ TEST_F(ResolverTest, ArraySize_UnamedOverride_Equivalence) { auto* sem_override = Sem().Get(override); ASSERT_NE(sem_override, nullptr); - EXPECT_EQ(ary_a->Count(), sem::UnnamedOverrideArrayCount{Sem().Get(a_cnt)}); - EXPECT_EQ(ary_b->Count(), sem::UnnamedOverrideArrayCount{Sem().Get(b_cnt)}); + EXPECT_EQ(ary_a->Count(), create(Sem().Get(a_cnt))); + EXPECT_EQ(ary_b->Count(), create(Sem().Get(b_cnt))); EXPECT_NE(ary_a, ary_b); } diff --git a/src/tint/resolver/resolver_test_helper.h b/src/tint/resolver/resolver_test_helper.h index fe33c265d0..f2567c81cd 100644 --- a/src/tint/resolver/resolver_test_helper.h +++ b/src/tint/resolver/resolver_test_helper.h @@ -659,9 +659,11 @@ struct DataType> { /// @return the semantic array type static inline const sem::Type* Sem(ProgramBuilder& b) { auto* el = DataType::Sem(b); - sem::ArrayCount count = sem::ConstantArrayCount{N}; + const sem::ArrayCount* count = nullptr; if (N == 0) { - count = sem::RuntimeArrayCount{}; + count = b.create(); + } else { + count = b.create(N); } return b.create( /* element */ el, diff --git a/src/tint/resolver/validator.cc b/src/tint/resolver/validator.cc index 83c8c7831a..5a7f9d88e6 100644 --- a/src/tint/resolver/validator.cc +++ b/src/tint/resolver/validator.cc @@ -1778,7 +1778,12 @@ bool Validator::ArrayInitializer(const ast::CallExpression* ctor, return false; } - const auto count = std::get(array_type->Count()).value; + if (!array_type->IsConstantSized()) { + TINT_ICE(Resolver, diagnostics_) << "Invalid ArrayCount found"; + return false; + } + + const auto count = array_type->Count()->As()->value; if (!values.IsEmpty() && (values.Length() != count)) { std::string fm = values.Length() < count ? "few" : "many"; AddError("array initializer has too " + fm + " elements: expected " + diff --git a/src/tint/resolver/validator_is_storeable_test.cc b/src/tint/resolver/validator_is_storeable_test.cc index 9fa064ac77..cd079ce171 100644 --- a/src/tint/resolver/validator_is_storeable_test.cc +++ b/src/tint/resolver/validator_is_storeable_test.cc @@ -89,13 +89,14 @@ TEST_F(ValidatorIsStorableTest, Atomic) { } TEST_F(ValidatorIsStorableTest, ArraySizedOfStorable) { - auto* arr = - create(create(), sem::ConstantArrayCount{5u}, 4u, 20u, 4u, 4u); + auto* arr = create(create(), create(5u), 4u, 20u, + 4u, 4u); EXPECT_TRUE(v()->IsStorable(arr)); } TEST_F(ValidatorIsStorableTest, ArrayUnsizedOfStorable) { - auto* arr = create(create(), sem::RuntimeArrayCount{}, 4u, 4u, 4u, 4u); + auto* arr = + create(create(), create(), 4u, 4u, 4u, 4u); EXPECT_TRUE(v()->IsStorable(arr)); } diff --git a/src/tint/sem/array.cc b/src/tint/sem/array.cc index d61d430784..207a6278a5 100644 --- a/src/tint/sem/array.cc +++ b/src/tint/sem/array.cc @@ -28,10 +28,10 @@ namespace tint::sem { namespace { -TypeFlags FlagsFrom(const Type* element, ArrayCount count) { +TypeFlags FlagsFrom(const Type* element, const ArrayCount* count) { TypeFlags flags; // Only constant-expression sized arrays are constructible - if (std::holds_alternative(count)) { + if (count->Is()) { if (element->IsConstructible()) { flags.Add(TypeFlag::kConstructable); } @@ -39,9 +39,7 @@ TypeFlags FlagsFrom(const Type* element, ArrayCount count) { flags.Add(TypeFlag::kCreationFixedFootprint); } } - if (std::holds_alternative(count) || - std::holds_alternative(count) || - std::holds_alternative(count)) { + if (count->IsAnyOf()) { if (element->HasFixedFootprint()) { flags.Add(TypeFlag::kFixedFootprint); } @@ -56,7 +54,7 @@ const char* const Array::kErrExpectedConstantCount = "Was the SubstituteOverride transform run?"; Array::Array(const Type* element, - ArrayCount count, + const ArrayCount* count, uint32_t align, uint32_t size, uint32_t stride, @@ -91,11 +89,11 @@ std::string Array::FriendlyName(const SymbolTable& symbols) const { out << "@stride(" << stride_ << ") "; } out << "array<" << element_->FriendlyName(symbols); - if (auto* const_count = std::get_if(&count_)) { + if (auto* const_count = count_->As()) { out << ", " << const_count->value; - } else if (auto* named_override_count = std::get_if(&count_)) { + } else if (auto* named_override_count = count_->As()) { out << ", " << symbols.NameFor(named_override_count->variable->Declaration()->symbol); - } else if (std::holds_alternative(count_)) { + } else if (count_->Is()) { out << ", [unnamed override-expression]"; } out << ">"; diff --git a/src/tint/sem/array.h b/src/tint/sem/array.h index 4d1ed7dce6..c068948499 100644 --- a/src/tint/sem/array.h +++ b/src/tint/sem/array.h @@ -20,6 +20,7 @@ #include #include +#include "src/tint/sem/array_count.h" #include "src/tint/sem/node.h" #include "src/tint/sem/type.h" #include "src/tint/utils/compiler_macros.h" @@ -33,115 +34,6 @@ class GlobalVariable; namespace tint::sem { -/// The variant of an ArrayCount when the array is a const-expression. -/// Example: -/// ``` -/// const N = 123; -/// type arr = array -/// ``` -struct ConstantArrayCount { - /// The array count constant-expression value. - uint32_t value; -}; - -/// The variant of an ArrayCount when the count is a named override variable. -/// Example: -/// ``` -/// override N : i32; -/// type arr = array -/// ``` -struct NamedOverrideArrayCount { - /// The `override` variable. - const GlobalVariable* variable; -}; - -/// The variant of an ArrayCount when the count is an unnamed override variable. -/// Example: -/// ``` -/// override N : i32; -/// type arr = array -/// ``` -struct UnnamedOverrideArrayCount { - /// The unnamed override expression. - /// Note: Each AST expression gets a unique semantic expression node, so two equivalent AST - /// expressions will not result in the same `expr` pointer. This property is important to ensure - /// that two array declarations with equivalent AST expressions do not compare equal. - /// For example, consider: - /// ``` - /// override size : u32; - /// var a : array; - /// var b : array; - /// ``` - // The array count for `a` and `b` have equivalent AST expressions, but the types for `a` and - // `b` must not compare equal. - const Expression* expr; -}; - -/// The variant of an ArrayCount when the array is is runtime-sized. -/// Example: -/// ``` -/// type arr = array -/// ``` -struct RuntimeArrayCount {}; - -/// An array count is either a constant-expression value, a named override identifier, an unnamed -/// override identifier, or runtime-sized. -using ArrayCount = std::variant; - -/// Equality operator -/// @param a the LHS ConstantArrayCount -/// @param b the RHS ConstantArrayCount -/// @returns true if @p a is equal to @p b -inline bool operator==(const ConstantArrayCount& a, const ConstantArrayCount& b) { - return a.value == b.value; -} - -/// Equality operator -/// @param a the LHS OverrideArrayCount -/// @param b the RHS OverrideArrayCount -/// @returns true if @p a is equal to @p b -inline bool operator==(const NamedOverrideArrayCount& a, const NamedOverrideArrayCount& b) { - return a.variable == b.variable; -} - -/// Equality operator -/// @param a the LHS OverrideArrayCount -/// @param b the RHS OverrideArrayCount -/// @returns true if @p a is equal to @p b -inline bool operator==(const UnnamedOverrideArrayCount& a, const UnnamedOverrideArrayCount& b) { - return a.expr == b.expr; -} - -/// Equality operator -/// @returns true -inline bool operator==(const RuntimeArrayCount&, const RuntimeArrayCount&) { - return true; -} - -/// Equality operator -/// @param a the LHS ArrayCount -/// @param b the RHS count -/// @returns true if @p a is equal to @p b -template || std::is_same_v || - std::is_same_v || std::is_same_v>> -inline bool operator==(const ArrayCount& a, const T& b) { - TINT_BEGIN_DISABLE_WARNING(UNREACHABLE_CODE); - return std::visit( - [&](auto count) { - if constexpr (std::is_same_v, T>) { - return count == b; - } - return false; - }, - a); - TINT_END_DISABLE_WARNING(UNREACHABLE_CODE); -} - /// Array holds the semantic information for Array nodes. class Array final : public Castable { public: @@ -161,7 +53,7 @@ class Array final : public Castable { /// of the array to the start of the next element, if there was no `@stride` /// attribute applied. Array(Type const* element, - ArrayCount count, + const ArrayCount* count, uint32_t align, uint32_t size, uint32_t stride, @@ -178,11 +70,11 @@ class Array final : public Castable { Type const* ElemType() const { return element_; } /// @returns the number of elements in the array. - const ArrayCount& Count() const { return count_; } + const ArrayCount* Count() const { return count_; } /// @returns the array count if the count is a const-expression, otherwise returns nullopt. inline std::optional ConstantCount() const { - if (auto* count = std::get_if(&count_)) { + if (auto* count = count_->As()) { return count->value; } return std::nullopt; @@ -212,23 +104,19 @@ class Array final : public Castable { bool IsStrideImplicit() const { return stride_ == implicit_stride_; } /// @returns true if this array is sized using an const-expression - bool IsConstantSized() const { return std::holds_alternative(count_); } + bool IsConstantSized() const { return count_->Is(); } /// @returns true if this array is sized using a named override variable - bool IsNamedOverrideSized() const { - return std::holds_alternative(count_); - } + bool IsNamedOverrideSized() const { return count_->Is(); } /// @returns true if this array is sized using an unnamed override variable - bool IsUnnamedOverrideSized() const { - return std::holds_alternative(count_); - } + bool IsUnnamedOverrideSized() const { return count_->Is(); } /// @returns true if this array is sized using a named or unnamed override variable bool IsOverrideSized() const { return IsNamedOverrideSized() || IsUnnamedOverrideSized(); } /// @returns true if this array is runtime sized - bool IsRuntimeSized() const { return std::holds_alternative(count_); } + bool IsRuntimeSized() const { return count_->Is(); } /// @param symbols the program's symbol table /// @returns the name for this type that closely resembles how it would be @@ -237,7 +125,7 @@ class Array final : public Castable { private: Type const* const element_; - const ArrayCount count_; + const ArrayCount* count_; const uint32_t align_; const uint32_t size_; const uint32_t stride_; @@ -246,49 +134,4 @@ class Array final : public Castable { } // namespace tint::sem -namespace std { - -/// Custom std::hash specialization for tint::sem::ConstantArrayCount. -template <> -class hash { - public: - /// @param count the count to hash - /// @return the hash value - inline std::size_t operator()(const tint::sem::ConstantArrayCount& count) const { - return std::hash()(count.value); - } -}; - -/// Custom std::hash specialization for tint::sem::NamedOverrideArrayCount. -template <> -class hash { - public: - /// @param count the count to hash - /// @return the hash value - inline std::size_t operator()(const tint::sem::NamedOverrideArrayCount& count) const { - return std::hash()(count.variable); - } -}; - -/// Custom std::hash specialization for tint::sem::UnnamedOverrideArrayCount. -template <> -class hash { - public: - /// @param count the count to hash - /// @return the hash value - inline std::size_t operator()(const tint::sem::UnnamedOverrideArrayCount& count) const { - return std::hash()(count.expr); - } -}; - -/// Custom std::hash specialization for tint::sem::RuntimeArrayCount. -template <> -class hash { - public: - /// @return the hash value - inline std::size_t operator()(const tint::sem::RuntimeArrayCount&) const { return 42; } -}; - -} // namespace std - #endif // SRC_TINT_SEM_ARRAY_H_ diff --git a/src/tint/sem/array_count.cc b/src/tint/sem/array_count.cc new file mode 100644 index 0000000000..fa1663927c --- /dev/null +++ b/src/tint/sem/array_count.cc @@ -0,0 +1,82 @@ +// Copyright 2021 The Tint Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "src/tint/sem/array_count.h" + +TINT_INSTANTIATE_TYPEINFO(tint::sem::ArrayCount); +TINT_INSTANTIATE_TYPEINFO(tint::sem::ConstantArrayCount); +TINT_INSTANTIATE_TYPEINFO(tint::sem::RuntimeArrayCount); +TINT_INSTANTIATE_TYPEINFO(tint::sem::NamedOverrideArrayCount); +TINT_INSTANTIATE_TYPEINFO(tint::sem::UnnamedOverrideArrayCount); + +namespace tint::sem { + +ArrayCount::ArrayCount() : Base() {} +ArrayCount::~ArrayCount() = default; + +ConstantArrayCount::ConstantArrayCount(uint32_t val) : Base(), value(val) {} +ConstantArrayCount::~ConstantArrayCount() = default; + +size_t ConstantArrayCount::Hash() const { + return static_cast(TypeInfo::Of().full_hashcode); +} + +bool ConstantArrayCount::Equals(const ArrayCount& other) const { + if (auto* v = other.As()) { + return value == v->value; + } + return false; +} + +RuntimeArrayCount::RuntimeArrayCount() : Base() {} +RuntimeArrayCount::~RuntimeArrayCount() = default; + +size_t RuntimeArrayCount::Hash() const { + return static_cast(TypeInfo::Of().full_hashcode); +} + +bool RuntimeArrayCount::Equals(const ArrayCount& other) const { + return other.Is(); +} + +NamedOverrideArrayCount::NamedOverrideArrayCount(const GlobalVariable* var) + : Base(), variable(var) {} +NamedOverrideArrayCount::~NamedOverrideArrayCount() = default; + +size_t NamedOverrideArrayCount::Hash() const { + return static_cast(TypeInfo::Of().full_hashcode); +} + +bool NamedOverrideArrayCount::Equals(const ArrayCount& other) const { + if (auto* v = other.As()) { + return variable == v->variable; + } + return false; +} + +UnnamedOverrideArrayCount::UnnamedOverrideArrayCount(const Expression* e) : Base(), expr(e) {} +UnnamedOverrideArrayCount::~UnnamedOverrideArrayCount() = default; + +size_t UnnamedOverrideArrayCount::Hash() const { + return static_cast(TypeInfo::Of().full_hashcode); +} + +bool UnnamedOverrideArrayCount::Equals(const ArrayCount& other) const { + if (auto* v = other.As()) { + return expr == v->expr; + } + return false; +} + +} // namespace tint::sem diff --git a/src/tint/sem/array_count.h b/src/tint/sem/array_count.h new file mode 100644 index 0000000000..eb1a0016b6 --- /dev/null +++ b/src/tint/sem/array_count.h @@ -0,0 +1,170 @@ +// Copyright 2022 The Tint Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef SRC_TINT_SEM_ARRAY_COUNT_H_ +#define SRC_TINT_SEM_ARRAY_COUNT_H_ + +#include +#include + +#include "src/tint/sem/expression.h" +#include "src/tint/sem/node.h" +#include "src/tint/sem/variable.h" + +namespace tint::sem { + +/// An array count +class ArrayCount : public Castable { + public: + ~ArrayCount() override; + + /// @returns a hash of the array count. + virtual size_t Hash() const = 0; + + /// @param t other array count + /// @returns true if this array count is equal to the given array count + virtual bool Equals(const ArrayCount& t) const = 0; + + protected: + ArrayCount(); +}; + +/// The variant of an ArrayCount when the array is a const-expression. +/// Example: +/// ``` +/// const N = 123; +/// type arr = array +/// ``` +class ConstantArrayCount final : public Castable { + public: + /// Constructor + /// @param val the constant-expression value + explicit ConstantArrayCount(uint32_t val); + ~ConstantArrayCount() override; + + /// @returns a hash of the array count. + size_t Hash() const override; + + /// @param t other array count + /// @returns true if this array count is equal to the given array count + bool Equals(const ArrayCount& t) const override; + + /// The array count constant-expression value. + uint32_t value; +}; + +/// The variant of an ArrayCount when the array is is runtime-sized. +/// Example: +/// ``` +/// type arr = array +/// ``` +class RuntimeArrayCount final : public Castable { + public: + /// Constructor + RuntimeArrayCount(); + ~RuntimeArrayCount() override; + + /// @returns a hash of the array count. + size_t Hash() const override; + + /// @param t other array count + /// @returns true if this array count is equal to the given array count + bool Equals(const ArrayCount& t) const override; +}; + +/// The variant of an ArrayCount when the count is a named override variable. +/// Example: +/// ``` +/// override N : i32; +/// type arr = array +/// ``` +class NamedOverrideArrayCount final : public Castable { + public: + /// Constructor + /// @param var the `override` variable + explicit NamedOverrideArrayCount(const GlobalVariable* var); + ~NamedOverrideArrayCount() override; + + /// @returns a hash of the array count. + size_t Hash() const override; + + /// @param t other array count + /// @returns true if this array count is equal to the given array count + bool Equals(const ArrayCount& t) const override; + + /// The `override` variable. + const GlobalVariable* variable; +}; + +/// The variant of an ArrayCount when the count is an unnamed override variable. +/// Example: +/// ``` +/// override N : i32; +/// type arr = array +/// ``` +class UnnamedOverrideArrayCount final : public Castable { + public: + /// Constructor + /// @param e the override expression + explicit UnnamedOverrideArrayCount(const Expression* e); + ~UnnamedOverrideArrayCount() override; + + /// @returns a hash of the array count. + size_t Hash() const override; + + /// @param t other array count + /// @returns true if this array count is equal to the given array count + bool Equals(const ArrayCount& t) const override; + + /// The unnamed override expression. + /// Note: Each AST expression gets a unique semantic expression node, so two equivalent AST + /// expressions will not result in the same `expr` pointer. This property is important to ensure + /// that two array declarations with equivalent AST expressions do not compare equal. + /// For example, consider: + /// ``` + /// override size : u32; + /// var a : array; + /// var b : array; + /// ``` + // The array count for `a` and `b` have equivalent AST expressions, but the types for `a` and + // `b` must not compare equal. + const Expression* expr; +}; + +} // namespace tint::sem + +namespace std { + +/// std::hash specialization for tint::sem::ArrayCount +template <> +struct hash { + /// @param a the array count to obtain a hash from + /// @returns the hash of the array count + size_t operator()(const tint::sem::ArrayCount& a) const { return a.Hash(); } +}; + +/// std::equal_to specialization for tint::sem::ArrayCount +template <> +struct equal_to { + /// @param a the first array count to compare + /// @param b the second array count to compare + /// @returns true if the two array counts are equal + bool operator()(const tint::sem::ArrayCount& a, const tint::sem::ArrayCount& b) const { + return a.Equals(b); + } +}; + +} // namespace std + +#endif // SRC_TINT_SEM_ARRAY_COUNT_H_ diff --git a/src/tint/sem/array_test.cc b/src/tint/sem/array_test.cc index a51b492239..44073df3d6 100644 --- a/src/tint/sem/array_test.cc +++ b/src/tint/sem/array_test.cc @@ -21,16 +21,16 @@ namespace { using ArrayTest = TestHelper; TEST_F(ArrayTest, CreateSizedArray) { - auto* a = create(create(), ConstantArrayCount{2u}, 4u, 8u, 32u, 16u); - auto* b = create(create(), ConstantArrayCount{2u}, 4u, 8u, 32u, 16u); - auto* c = create(create(), ConstantArrayCount{3u}, 4u, 8u, 32u, 16u); - auto* d = create(create(), ConstantArrayCount{2u}, 5u, 8u, 32u, 16u); - auto* e = create(create(), ConstantArrayCount{2u}, 4u, 9u, 32u, 16u); - auto* f = create(create(), ConstantArrayCount{2u}, 4u, 8u, 33u, 16u); - auto* g = create(create(), ConstantArrayCount{2u}, 4u, 8u, 33u, 17u); + auto* a = create(create(), create(2u), 4u, 8u, 32u, 16u); + auto* b = create(create(), create(2u), 4u, 8u, 32u, 16u); + auto* c = create(create(), create(3u), 4u, 8u, 32u, 16u); + auto* d = create(create(), create(2u), 5u, 8u, 32u, 16u); + auto* e = create(create(), create(2u), 4u, 9u, 32u, 16u); + auto* f = create(create(), create(2u), 4u, 8u, 33u, 16u); + auto* g = create(create(), create(2u), 4u, 8u, 33u, 17u); EXPECT_EQ(a->ElemType(), create()); - EXPECT_EQ(a->Count(), ConstantArrayCount{2u}); + EXPECT_EQ(a->Count(), create(2u)); EXPECT_EQ(a->Align(), 4u); EXPECT_EQ(a->Size(), 8u); EXPECT_EQ(a->Stride(), 32u); @@ -47,15 +47,15 @@ TEST_F(ArrayTest, CreateSizedArray) { } TEST_F(ArrayTest, CreateRuntimeArray) { - auto* a = create(create(), RuntimeArrayCount{}, 4u, 8u, 32u, 32u); - auto* b = create(create(), RuntimeArrayCount{}, 4u, 8u, 32u, 32u); - auto* c = create(create(), RuntimeArrayCount{}, 5u, 8u, 32u, 32u); - auto* d = create(create(), RuntimeArrayCount{}, 4u, 9u, 32u, 32u); - auto* e = create(create(), RuntimeArrayCount{}, 4u, 8u, 33u, 32u); - auto* f = create(create(), RuntimeArrayCount{}, 4u, 8u, 33u, 17u); + auto* a = create(create(), create(), 4u, 8u, 32u, 32u); + auto* b = create(create(), create(), 4u, 8u, 32u, 32u); + auto* c = create(create(), create(), 5u, 8u, 32u, 32u); + auto* d = create(create(), create(), 4u, 9u, 32u, 32u); + auto* e = create(create(), create(), 4u, 8u, 33u, 32u); + auto* f = create(create(), create(), 4u, 8u, 33u, 17u); EXPECT_EQ(a->ElemType(), create()); - EXPECT_EQ(a->Count(), sem::RuntimeArrayCount{}); + EXPECT_EQ(a->Count(), create()); EXPECT_EQ(a->Align(), 4u); EXPECT_EQ(a->Size(), 8u); EXPECT_EQ(a->Stride(), 32u); @@ -71,13 +71,13 @@ TEST_F(ArrayTest, CreateRuntimeArray) { } TEST_F(ArrayTest, Hash) { - auto* a = create(create(), ConstantArrayCount{2u}, 4u, 8u, 32u, 16u); - auto* b = create(create(), ConstantArrayCount{2u}, 4u, 8u, 32u, 16u); - auto* c = create(create(), ConstantArrayCount{3u}, 4u, 8u, 32u, 16u); - auto* d = create(create(), ConstantArrayCount{2u}, 5u, 8u, 32u, 16u); - auto* e = create(create(), ConstantArrayCount{2u}, 4u, 9u, 32u, 16u); - auto* f = create(create(), ConstantArrayCount{2u}, 4u, 8u, 33u, 16u); - auto* g = create(create(), ConstantArrayCount{2u}, 4u, 8u, 33u, 17u); + auto* a = create(create(), create(2u), 4u, 8u, 32u, 16u); + auto* b = create(create(), create(2u), 4u, 8u, 32u, 16u); + auto* c = create(create(), create(3u), 4u, 8u, 32u, 16u); + auto* d = create(create(), create(2u), 5u, 8u, 32u, 16u); + auto* e = create(create(), create(2u), 4u, 9u, 32u, 16u); + auto* f = create(create(), create(2u), 4u, 8u, 33u, 16u); + auto* g = create(create(), create(2u), 4u, 8u, 33u, 17u); EXPECT_EQ(a->Hash(), b->Hash()); EXPECT_NE(a->Hash(), c->Hash()); @@ -88,13 +88,13 @@ TEST_F(ArrayTest, Hash) { } TEST_F(ArrayTest, Equals) { - auto* a = create(create(), ConstantArrayCount{2u}, 4u, 8u, 32u, 16u); - auto* b = create(create(), ConstantArrayCount{2u}, 4u, 8u, 32u, 16u); - auto* c = create(create(), ConstantArrayCount{3u}, 4u, 8u, 32u, 16u); - auto* d = create(create(), ConstantArrayCount{2u}, 5u, 8u, 32u, 16u); - auto* e = create(create(), ConstantArrayCount{2u}, 4u, 9u, 32u, 16u); - auto* f = create(create(), ConstantArrayCount{2u}, 4u, 8u, 33u, 16u); - auto* g = create(create(), ConstantArrayCount{2u}, 4u, 8u, 33u, 17u); + auto* a = create(create(), create(2u), 4u, 8u, 32u, 16u); + auto* b = create(create(), create(2u), 4u, 8u, 32u, 16u); + auto* c = create(create(), create(3u), 4u, 8u, 32u, 16u); + auto* d = create(create(), create(2u), 5u, 8u, 32u, 16u); + auto* e = create(create(), create(2u), 4u, 9u, 32u, 16u); + auto* f = create(create(), create(2u), 4u, 8u, 33u, 16u); + auto* g = create(create(), create(2u), 4u, 8u, 33u, 17u); EXPECT_TRUE(a->Equals(*b)); EXPECT_FALSE(a->Equals(*c)); @@ -106,32 +106,34 @@ TEST_F(ArrayTest, Equals) { } TEST_F(ArrayTest, FriendlyNameRuntimeSized) { - auto* arr = create(create(), RuntimeArrayCount{}, 0u, 4u, 4u, 4u); + auto* arr = create(create(), create(), 0u, 4u, 4u, 4u); EXPECT_EQ(arr->FriendlyName(Symbols()), "array"); } TEST_F(ArrayTest, FriendlyNameStaticSized) { - auto* arr = create(create(), ConstantArrayCount{5u}, 4u, 20u, 4u, 4u); + auto* arr = create(create(), create(5u), 4u, 20u, 4u, 4u); EXPECT_EQ(arr->FriendlyName(Symbols()), "array"); } TEST_F(ArrayTest, FriendlyNameRuntimeSizedNonImplicitStride) { - auto* arr = create(create(), RuntimeArrayCount{}, 0u, 4u, 8u, 4u); + auto* arr = create(create(), create(), 0u, 4u, 8u, 4u); EXPECT_EQ(arr->FriendlyName(Symbols()), "@stride(8) array"); } TEST_F(ArrayTest, FriendlyNameStaticSizedNonImplicitStride) { - auto* arr = create(create(), ConstantArrayCount{5u}, 4u, 20u, 8u, 4u); + auto* arr = create(create(), create(5u), 4u, 20u, 8u, 4u); EXPECT_EQ(arr->FriendlyName(Symbols()), "@stride(8) array"); } TEST_F(ArrayTest, IsConstructable) { - auto* fixed_sized = create(create(), ConstantArrayCount{2u}, 4u, 8u, 32u, 16u); + auto* fixed_sized = + create(create(), create(2u), 4u, 8u, 32u, 16u); auto* named_override_sized = - create(create(), NamedOverrideArrayCount{}, 4u, 8u, 32u, 16u); + create(create(), create(nullptr), 4u, 8u, 32u, 16u); auto* unnamed_override_sized = - create(create(), UnnamedOverrideArrayCount{}, 4u, 8u, 32u, 16u); - auto* runtime_sized = create(create(), RuntimeArrayCount{}, 4u, 8u, 32u, 16u); + create(create(), create(nullptr), 4u, 8u, 32u, 16u); + auto* runtime_sized = + create(create(), create(), 4u, 8u, 32u, 16u); EXPECT_TRUE(fixed_sized->IsConstructible()); EXPECT_FALSE(named_override_sized->IsConstructible()); @@ -140,12 +142,14 @@ TEST_F(ArrayTest, IsConstructable) { } TEST_F(ArrayTest, HasCreationFixedFootprint) { - auto* fixed_sized = create(create(), ConstantArrayCount{2u}, 4u, 8u, 32u, 16u); + auto* fixed_sized = + create(create(), create(2u), 4u, 8u, 32u, 16u); auto* named_override_sized = - create(create(), NamedOverrideArrayCount{}, 4u, 8u, 32u, 16u); + create(create(), create(nullptr), 4u, 8u, 32u, 16u); auto* unnamed_override_sized = - create(create(), UnnamedOverrideArrayCount{}, 4u, 8u, 32u, 16u); - auto* runtime_sized = create(create(), RuntimeArrayCount{}, 4u, 8u, 32u, 16u); + create(create(), create(nullptr), 4u, 8u, 32u, 16u); + auto* runtime_sized = + create(create(), create(), 4u, 8u, 32u, 16u); EXPECT_TRUE(fixed_sized->HasCreationFixedFootprint()); EXPECT_FALSE(named_override_sized->HasCreationFixedFootprint()); @@ -154,12 +158,14 @@ TEST_F(ArrayTest, HasCreationFixedFootprint) { } TEST_F(ArrayTest, HasFixedFootprint) { - auto* fixed_sized = create(create(), ConstantArrayCount{2u}, 4u, 8u, 32u, 16u); + auto* fixed_sized = + create(create(), create(2u), 4u, 8u, 32u, 16u); auto* named_override_sized = - create(create(), NamedOverrideArrayCount{}, 4u, 8u, 32u, 16u); + create(create(), create(nullptr), 4u, 8u, 32u, 16u); auto* unnamed_override_sized = - create(create(), UnnamedOverrideArrayCount{}, 4u, 8u, 32u, 16u); - auto* runtime_sized = create(create(), RuntimeArrayCount{}, 4u, 8u, 32u, 16u); + create(create(), create(nullptr), 4u, 8u, 32u, 16u); + auto* runtime_sized = + create(create(), create(), 4u, 8u, 32u, 16u); EXPECT_TRUE(fixed_sized->HasFixedFootprint()); EXPECT_TRUE(named_override_sized->HasFixedFootprint()); diff --git a/src/tint/sem/type.cc b/src/tint/sem/type.cc index f9db53b6c5..3d25e7eb3c 100644 --- a/src/tint/sem/type.cc +++ b/src/tint/sem/type.cc @@ -273,7 +273,7 @@ const Type* Type::ElementOf(const Type* ty, uint32_t* count /* = nullptr */) { }, [&](const Array* a) { if (count) { - if (auto* const_count = std::get_if(&a->Count())) { + if (auto* const_count = a->Count()->As()) { *count = const_count->value; } } diff --git a/src/tint/sem/type_manager.h b/src/tint/sem/type_manager.h index 72f843ad5b..33c42b3b32 100644 --- a/src/tint/sem/type_manager.h +++ b/src/tint/sem/type_manager.h @@ -19,6 +19,7 @@ #include #include +#include "src/tint/sem/array_count.h" #include "src/tint/sem/type.h" #include "src/tint/utils/unique_allocator.h" @@ -56,6 +57,7 @@ class TypeManager final { static TypeManager Wrap(const TypeManager& inner) { TypeManager out; out.types_.Wrap(inner.types_); + out.array_counts_.Wrap(inner.array_counts_); return out; } @@ -80,6 +82,17 @@ class TypeManager final { return types_.Find(std::forward(args)...); } + /// @param args the arguments used to construct the object. + /// @return a pointer to an instance of `T` with the provided arguments. + /// If an existing instance of `T` has been constructed, then the same + /// pointer is returned. + template >, + typename... ARGS> + TYPE* GetArrayCount(ARGS&&... args) { + return array_counts_.Get(std::forward(args)...); + } + /// @returns an iterator to the beginning of the types TypeIterator begin() const { return types_.begin(); } /// @returns an iterator to the end of the types @@ -87,6 +100,7 @@ class TypeManager final { private: utils::UniqueAllocator types_; + utils::UniqueAllocator array_counts_; }; } // namespace tint::sem diff --git a/src/tint/sem/type_test.cc b/src/tint/sem/type_test.cc index db7616b383..91483221b2 100644 --- a/src/tint/sem/type_test.cc +++ b/src/tint/sem/type_test.cc @@ -100,63 +100,63 @@ struct TypeTest : public TestHelper { /* size_no_padding*/ 4u); const sem::Array* arr_i32 = create( /* element */ i32, - /* count */ ConstantArrayCount{5u}, + /* count */ create(5u), /* align */ 4u, /* size */ 5u * 4u, /* stride */ 5u * 4u, /* implicit_stride */ 5u * 4u); const sem::Array* arr_ai = create( /* element */ ai, - /* count */ ConstantArrayCount{5u}, + /* count */ create(5u), /* align */ 4u, /* size */ 5u * 4u, /* stride */ 5u * 4u, /* implicit_stride */ 5u * 4u); const sem::Array* arr_vec3_i32 = create( /* element */ vec3_i32, - /* count */ ConstantArrayCount{5u}, + /* count */ create(5u), /* align */ 16u, /* size */ 5u * 16u, /* stride */ 5u * 16u, /* implicit_stride */ 5u * 16u); const sem::Array* arr_vec3_ai = create( /* element */ vec3_ai, - /* count */ ConstantArrayCount{5u}, + /* count */ create(5u), /* align */ 16u, /* size */ 5u * 16u, /* stride */ 5u * 16u, /* implicit_stride */ 5u * 16u); const sem::Array* arr_mat4x3_f16 = create( /* element */ mat4x3_f16, - /* count */ ConstantArrayCount{5u}, + /* count */ create(5u), /* align */ 32u, /* size */ 5u * 32u, /* stride */ 5u * 32u, /* implicit_stride */ 5u * 32u); const sem::Array* arr_mat4x3_f32 = create( /* element */ mat4x3_f32, - /* count */ ConstantArrayCount{5u}, + /* count */ create(5u), /* align */ 64u, /* size */ 5u * 64u, /* stride */ 5u * 64u, /* implicit_stride */ 5u * 64u); const sem::Array* arr_mat4x3_af = create( /* element */ mat4x3_af, - /* count */ ConstantArrayCount{5u}, + /* count */ create(5u), /* align */ 64u, /* size */ 5u * 64u, /* stride */ 5u * 64u, /* implicit_stride */ 5u * 64u); const sem::Array* arr_str_f16 = create( /* element */ str_f16, - /* count */ ConstantArrayCount{5u}, + /* count */ create(5u), /* align */ 4u, /* size */ 5u * 4u, /* stride */ 5u * 4u, /* implicit_stride */ 5u * 4u); const sem::Array* arr_str_af = create( /* element */ str_af, - /* count */ ConstantArrayCount{5u}, + /* count */ create(5u), /* align */ 4u, /* size */ 5u * 4u, /* stride */ 5u * 4u, diff --git a/src/tint/transform/transform.cc b/src/tint/transform/transform.cc index 5c8357c3df..d2eea428d4 100644 --- a/src/tint/transform/transform.cc +++ b/src/tint/transform/transform.cc @@ -109,11 +109,11 @@ const ast::Type* Transform::CreateASTTypeFor(CloneContext& ctx, const sem::Type* if (a->IsRuntimeSized()) { return ctx.dst->ty.array(el, nullptr, std::move(attrs)); } - if (auto* override = std::get_if(&a->Count())) { + if (auto* override = a->Count()->As()) { auto* count = ctx.Clone(override->variable->Declaration()); return ctx.dst->ty.array(el, count, std::move(attrs)); } - if (auto* override = std::get_if(&a->Count())) { + if (auto* override = a->Count()->As()) { // If the array count is an unnamed (complex) override expression, then its not safe to // redeclare this type as we'd end up with two types that would not compare equal. // See crbug.com/tint/1764. diff --git a/src/tint/transform/transform_test.cc b/src/tint/transform/transform_test.cc index 2a39094294..29a21e1fac 100644 --- a/src/tint/transform/transform_test.cc +++ b/src/tint/transform/transform_test.cc @@ -69,8 +69,8 @@ TEST_F(CreateASTTypeForTest, Vector) { TEST_F(CreateASTTypeForTest, ArrayImplicitStride) { auto* arr = create([](ProgramBuilder& b) { - return b.create(b.create(), sem::ConstantArrayCount{2u}, 4u, 4u, 32u, - 32u); + return b.create(b.create(), b.create(2u), 4u, + 4u, 32u, 32u); }); ASSERT_TRUE(arr->Is()); ASSERT_TRUE(arr->As()->type->Is()); @@ -83,8 +83,8 @@ TEST_F(CreateASTTypeForTest, ArrayImplicitStride) { TEST_F(CreateASTTypeForTest, ArrayNonImplicitStride) { auto* arr = create([](ProgramBuilder& b) { - return b.create(b.create(), sem::ConstantArrayCount{2u}, 4u, 4u, 64u, - 32u); + return b.create(b.create(), b.create(2u), 4u, + 4u, 64u, 32u); }); ASSERT_TRUE(arr->Is()); ASSERT_TRUE(arr->As()->type->Is()); diff --git a/src/tint/utils/unique_allocator.h b/src/tint/utils/unique_allocator.h index 25681cdfff..7ba59097c3 100644 --- a/src/tint/utils/unique_allocator.h +++ b/src/tint/utils/unique_allocator.h @@ -91,7 +91,7 @@ class UniqueAllocator { struct Entry { /// The pre-calculated hash of the entry size_t hash; - /// Tge pointer to the unique object + /// The pointer to the unique object T* ptr; }; /// Comparator is the hashing and equality function used by the unordered_set