Fix overrides in array size.

This CL fixes the usage of overrides in array sizes. Currently
the usage will generate a validation error as we check that the
array size is const.

Bug: tint:1660
Change-Id: Ibf440905c30a73b581d55b0c071b8621b61605e6
Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/101900
Kokoro: Kokoro <noreply+kokoro@google.com>
Reviewed-by: dan sinclair <dsinclair@google.com>
Reviewed-by: Antonio Maiorano <amaiorano@google.com>
Commit-Queue: Ben Clayton <bclayton@chromium.org>
This commit is contained in:
dan sinclair 2022-09-22 22:28:21 +00:00 committed by Dawn LUCI CQ
parent 534a198f88
commit 78f8067fd5
42 changed files with 839 additions and 186 deletions

View File

@ -394,8 +394,10 @@ const ImplConstant* ZeroValue(ProgramBuilder& builder, const sem::Type* type) {
return builder.create<Splat>(type, zero_el, m->columns()); return builder.create<Splat>(type, zero_el, m->columns());
}, },
[&](const sem::Array* a) -> const ImplConstant* { [&](const sem::Array* a) -> const ImplConstant* {
if (auto* zero_el = ZeroValue(builder, a->ElemType())) { if (auto n = a->ConstantCount()) {
return builder.create<Splat>(type, zero_el, a->Count()); if (auto* zero_el = ZeroValue(builder, a->ElemType())) {
return builder.create<Splat>(type, zero_el, n.value());
}
} }
return nullptr; return nullptr;
}, },
@ -451,12 +453,16 @@ bool Equal(const sem::Constant* a, const sem::Constant* b) {
return true; return true;
}, },
[&](const sem::Array* arr) { [&](const sem::Array* arr) {
for (size_t i = 0; i < arr->Count(); i++) { if (auto count = arr->ConstantCount()) {
if (!Equal(a->Index(i), b->Index(i))) { for (size_t i = 0; i < count; i++) {
return false; if (!Equal(a->Index(i), b->Index(i))) {
return false;
}
} }
return true;
} }
return true;
return false;
}, },
[&](Default) { return a->Value() == b->Value(); }); [&](Default) { return a->Value() == b->Value(); });
} }

View File

@ -1700,7 +1700,7 @@ TEST_F(ResolverConstEvalTest, Array_i32_Zero) {
auto* arr = sem->Type()->As<sem::Array>(); auto* arr = sem->Type()->As<sem::Array>();
ASSERT_NE(arr, nullptr); ASSERT_NE(arr, nullptr);
EXPECT_TRUE(arr->ElemType()->Is<sem::I32>()); EXPECT_TRUE(arr->ElemType()->Is<sem::I32>());
EXPECT_EQ(arr->Count(), 4u); EXPECT_EQ(arr->Count(), sem::ConstantArrayCount{4u});
EXPECT_TYPE(sem->ConstantValue()->Type(), sem->Type()); EXPECT_TYPE(sem->ConstantValue()->Type(), sem->Type());
EXPECT_TRUE(sem->ConstantValue()->AllEqual()); EXPECT_TRUE(sem->ConstantValue()->AllEqual());
EXPECT_TRUE(sem->ConstantValue()->AnyZero()); EXPECT_TRUE(sem->ConstantValue()->AnyZero());
@ -1738,7 +1738,7 @@ TEST_F(ResolverConstEvalTest, Array_f32_Zero) {
auto* arr = sem->Type()->As<sem::Array>(); auto* arr = sem->Type()->As<sem::Array>();
ASSERT_NE(arr, nullptr); ASSERT_NE(arr, nullptr);
EXPECT_TRUE(arr->ElemType()->Is<sem::F32>()); EXPECT_TRUE(arr->ElemType()->Is<sem::F32>());
EXPECT_EQ(arr->Count(), 4u); EXPECT_EQ(arr->Count(), sem::ConstantArrayCount{4u});
EXPECT_TYPE(sem->ConstantValue()->Type(), sem->Type()); EXPECT_TYPE(sem->ConstantValue()->Type(), sem->Type());
EXPECT_TRUE(sem->ConstantValue()->AllEqual()); EXPECT_TRUE(sem->ConstantValue()->AllEqual());
EXPECT_TRUE(sem->ConstantValue()->AnyZero()); EXPECT_TRUE(sem->ConstantValue()->AnyZero());
@ -1776,7 +1776,7 @@ TEST_F(ResolverConstEvalTest, Array_vec3_f32_Zero) {
auto* arr = sem->Type()->As<sem::Array>(); auto* arr = sem->Type()->As<sem::Array>();
ASSERT_NE(arr, nullptr); ASSERT_NE(arr, nullptr);
EXPECT_TRUE(arr->ElemType()->Is<sem::Vector>()); EXPECT_TRUE(arr->ElemType()->Is<sem::Vector>());
EXPECT_EQ(arr->Count(), 2u); EXPECT_EQ(arr->Count(), sem::ConstantArrayCount{2u});
EXPECT_TYPE(sem->ConstantValue()->Type(), sem->Type()); EXPECT_TYPE(sem->ConstantValue()->Type(), sem->Type());
EXPECT_TRUE(sem->ConstantValue()->AllEqual()); EXPECT_TRUE(sem->ConstantValue()->AllEqual());
EXPECT_TRUE(sem->ConstantValue()->AnyZero()); EXPECT_TRUE(sem->ConstantValue()->AnyZero());
@ -1828,7 +1828,7 @@ TEST_F(ResolverConstEvalTest, Array_Struct_f32_Zero) {
auto* arr = sem->Type()->As<sem::Array>(); auto* arr = sem->Type()->As<sem::Array>();
ASSERT_NE(arr, nullptr); ASSERT_NE(arr, nullptr);
EXPECT_TRUE(arr->ElemType()->Is<sem::Struct>()); EXPECT_TRUE(arr->ElemType()->Is<sem::Struct>());
EXPECT_EQ(arr->Count(), 2u); EXPECT_EQ(arr->Count(), sem::ConstantArrayCount{2u});
EXPECT_TYPE(sem->ConstantValue()->Type(), sem->Type()); EXPECT_TYPE(sem->ConstantValue()->Type(), sem->Type());
EXPECT_TRUE(sem->ConstantValue()->AllEqual()); EXPECT_TRUE(sem->ConstantValue()->AllEqual());
EXPECT_TRUE(sem->ConstantValue()->AnyZero()); EXPECT_TRUE(sem->ConstantValue()->AnyZero());
@ -1866,7 +1866,7 @@ TEST_F(ResolverConstEvalTest, Array_i32_Elements) {
auto* arr = sem->Type()->As<sem::Array>(); auto* arr = sem->Type()->As<sem::Array>();
ASSERT_NE(arr, nullptr); ASSERT_NE(arr, nullptr);
EXPECT_TRUE(arr->ElemType()->Is<sem::I32>()); EXPECT_TRUE(arr->ElemType()->Is<sem::I32>());
EXPECT_EQ(arr->Count(), 4u); EXPECT_EQ(arr->Count(), sem::ConstantArrayCount{4u});
EXPECT_TYPE(sem->ConstantValue()->Type(), sem->Type()); EXPECT_TYPE(sem->ConstantValue()->Type(), sem->Type());
EXPECT_FALSE(sem->ConstantValue()->AllEqual()); EXPECT_FALSE(sem->ConstantValue()->AllEqual());
EXPECT_FALSE(sem->ConstantValue()->AnyZero()); EXPECT_FALSE(sem->ConstantValue()->AnyZero());
@ -1904,7 +1904,7 @@ TEST_F(ResolverConstEvalTest, Array_f32_Elements) {
auto* arr = sem->Type()->As<sem::Array>(); auto* arr = sem->Type()->As<sem::Array>();
ASSERT_NE(arr, nullptr); ASSERT_NE(arr, nullptr);
EXPECT_TRUE(arr->ElemType()->Is<sem::F32>()); EXPECT_TRUE(arr->ElemType()->Is<sem::F32>());
EXPECT_EQ(arr->Count(), 4u); EXPECT_EQ(arr->Count(), sem::ConstantArrayCount{4u});
EXPECT_TYPE(sem->ConstantValue()->Type(), sem->Type()); EXPECT_TYPE(sem->ConstantValue()->Type(), sem->Type());
EXPECT_FALSE(sem->ConstantValue()->AllEqual()); EXPECT_FALSE(sem->ConstantValue()->AllEqual());
EXPECT_FALSE(sem->ConstantValue()->AnyZero()); EXPECT_FALSE(sem->ConstantValue()->AnyZero());
@ -1943,7 +1943,7 @@ TEST_F(ResolverConstEvalTest, Array_vec3_f32_Elements) {
auto* arr = sem->Type()->As<sem::Array>(); auto* arr = sem->Type()->As<sem::Array>();
ASSERT_NE(arr, nullptr); ASSERT_NE(arr, nullptr);
EXPECT_TRUE(arr->ElemType()->Is<sem::Vector>()); EXPECT_TRUE(arr->ElemType()->Is<sem::Vector>());
EXPECT_EQ(arr->Count(), 2u); EXPECT_EQ(arr->Count(), sem::ConstantArrayCount{2u});
EXPECT_TYPE(sem->ConstantValue()->Type(), sem->Type()); EXPECT_TYPE(sem->ConstantValue()->Type(), sem->Type());
EXPECT_FALSE(sem->ConstantValue()->AllEqual()); EXPECT_FALSE(sem->ConstantValue()->AllEqual());
EXPECT_FALSE(sem->ConstantValue()->AnyZero()); EXPECT_FALSE(sem->ConstantValue()->AnyZero());
@ -1973,7 +1973,7 @@ TEST_F(ResolverConstEvalTest, Array_Struct_f32_Elements) {
auto* arr = sem->Type()->As<sem::Array>(); auto* arr = sem->Type()->As<sem::Array>();
ASSERT_NE(arr, nullptr); ASSERT_NE(arr, nullptr);
EXPECT_TRUE(arr->ElemType()->Is<sem::Struct>()); EXPECT_TRUE(arr->ElemType()->Is<sem::Struct>());
EXPECT_EQ(arr->Count(), 2u); EXPECT_EQ(arr->Count(), sem::ConstantArrayCount{2u});
EXPECT_TYPE(sem->ConstantValue()->Type(), sem->Type()); EXPECT_TYPE(sem->ConstantValue()->Type(), sem->Type());
EXPECT_FALSE(sem->ConstantValue()->AllEqual()); EXPECT_FALSE(sem->ConstantValue()->AllEqual());
EXPECT_FALSE(sem->ConstantValue()->AnyZero()); EXPECT_FALSE(sem->ConstantValue()->AnyZero());

View File

@ -877,7 +877,7 @@ TEST_F(ResolverFunctionValidationTest, ParameterStoreType_NonAtomicFree) {
Member("m", ty.atomic(ty.i32())), Member("m", ty.atomic(ty.i32())),
}); });
auto* ret_type = ty.type_name(Source{{12, 34}}, "S"); auto* ret_type = ty.type_name(Source{{12, 34}}, "S");
auto* bar = Param(Source{{12, 34}}, "bar", ret_type); auto* bar = Param("bar", ret_type);
Func("f", utils::Vector{bar}, ty.void_(), utils::Empty); Func("f", utils::Vector{bar}, ty.void_(), utils::Empty);
EXPECT_FALSE(r()->Resolve()); EXPECT_FALSE(r()->Resolve());

View File

@ -135,7 +135,8 @@ INSTANTIATE_TEST_SUITE_P(ResolverTest, ResolverInferredTypeParamTest, testing::V
TEST_F(ResolverInferredTypeTest, InferArray_Pass) { TEST_F(ResolverInferredTypeTest, InferArray_Pass) {
auto* type = ty.array(ty.u32(), 10_u); auto* type = ty.array(ty.u32(), 10_u);
auto* expected_type = create<sem::Array>(create<sem::U32>(), 10u, 4u, 4u * 10u, 4u, 4u); auto* expected_type =
create<sem::Array>(create<sem::U32>(), sem::ConstantArrayCount{10u}, 4u, 4u * 10u, 4u, 4u);
auto* ctor_expr = Construct(type); auto* ctor_expr = Construct(type);
auto* var = Var("a", ast::StorageClass::kFunction, ctor_expr); auto* var = Var("a", ast::StorageClass::kFunction, ctor_expr);

View File

@ -513,7 +513,7 @@ bool match_array(const sem::Type* ty, const sem::Type*& T) {
} }
if (auto* a = ty->As<sem::Array>()) { if (auto* a = ty->As<sem::Array>()) {
if (a->Count() == 0) { if (a->IsRuntimeSized()) {
T = a->ElemType(); T = a->ElemType();
return true; return true;
} }
@ -523,7 +523,7 @@ bool match_array(const sem::Type* ty, const sem::Type*& T) {
const sem::Array* build_array(MatchState& state, const sem::Type* el) { const sem::Array* build_array(MatchState& state, const sem::Type* el) {
return state.builder.create<sem::Array>(el, return state.builder.create<sem::Array>(el,
/* count */ 0u, /* count */ sem::RuntimeArrayCount{},
/* align */ 0u, /* align */ 0u,
/* size */ 0u, /* size */ 0u,
/* stride */ 0u, /* stride */ 0u,

View File

@ -235,7 +235,7 @@ TEST_F(IntrinsicTableTest, MismatchPointer) {
} }
TEST_F(IntrinsicTableTest, MatchArray) { TEST_F(IntrinsicTableTest, MatchArray) {
auto* arr = create<sem::Array>(create<sem::U32>(), 0u, 4u, 4u, 4u, 4u); auto* arr = create<sem::Array>(create<sem::U32>(), sem::RuntimeArrayCount{}, 4u, 4u, 4u, 4u);
auto* arr_ptr = create<sem::Pointer>(arr, ast::StorageClass::kStorage, ast::Access::kReadWrite); auto* arr_ptr = create<sem::Pointer>(arr, ast::StorageClass::kStorage, ast::Access::kReadWrite);
auto result = table->Lookup(BuiltinType::kArrayLength, utils::Vector{arr_ptr}, Source{}); auto result = table->Lookup(BuiltinType::kArrayLength, utils::Vector{arr_ptr}, Source{});
ASSERT_NE(result.sem, nullptr) << Diagnostics().str(); ASSERT_NE(result.sem, nullptr) << Diagnostics().str();
@ -798,7 +798,7 @@ TEST_F(IntrinsicTableTest, MatchTypeConversion) {
} }
TEST_F(IntrinsicTableTest, MismatchTypeConversion) { TEST_F(IntrinsicTableTest, MismatchTypeConversion) {
auto* arr = create<sem::Array>(create<sem::U32>(), 0u, 4u, 4u, 4u, 4u); auto* arr = create<sem::Array>(create<sem::U32>(), sem::RuntimeArrayCount{}, 4u, 4u, 4u, 4u);
auto* f32 = create<sem::F32>(); auto* f32 = create<sem::F32>();
auto result = auto result =
table->Lookup(CtorConvIntrinsic::kVec3, f32, utils::Vector{arr}, Source{{12, 34}}); table->Lookup(CtorConvIntrinsic::kVec3, f32, utils::Vector{arr}, Source{{12, 34}});

View File

@ -106,12 +106,13 @@ TEST_F(ResolverIsHostShareable, Atomic) {
} }
TEST_F(ResolverIsHostShareable, ArraySizedOfHostShareable) { TEST_F(ResolverIsHostShareable, ArraySizedOfHostShareable) {
auto* arr = create<sem::Array>(create<sem::I32>(), 5u, 4u, 20u, 4u, 4u); auto* arr =
create<sem::Array>(create<sem::I32>(), sem::ConstantArrayCount{5u}, 4u, 20u, 4u, 4u);
EXPECT_TRUE(r()->IsHostShareable(arr)); EXPECT_TRUE(r()->IsHostShareable(arr));
} }
TEST_F(ResolverIsHostShareable, ArrayUnsizedOfHostShareable) { TEST_F(ResolverIsHostShareable, ArrayUnsizedOfHostShareable) {
auto* arr = create<sem::Array>(create<sem::I32>(), 0u, 4u, 4u, 4u, 4u); auto* arr = create<sem::Array>(create<sem::I32>(), sem::RuntimeArrayCount{}, 4u, 4u, 4u, 4u);
EXPECT_TRUE(r()->IsHostShareable(arr)); EXPECT_TRUE(r()->IsHostShareable(arr));
} }

View File

@ -89,12 +89,13 @@ TEST_F(ResolverIsStorableTest, Atomic) {
} }
TEST_F(ResolverIsStorableTest, ArraySizedOfStorable) { TEST_F(ResolverIsStorableTest, ArraySizedOfStorable) {
auto* arr = create<sem::Array>(create<sem::I32>(), 5u, 4u, 20u, 4u, 4u); auto* arr =
create<sem::Array>(create<sem::I32>(), sem::ConstantArrayCount{5u}, 4u, 20u, 4u, 4u);
EXPECT_TRUE(r()->IsStorable(arr)); EXPECT_TRUE(r()->IsStorable(arr));
} }
TEST_F(ResolverIsStorableTest, ArrayUnsizedOfStorable) { TEST_F(ResolverIsStorableTest, ArrayUnsizedOfStorable) {
auto* arr = create<sem::Array>(create<sem::I32>(), 0u, 4u, 4u, 4u, 4u); auto* arr = create<sem::Array>(create<sem::I32>(), sem::RuntimeArrayCount{}, 4u, 4u, 4u, 4u);
EXPECT_TRUE(r()->IsStorable(arr)); EXPECT_TRUE(r()->IsStorable(arr));
} }

View File

@ -118,7 +118,9 @@ class MaterializeTest : public resolver::ResolverTestWithParam<CASE> {
} }
}, },
[&](const sem::Array* a) { [&](const sem::Array* a) {
for (uint32_t i = 0; i < a->Count(); i++) { auto count = a->ConstantCount();
ASSERT_NE(count, 0u);
for (uint32_t i = 0; i < count; i++) {
auto* el = value->Index(i); auto* el = value->Index(i);
ASSERT_NE(el, nullptr); ASSERT_NE(el, nullptr);
EXPECT_TYPE(el->Type(), a->ElemType()); EXPECT_TYPE(el->Type(), a->ElemType());

View File

@ -1489,7 +1489,7 @@ const sem::Type* Resolver::ConcreteType(const sem::Type* ty,
target_el_ty = target_arr_ty->ElemType(); target_el_ty = target_arr_ty->ElemType();
} }
if (auto* el_ty = ConcreteType(a->ElemType(), target_el_ty, source)) { if (auto* el_ty = ConcreteType(a->ElemType(), target_el_ty, source)) {
return Array(source, el_ty, a->Count(), /* explicit_stride */ 0); return Array(source, source, el_ty, a->Count(), /* explicit_stride */ 0);
} }
return nullptr; return nullptr;
}); });
@ -1879,7 +1879,8 @@ sem::Call* Resolver::Call(const ast::CallExpression* expr) {
[&](const ast::Array* a) -> sem::Call* { [&](const ast::Array* a) -> sem::Call* {
Mark(a); Mark(a);
// array element type must be inferred if it was not specified. // array element type must be inferred if it was not specified.
auto el_count = static_cast<uint32_t>(args.Length()); sem::ArrayCount el_count =
sem::ConstantArrayCount{static_cast<uint32_t>(args.Length())};
const sem::Type* el_ty = nullptr; const sem::Type* el_ty = nullptr;
if (a->type) { if (a->type) {
el_ty = Type(a->type); el_ty = Type(a->type);
@ -1921,7 +1922,9 @@ sem::Call* Resolver::Call(const ast::CallExpression* expr) {
return nullptr; return nullptr;
} }
auto* arr = Array(a->source, el_ty, el_count, explicit_stride); auto* arr = Array(a->type ? a->type->source : a->source,
a->count ? a->count->source : a->source, //
el_ty, el_count, explicit_stride);
if (!arr) { if (!arr) {
return nullptr; return nullptr;
} }
@ -2591,7 +2594,7 @@ sem::Array* Resolver::Array(const ast::Array* arr) {
return nullptr; return nullptr;
} }
uint32_t el_count = 0; // sem::Array uses a size of 0 for a runtime-sized array. sem::ArrayCount el_count = sem::RuntimeArrayCount{};
// Evaluate the constant array size expression. // Evaluate the constant array size expression.
if (auto* count_expr = arr->count) { if (auto* count_expr = arr->count) {
@ -2602,7 +2605,9 @@ sem::Array* Resolver::Array(const ast::Array* arr) {
} }
} }
auto* out = Array(arr->source, el_ty, el_count, explicit_stride); auto* out = Array(arr->type->source, //
arr->count ? arr->count->source : arr->source, //
el_ty, el_count, explicit_stride);
if (out == nullptr) { if (out == nullptr) {
return nullptr; return nullptr;
} }
@ -2619,16 +2624,27 @@ sem::Array* Resolver::Array(const ast::Array* arr) {
return out; return out;
} }
utils::Result<uint32_t> Resolver::ArrayCount(const ast::Expression* count_expr) { utils::Result<sem::ArrayCount> Resolver::ArrayCount(const ast::Expression* count_expr) {
// Evaluate the constant array size expression. // Evaluate the constant array size expression.
const auto* count_sem = Materialize(Expression(count_expr)); const auto* count_sem = Materialize(Expression(count_expr));
if (!count_sem) { if (!count_sem) {
return utils::Failure; return utils::Failure;
} }
// Note: If the array count is an 'override', but not a identifier expression, we do not return
// here, but instead continue to the ConstantValue() check below.
if (auto* user = count_sem->UnwrapMaterialize()->As<sem::VariableUser>()) {
if (auto* global = user->Variable()->As<sem::GlobalVariable>()) {
if (global->Declaration()->Is<ast::Override>()) {
return sem::ArrayCount{sem::OverrideArrayCount{global}};
}
}
}
auto* count_val = count_sem->ConstantValue(); auto* count_val = count_sem->ConstantValue();
if (!count_val) { if (!count_val) {
AddError("array size must evaluate to a constant integer expression", count_expr->source); AddError("array size must evaluate to a constant integer expression or override variable",
count_expr->source);
return utils::Failure; return utils::Failure;
} }
@ -2646,7 +2662,7 @@ utils::Result<uint32_t> Resolver::ArrayCount(const ast::Expression* count_expr)
return utils::Failure; return utils::Failure;
} }
return static_cast<uint32_t>(count); return sem::ArrayCount{sem::ConstantArrayCount{static_cast<uint32_t>(count)}};
} }
bool Resolver::ArrayAttributes(utils::VectorRef<const ast::Attribute*> attributes, bool Resolver::ArrayAttributes(utils::VectorRef<const ast::Attribute*> attributes,
@ -2673,27 +2689,33 @@ bool Resolver::ArrayAttributes(utils::VectorRef<const ast::Attribute*> attribute
return true; return true;
} }
sem::Array* Resolver::Array(const Source& source, sem::Array* Resolver::Array(const Source& el_source,
const Source& count_source,
const sem::Type* el_ty, const sem::Type* el_ty,
uint32_t el_count, sem::ArrayCount el_count,
uint32_t explicit_stride) { uint32_t explicit_stride) {
uint32_t el_align = el_ty->Align(); uint32_t el_align = el_ty->Align();
uint32_t el_size = el_ty->Size(); uint32_t el_size = el_ty->Size();
uint64_t implicit_stride = el_size ? utils::RoundUp<uint64_t>(el_align, el_size) : 0; uint64_t implicit_stride = el_size ? utils::RoundUp<uint64_t>(el_align, el_size) : 0;
uint64_t stride = explicit_stride ? explicit_stride : implicit_stride; uint64_t stride = explicit_stride ? explicit_stride : implicit_stride;
uint64_t size = 0;
auto size = std::max<uint64_t>(el_count, 1u) * stride; if (auto const_count = std::get_if<sem::ConstantArrayCount>(&el_count)) {
if (size > std::numeric_limits<uint32_t>::max()) { size = const_count->value * stride;
std::stringstream msg; if (size > std::numeric_limits<uint32_t>::max()) {
msg << "array size (0x" << std::hex << size << ") must not exceed 0xffffffff bytes"; std::stringstream msg;
AddError(msg.str(), source); msg << "array size (0x" << std::hex << size << ") must not exceed 0xffffffff bytes";
return nullptr; AddError(msg.str(), count_source);
return nullptr;
}
} else if (std::holds_alternative<sem::RuntimeArrayCount>(el_count)) {
size = stride;
} }
auto* out = builder_->create<sem::Array>(el_ty, el_count, el_align, static_cast<uint32_t>(size), auto* out = builder_->create<sem::Array>(el_ty, el_count, el_align, static_cast<uint32_t>(size),
static_cast<uint32_t>(stride), static_cast<uint32_t>(stride),
static_cast<uint32_t>(implicit_stride)); static_cast<uint32_t>(implicit_stride));
if (!validator_.Array(out, source)) { if (!validator_.Array(out, el_source)) {
return nullptr; return nullptr;
} }

View File

@ -302,7 +302,7 @@ class Resolver {
/// Resolves and validates the expression used as the count parameter of an array. /// Resolves and validates the expression used as the count parameter of an array.
/// @param count_expr the expression used as the second template parameter to an array<>. /// @param count_expr the expression used as the second template parameter to an array<>.
/// @returns the number of elements in the array. /// @returns the number of elements in the array.
utils::Result<uint32_t> ArrayCount(const ast::Expression* count_expr); utils::Result<sem::ArrayCount> ArrayCount(const ast::Expression* count_expr);
/// Resolves and validates the attributes on an array. /// Resolves and validates the attributes on an array.
/// @param attributes the attributes on the array type. /// @param attributes the attributes on the array type.
@ -315,13 +315,17 @@ class Resolver {
/// Builds and returns the semantic information for an array. /// Builds and returns the semantic information for an array.
/// @returns the semantic Array information, or nullptr if an error is raised. /// @returns the semantic Array information, or nullptr if an error is raised.
/// @param source the source of the array declaration /// @param el_source the source of the array element, or the array if the array does not have a
/// locally-declared element AST node.
/// @param count_source the source of the array count, or the array if the array does not have a
/// locally-declared element AST node.
/// @param el_ty the Array element type /// @param el_ty the Array element type
/// @param el_count the number of elements in the array. Zero means runtime-sized. /// @param el_count the number of elements in the array.
/// @param explicit_stride the explicit byte stride of the array. Zero means implicit stride. /// @param explicit_stride the explicit byte stride of the array. Zero means implicit stride.
sem::Array* Array(const Source& source, sem::Array* Array(const Source& el_source,
const Source& count_source,
const sem::Type* el_ty, const sem::Type* el_ty,
uint32_t el_count, sem::ArrayCount el_count,
uint32_t explicit_stride); uint32_t explicit_stride);
/// Builds and returns the semantic information for the alias `alias`. /// Builds and returns the semantic information for the alias `alias`.

View File

@ -432,7 +432,7 @@ TEST_F(ResolverTest, ArraySize_UnsignedLiteral) {
auto* ref = TypeOf(a)->As<sem::Reference>(); auto* ref = TypeOf(a)->As<sem::Reference>();
ASSERT_NE(ref, nullptr); ASSERT_NE(ref, nullptr);
auto* ary = ref->StoreType()->As<sem::Array>(); auto* ary = ref->StoreType()->As<sem::Array>();
EXPECT_EQ(ary->Count(), 10u); EXPECT_EQ(ary->Count(), sem::ConstantArrayCount{10u});
} }
TEST_F(ResolverTest, ArraySize_SignedLiteral) { TEST_F(ResolverTest, ArraySize_SignedLiteral) {
@ -445,7 +445,7 @@ TEST_F(ResolverTest, ArraySize_SignedLiteral) {
auto* ref = TypeOf(a)->As<sem::Reference>(); auto* ref = TypeOf(a)->As<sem::Reference>();
ASSERT_NE(ref, nullptr); ASSERT_NE(ref, nullptr);
auto* ary = ref->StoreType()->As<sem::Array>(); auto* ary = ref->StoreType()->As<sem::Array>();
EXPECT_EQ(ary->Count(), 10u); EXPECT_EQ(ary->Count(), sem::ConstantArrayCount{10u});
} }
TEST_F(ResolverTest, ArraySize_UnsignedConst) { TEST_F(ResolverTest, ArraySize_UnsignedConst) {
@ -460,7 +460,7 @@ TEST_F(ResolverTest, ArraySize_UnsignedConst) {
auto* ref = TypeOf(a)->As<sem::Reference>(); auto* ref = TypeOf(a)->As<sem::Reference>();
ASSERT_NE(ref, nullptr); ASSERT_NE(ref, nullptr);
auto* ary = ref->StoreType()->As<sem::Array>(); auto* ary = ref->StoreType()->As<sem::Array>();
EXPECT_EQ(ary->Count(), 10u); EXPECT_EQ(ary->Count(), sem::ConstantArrayCount{10u});
} }
TEST_F(ResolverTest, ArraySize_SignedConst) { TEST_F(ResolverTest, ArraySize_SignedConst) {
@ -475,7 +475,51 @@ TEST_F(ResolverTest, ArraySize_SignedConst) {
auto* ref = TypeOf(a)->As<sem::Reference>(); auto* ref = TypeOf(a)->As<sem::Reference>();
ASSERT_NE(ref, nullptr); ASSERT_NE(ref, nullptr);
auto* ary = ref->StoreType()->As<sem::Array>(); auto* ary = ref->StoreType()->As<sem::Array>();
EXPECT_EQ(ary->Count(), 10u); EXPECT_EQ(ary->Count(), sem::ConstantArrayCount{10u});
}
TEST_F(ResolverTest, ArraySize_Override) {
// override size = 0;
// var<workgroup> a : array<f32, size>;
auto* override = Override("size", Expr(10_i));
auto* a = GlobalVar("a", ty.array(ty.f32(), Expr("size")), ast::StorageClass::kWorkgroup);
EXPECT_TRUE(r()->Resolve()) << r()->error();
ASSERT_NE(TypeOf(a), nullptr);
auto* ref = TypeOf(a)->As<sem::Reference>();
ASSERT_NE(ref, nullptr);
auto* ary = ref->StoreType()->As<sem::Array>();
auto* sem_override = Sem().Get<sem::GlobalVariable>(override);
ASSERT_NE(sem_override, nullptr);
EXPECT_EQ(ary->Count(), sem::OverrideArrayCount{sem_override});
}
TEST_F(ResolverTest, ArraySize_Override_Equivalence) {
// override size = 0;
// var<workgroup> a : array<f32, size>;
// var<workgroup> b : array<f32, size>;
auto* override = Override("size", Expr(10_i));
auto* a = GlobalVar("a", ty.array(ty.f32(), Expr("size")), ast::StorageClass::kWorkgroup);
auto* b = GlobalVar("b", ty.array(ty.f32(), Expr("size")), ast::StorageClass::kWorkgroup);
EXPECT_TRUE(r()->Resolve()) << r()->error();
ASSERT_NE(TypeOf(a), nullptr);
auto* ref_a = TypeOf(a)->As<sem::Reference>();
ASSERT_NE(ref_a, nullptr);
auto* ary_a = ref_a->StoreType()->As<sem::Array>();
ASSERT_NE(TypeOf(b), nullptr);
auto* ref_b = TypeOf(b)->As<sem::Reference>();
ASSERT_NE(ref_b, nullptr);
auto* ary_b = ref_b->StoreType()->As<sem::Array>();
auto* sem_override = Sem().Get<sem::GlobalVariable>(override);
ASSERT_NE(sem_override, nullptr);
EXPECT_EQ(ary_a->Count(), sem::OverrideArrayCount{sem_override});
EXPECT_EQ(ary_b->Count(), sem::OverrideArrayCount{sem_override});
EXPECT_EQ(ary_a, ary_b);
} }
TEST_F(ResolverTest, Expr_Bitcast) { TEST_F(ResolverTest, Expr_Bitcast) {

View File

@ -654,9 +654,13 @@ struct DataType<array<N, T>> {
/// @return the semantic array type /// @return the semantic array type
static inline const sem::Type* Sem(ProgramBuilder& b) { static inline const sem::Type* Sem(ProgramBuilder& b) {
auto* el = DataType<T>::Sem(b); auto* el = DataType<T>::Sem(b);
sem::ArrayCount count = sem::ConstantArrayCount{N};
if (N == 0) {
count = sem::RuntimeArrayCount{};
}
return b.create<sem::Array>( return b.create<sem::Array>(
/* element */ el, /* element */ el,
/* count */ N, /* count */ count,
/* align */ el->Align(), /* align */ el->Align(),
/* size */ N * el->Size(), /* size */ N * el->Size(),
/* stride */ el->Align(), /* stride */ el->Align(),

View File

@ -310,7 +310,8 @@ TEST_F(ResolverTypeValidationTest, ArraySize_IVecConst) {
TEST_F(ResolverTypeValidationTest, ArraySize_TooBig_ImplicitStride) { TEST_F(ResolverTypeValidationTest, ArraySize_TooBig_ImplicitStride) {
// var<private> a : array<f32, 0x40000000u>; // var<private> a : array<f32, 0x40000000u>;
GlobalVar("a", ty.array(Source{{12, 34}}, ty.f32(), 0x40000000_u), ast::StorageClass::kPrivate); GlobalVar("a", ty.array(ty.f32(), Expr(Source{{12, 34}}, 0x40000000_u)),
ast::StorageClass::kPrivate);
EXPECT_FALSE(r()->Resolve()); EXPECT_FALSE(r()->Resolve());
EXPECT_EQ(r()->error(), EXPECT_EQ(r()->error(),
"12:34 error: array size (0x100000000) must not exceed 0xffffffff bytes"); "12:34 error: array size (0x100000000) must not exceed 0xffffffff bytes");
@ -318,21 +319,157 @@ TEST_F(ResolverTypeValidationTest, ArraySize_TooBig_ImplicitStride) {
TEST_F(ResolverTypeValidationTest, ArraySize_TooBig_ExplicitStride) { TEST_F(ResolverTypeValidationTest, ArraySize_TooBig_ExplicitStride) {
// var<private> a : @stride(8) array<f32, 0x20000000u>; // var<private> a : @stride(8) array<f32, 0x20000000u>;
GlobalVar("a", ty.array(Source{{12, 34}}, ty.f32(), 0x20000000_u, 8), GlobalVar("a", ty.array(ty.f32(), Expr(Source{{12, 34}}, 0x20000000_u), 8),
ast::StorageClass::kPrivate); ast::StorageClass::kPrivate);
EXPECT_FALSE(r()->Resolve()); EXPECT_FALSE(r()->Resolve());
EXPECT_EQ(r()->error(), EXPECT_EQ(r()->error(),
"12:34 error: array size (0x100000000) must not exceed 0xffffffff bytes"); "12:34 error: array size (0x100000000) must not exceed 0xffffffff bytes");
} }
TEST_F(ResolverTypeValidationTest, ArraySize_Overridable) { TEST_F(ResolverTypeValidationTest, ArraySize_Override_PrivateVar) {
// override size = 10i; // override size = 10i;
// var<private> a : array<f32, size>; // var<private> a : array<f32, size>;
Override("size", Expr(10_i)); Override("size", Expr(10_i));
GlobalVar("a", ty.array(ty.f32(), Expr(Source{{12, 34}}, "size")), ast::StorageClass::kPrivate); GlobalVar("a", ty.array(Source{{12, 34}}, ty.f32(), "size"), ast::StorageClass::kPrivate);
EXPECT_FALSE(r()->Resolve()); EXPECT_FALSE(r()->Resolve());
EXPECT_EQ(r()->error(), EXPECT_EQ(r()->error(),
"12:34 error: array size must evaluate to a constant integer expression"); "12:34 error: array with an 'override' element count can only be used as the store "
"type of a 'var<workgroup>'");
}
TEST_F(ResolverTypeValidationTest, ArraySize_Override_ComplexExpr) {
// override size = 10i;
// var<workgroup> a : array<f32, size + 1>;
Override("size", Expr(10_i));
GlobalVar("a", ty.array(ty.f32(), Add(Source{{12, 34}}, "size", 1_i)),
ast::StorageClass::kWorkgroup);
EXPECT_FALSE(r()->Resolve());
EXPECT_EQ(r()->error(),
"12:34 error: array size must evaluate to a constant integer expression or override "
"variable");
}
TEST_F(ResolverTypeValidationTest, ArraySize_Override_InArray) {
// override size = 10i;
// var<workgroup> a : array<array<f32, size>, 4>;
Override("size", Expr(10_i));
GlobalVar("a", ty.array(ty.array(Source{{12, 34}}, ty.f32(), "size"), 4_a),
ast::StorageClass::kWorkgroup);
EXPECT_FALSE(r()->Resolve());
EXPECT_EQ(r()->error(),
"12:34 error: array with an 'override' element count can only be used as the store "
"type of a 'var<workgroup>'");
}
TEST_F(ResolverTypeValidationTest, ArraySize_Override_InStruct) {
// override size = 10i;
// struct S {
// a : array<f32, size>
// };
Override("size", Expr(10_i));
Structure("S", utils::Vector{Member("a", ty.array(Source{{12, 34}}, ty.f32(), "size"))});
EXPECT_FALSE(r()->Resolve());
EXPECT_EQ(r()->error(),
"12:34 error: array with an 'override' element count can only be used as the store "
"type of a 'var<workgroup>'");
}
TEST_F(ResolverTypeValidationTest, ArraySize_Override_FunctionVar_Explicit) {
// override size = 10i;
// fn f() {
// var a : array<f32, size>;
// }
Override("size", Expr(10_i));
Func("f", utils::Empty, ty.void_(),
utils::Vector{
Decl(Var("a", ty.array(Source{{12, 34}}, ty.f32(), "size"))),
});
EXPECT_FALSE(r()->Resolve());
EXPECT_EQ(r()->error(),
"12:34 error: array with an 'override' element count can only be used as the store "
"type of a 'var<workgroup>'");
}
TEST_F(ResolverTypeValidationTest, ArraySize_Override_FunctionLet_Explicit) {
// override size = 10i;
// fn f() {
// var a : array<f32, size>;
// }
Override("size", Expr(10_i));
Func("f", utils::Empty, ty.void_(),
utils::Vector{
Decl(Var("a", ty.array(Source{{12, 34}}, ty.f32(), "size"))),
});
EXPECT_FALSE(r()->Resolve());
EXPECT_EQ(r()->error(),
"12:34 error: array with an 'override' element count can only be used as the store "
"type of a 'var<workgroup>'");
}
TEST_F(ResolverTypeValidationTest, ArraySize_Override_FunctionVar_Implicit) {
// override size = 10i;
// var<workgroup> w : array<f32, size>;
// fn f() {
// var a = w;
// }
Override("size", Expr(10_i));
GlobalVar("w", ty.array(ty.f32(), "size"), ast::StorageClass::kWorkgroup);
Func("f", utils::Empty, ty.void_(),
utils::Vector{
Decl(Var("a", Expr(Source{{12, 34}}, "w"))),
});
EXPECT_FALSE(r()->Resolve());
EXPECT_EQ(r()->error(),
"12:34 error: array with an 'override' element count can only be used as the store "
"type of a 'var<workgroup>'");
}
TEST_F(ResolverTypeValidationTest, ArraySize_Override_FunctionLet_Implicit) {
// override size = 10i;
// var<workgroup> w : array<f32, size>;
// fn f() {
// let a = w;
// }
Override("size", Expr(10_i));
GlobalVar("w", ty.array(ty.f32(), "size"), ast::StorageClass::kWorkgroup);
Func("f", utils::Empty, ty.void_(),
utils::Vector{
Decl(Let("a", Expr(Source{{12, 34}}, "w"))),
});
EXPECT_FALSE(r()->Resolve());
EXPECT_EQ(r()->error(),
"12:34 error: array with an 'override' element count can only be used as the store "
"type of a 'var<workgroup>'");
}
TEST_F(ResolverTypeValidationTest, ArraySize_Override_Param) {
// override size = 10i;
// fn f(a : array<f32, size>) {
// }
Override("size", Expr(10_i));
Func("f", utils::Vector{Param("a", ty.array(Source{{12, 34}}, ty.f32(), "size"))}, ty.void_(),
utils::Empty);
EXPECT_FALSE(r()->Resolve());
EXPECT_EQ(r()->error(), "12:34 error: type of function parameter must be constructible");
}
TEST_F(ResolverTypeValidationTest, ArraySize_Override_ReturnType) {
// override size = 10i;
// fn f() -> array<f32, size> {
// }
Override("size", Expr(10_i));
Func("f", utils::Empty, ty.array(Source{{12, 34}}, ty.f32(), "size"), utils::Empty);
EXPECT_FALSE(r()->Resolve());
EXPECT_EQ(r()->error(), "12:34 error: function return type must be a constructible type");
}
TEST_F(ResolverTypeValidationTest, ArraySize_Workgroup_Overridable) {
// override size = 10i;
// var<workgroup> a : array<f32, size>;
Override("size", Expr(10_i));
GlobalVar("a", ty.array(ty.f32(), Expr(Source{{12, 34}}, "size")),
ast::StorageClass::kWorkgroup);
EXPECT_TRUE(r()->Resolve()) << r()->error();
} }
TEST_F(ResolverTypeValidationTest, ArraySize_ModuleVar) { TEST_F(ResolverTypeValidationTest, ArraySize_ModuleVar) {
@ -367,7 +504,8 @@ TEST_F(ResolverTypeValidationTest, ArraySize_FunctionLet) {
WrapInFunction(size, a); WrapInFunction(size, a);
EXPECT_FALSE(r()->Resolve()); EXPECT_FALSE(r()->Resolve());
EXPECT_EQ(r()->error(), EXPECT_EQ(r()->error(),
"12:34 error: array size must evaluate to a constant integer expression"); "12:34 error: array size must evaluate to a constant integer expression or override "
"variable");
} }
TEST_F(ResolverTypeValidationTest, ArraySize_ComplexExpr) { TEST_F(ResolverTypeValidationTest, ArraySize_ComplexExpr) {
@ -477,7 +615,7 @@ TEST_F(ResolverTypeValidationTest, RuntimeArrayInArray) {
// }; // };
Structure("Foo", utils::Vector{ Structure("Foo", utils::Vector{
Member("rt", ty.array(Source{{12, 34}}, ty.array<f32>(), 4_u)), Member("rt", ty.array(ty.array(Source{{12, 34}}, ty.f32()), 4_u)),
}); });
EXPECT_FALSE(r()->Resolve()) << r()->error(); EXPECT_FALSE(r()->Resolve()) << r()->error();
@ -491,10 +629,9 @@ TEST_F(ResolverTypeValidationTest, RuntimeArrayInStructInArray) {
// }; // };
// var<private> a : array<Foo, 4>; // var<private> a : array<Foo, 4>;
auto* foo = Structure("Foo", utils::Vector{ Structure("Foo", utils::Vector{Member("rt", ty.array<f32>())});
Member("rt", ty.array<f32>()), GlobalVar("v", ty.array(ty.type_name(Source{{12, 34}}, "Foo"), 4_u),
}); ast::StorageClass::kPrivate);
GlobalVar("v", ty.array(Source{{12, 34}}, ty.Of(foo), 4_u), ast::StorageClass::kPrivate);
EXPECT_FALSE(r()->Resolve()) << r()->error(); EXPECT_FALSE(r()->Resolve()) << r()->error();
EXPECT_EQ(r()->error(), EXPECT_EQ(r()->error(),
@ -636,8 +773,8 @@ TEST_F(ResolverTypeValidationTest, AliasRuntimeArrayIsLast_Pass) {
} }
TEST_F(ResolverTypeValidationTest, ArrayOfNonStorableType) { TEST_F(ResolverTypeValidationTest, ArrayOfNonStorableType) {
auto* tex_ty = ty.sampled_texture(ast::TextureDimension::k2d, ty.f32()); auto* tex_ty = ty.sampled_texture(Source{{12, 34}}, ast::TextureDimension::k2d, ty.f32());
GlobalVar("arr", ty.array(Source{{12, 34}}, tex_ty, 4_i), ast::StorageClass::kPrivate); GlobalVar("arr", ty.array(tex_ty, 4_i), ast::StorageClass::kPrivate);
EXPECT_FALSE(r()->Resolve()); EXPECT_FALSE(r()->Resolve());
EXPECT_EQ(r()->error(), EXPECT_EQ(r()->error(),

View File

@ -548,23 +548,28 @@ bool Validator::StorageClassLayout(const sem::Variable* var,
return true; return true;
} }
bool Validator::LocalVariable(const sem::Variable* v) const { bool Validator::LocalVariable(const sem::Variable* local) const {
auto* decl = v->Declaration(); auto* decl = local->Declaration();
if (IsArrayWithOverrideCount(local->Type())) {
RaiseArrayWithOverrideCountError(decl->type ? decl->type->source
: decl->constructor->source);
return false;
}
return Switch( return Switch(
decl, // decl, //
[&](const ast::Var* var) { [&](const ast::Var* var) {
if (IsValidationEnabled(var->attributes, if (IsValidationEnabled(var->attributes,
ast::DisabledValidation::kIgnoreStorageClass)) { ast::DisabledValidation::kIgnoreStorageClass)) {
if (!v->Type()->UnwrapRef()->IsConstructible()) { if (!local->Type()->UnwrapRef()->IsConstructible()) {
AddError("function-scope 'var' must have a constructible type", AddError("function-scope 'var' must have a constructible type",
var->type ? var->type->source : var->source); var->type ? var->type->source : var->source);
return false; return false;
} }
} }
return Var(v); return Var(local);
}, // }, //
[&](const ast::Let*) { return Let(v); }, // [&](const ast::Let*) { return Let(local); }, //
[&](const ast::Const*) { return true; }, // [&](const ast::Const*) { return true; }, //
[&](Default) { [&](Default) {
TINT_ICE(Resolver, diagnostics_) TINT_ICE(Resolver, diagnostics_)
<< "Validator::Variable() called with a unknown variable type: " << "Validator::Variable() called with a unknown variable type: "
@ -578,6 +583,12 @@ bool Validator::GlobalVariable(
const std::unordered_map<OverrideId, const sem::Variable*>& override_ids, const std::unordered_map<OverrideId, const sem::Variable*>& override_ids,
const std::unordered_map<const sem::Type*, const Source&>& atomic_composite_info) const { const std::unordered_map<const sem::Type*, const Source&>& atomic_composite_info) const {
auto* decl = global->Declaration(); auto* decl = global->Declaration();
if (global->StorageClass() != ast::StorageClass::kWorkgroup &&
IsArrayWithOverrideCount(global->Type())) {
RaiseArrayWithOverrideCountError(decl->type ? decl->type->source
: decl->constructor->source);
return false;
}
bool ok = Switch( bool ok = Switch(
decl, // decl, //
[&](const ast::Var* var) { [&](const ast::Var* var) {
@ -875,7 +886,7 @@ bool Validator::Parameter(const ast::Function* func, const sem::Variable* var) c
if (IsPlain(var->Type())) { if (IsPlain(var->Type())) {
if (!var->Type()->IsConstructible()) { if (!var->Type()->IsConstructible()) {
AddError("type of function parameter must be constructible", decl->source); AddError("type of function parameter must be constructible", decl->type->source);
return false; return false;
} }
} else if (!var->Type()->IsAnyOf<sem::Texture, sem::Sampler, sem::Pointer>()) { } else if (!var->Type()->IsAnyOf<sem::Texture, sem::Sampler, sem::Pointer>()) {
@ -1898,20 +1909,23 @@ bool Validator::ArrayConstructor(const ast::CallExpression* ctor,
if (array_type->IsRuntimeSized()) { if (array_type->IsRuntimeSized()) {
AddError("cannot construct a runtime-sized array", ctor->source); AddError("cannot construct a runtime-sized array", ctor->source);
return false; return false;
} else if (!elem_ty->IsConstructible()) { }
if (array_type->IsOverrideSized()) {
AddError("cannot construct an array that has an override expression count", ctor->source);
return false;
}
if (!elem_ty->IsConstructible()) {
AddError("array constructor has non-constructible element type", ctor->source); AddError("array constructor has non-constructible element type", ctor->source);
return false; return false;
} else if (!values.IsEmpty() && (values.Length() != array_type->Count())) { }
std::string fm = values.Length() < array_type->Count() ? "few" : "many";
const auto count = std::get<sem::ConstantArrayCount>(array_type->Count()).value;
if (!values.IsEmpty() && (values.Length() != count)) {
std::string fm = values.Length() < count ? "few" : "many";
AddError("array constructor has too " + fm + " elements: expected " + AddError("array constructor has too " + fm + " elements: expected " +
std::to_string(array_type->Count()) + ", found " + std::to_string(count) + ", found " + std::to_string(values.Length()),
std::to_string(values.Length()),
ctor->source);
return false;
} else if (values.Length() > array_type->Count()) {
AddError("array constructor has too many elements: expected " +
std::to_string(array_type->Count()) + ", found " +
std::to_string(values.Length()),
ctor->source); ctor->source);
return false; return false;
} }
@ -2086,18 +2100,25 @@ bool Validator::PushConstants(const std::vector<sem::Function*>& entry_points) c
return true; return true;
} }
bool Validator::Array(const sem::Array* arr, const Source& source) const { bool Validator::Array(const sem::Array* arr, const Source& el_source) const {
auto* el_ty = arr->ElemType(); auto* el_ty = arr->ElemType();
if (!IsPlain(el_ty)) { if (!IsPlain(el_ty)) {
AddError(sem_.TypeNameOf(el_ty) + " cannot be used as an element type of an array", source); AddError(sem_.TypeNameOf(el_ty) + " cannot be used as an element type of an array",
el_source);
return false; return false;
} }
if (!IsFixedFootprint(el_ty)) { if (!IsFixedFootprint(el_ty)) {
AddError("an array element type cannot contain a runtime-sized array", source); AddError("an array element type cannot contain a runtime-sized array", el_source);
return false; return false;
} }
if (IsArrayWithOverrideCount(el_ty)) {
RaiseArrayWithOverrideCountError(el_source);
return false;
}
return true; return true;
} }
@ -2154,6 +2175,11 @@ bool Validator::Structure(const sem::Struct* str, ast::PipelineStage stage) cons
return false; return false;
} }
} }
if (IsArrayWithOverrideCount(member->Type())) {
RaiseArrayWithOverrideCountError(member->Declaration()->type->source);
return false;
}
} else if (!IsFixedFootprint(member->Type())) { } else if (!IsFixedFootprint(member->Type())) {
AddError( AddError(
"a struct that contains a runtime array cannot be nested inside " "a struct that contains a runtime array cannot be nested inside "
@ -2488,6 +2514,22 @@ bool Validator::IsValidationEnabled(utils::VectorRef<const ast::Attribute*> attr
return !IsValidationDisabled(attributes, validation); return !IsValidationDisabled(attributes, validation);
} }
bool Validator::IsArrayWithOverrideCount(const sem::Type* ty) const {
if (auto* arr = ty->UnwrapRef()->As<sem::Array>()) {
if (arr->IsOverrideSized()) {
return true;
}
}
return false;
}
void Validator::RaiseArrayWithOverrideCountError(const Source& source) const {
AddError(
"array with an 'override' element count can only be used as the store type of a "
"'var<workgroup>'",
source);
}
std::string Validator::VectorPretty(uint32_t size, const sem::Type* element_type) const { std::string Validator::VectorPretty(uint32_t size, const sem::Type* element_type) const {
sem::Vector vec_type(element_type, size); sem::Vector vec_type(element_type, size);
return vec_type.FriendlyName(symbols_); return vec_type.FriendlyName(symbols_);

View File

@ -128,9 +128,10 @@ class Validator {
/// Validates the array /// Validates the array
/// @param arr the array to validate /// @param arr the array to validate
/// @param source the source of the array /// @param el_source the source of the array element, or the array if the array does not have a
/// locally-declared element AST node.
/// @returns true on success, false otherwise. /// @returns true on success, false otherwise.
bool Array(const sem::Array* arr, const Source& source) const; bool Array(const sem::Array* arr, const Source& el_source) const;
/// Validates an array stride attribute /// Validates an array stride attribute
/// @param attr the stride attribute to validate /// @param attr the stride attribute to validate
@ -463,6 +464,16 @@ class Validator {
ast::DisabledValidation validation) const; ast::DisabledValidation validation) const;
private: private:
/// @param ty the type to check
/// @returns true if @p ty is an array with an `override` expression element count, otherwise
/// false.
bool IsArrayWithOverrideCount(const sem::Type* ty) const;
/// Raises an error about an array type using an `override` expression element count, outside
/// the single allowed use of a `var<workgroup>`.
/// @param source the source for the error
void RaiseArrayWithOverrideCountError(const Source& source) const;
/// Searches the current statement and up through parents of the current /// Searches the current statement and up through parents of the current
/// statement looking for a loop or for-loop continuing statement. /// statement looking for a loop or for-loop continuing statement.
/// @returns the closest continuing statement to the current statement that /// @returns the closest continuing statement to the current statement that

View File

@ -89,12 +89,13 @@ TEST_F(ValidatorIsStorableTest, Atomic) {
} }
TEST_F(ValidatorIsStorableTest, ArraySizedOfStorable) { TEST_F(ValidatorIsStorableTest, ArraySizedOfStorable) {
auto* arr = create<sem::Array>(create<sem::I32>(), 5u, 4u, 20u, 4u, 4u); auto* arr =
create<sem::Array>(create<sem::I32>(), sem::ConstantArrayCount{5u}, 4u, 20u, 4u, 4u);
EXPECT_TRUE(v()->IsStorable(arr)); EXPECT_TRUE(v()->IsStorable(arr));
} }
TEST_F(ValidatorIsStorableTest, ArrayUnsizedOfStorable) { TEST_F(ValidatorIsStorableTest, ArrayUnsizedOfStorable) {
auto* arr = create<sem::Array>(create<sem::I32>(), 0u, 4u, 4u, 4u, 4u); auto* arr = create<sem::Array>(create<sem::I32>(), sem::RuntimeArrayCount{}, 4u, 4u, 4u, 4u);
EXPECT_TRUE(v()->IsStorable(arr)); EXPECT_TRUE(v()->IsStorable(arr));
} }

View File

@ -16,15 +16,22 @@
#include <string> #include <string>
#include "src/tint/ast/variable.h"
#include "src/tint/debug.h" #include "src/tint/debug.h"
#include "src/tint/sem/variable.h"
#include "src/tint/symbol_table.h"
#include "src/tint/utils/hash.h" #include "src/tint/utils/hash.h"
TINT_INSTANTIATE_TYPEINFO(tint::sem::Array); TINT_INSTANTIATE_TYPEINFO(tint::sem::Array);
namespace tint::sem { namespace tint::sem {
const char* Array::kErrExpectedConstantCount =
"array size is an override-expression, when expected a constant-expression.\n"
"Was the SubstituteOverride transform run?";
Array::Array(const Type* element, Array::Array(const Type* element,
uint32_t count, ArrayCount count,
uint32_t align, uint32_t align,
uint32_t size, uint32_t size,
uint32_t stride, uint32_t stride,
@ -35,8 +42,9 @@ Array::Array(const Type* element,
size_(size), size_(size),
stride_(stride), stride_(stride),
implicit_stride_(implicit_stride), implicit_stride_(implicit_stride),
constructible_(count > 0 // Runtime-sized arrays are not constructible // Only constant-expression sized arrays are constructible
&& element->IsConstructible()) { constructible_(std::holds_alternative<ConstantArrayCount>(count) &&
element->IsConstructible()) {
TINT_ASSERT(Semantic, element_); TINT_ASSERT(Semantic, element_);
} }
@ -64,8 +72,10 @@ std::string Array::FriendlyName(const SymbolTable& symbols) const {
out << "@stride(" << stride_ << ") "; out << "@stride(" << stride_ << ") ";
} }
out << "array<" << element_->FriendlyName(symbols); out << "array<" << element_->FriendlyName(symbols);
if (!IsRuntimeSized()) { if (auto* const_count = std::get_if<ConstantArrayCount>(&count_)) {
out << ", " << count_; out << ", " << const_count->value;
} else if (auto* override_count = std::get_if<OverrideArrayCount>(&count_)) {
out << ", " << symbols.NameFor(override_count->variable->Declaration()->symbol);
} }
out << ">"; out << ">";
return out.str(); return out.str();

View File

@ -16,29 +16,116 @@
#define SRC_TINT_SEM_ARRAY_H_ #define SRC_TINT_SEM_ARRAY_H_
#include <stdint.h> #include <stdint.h>
#include <optional>
#include <string> #include <string>
#include <variant>
#include "src/tint/sem/node.h" #include "src/tint/sem/node.h"
#include "src/tint/sem/type.h" #include "src/tint/sem/type.h"
#include "src/tint/utils/compiler_macros.h"
// Forward declarations
namespace tint::sem {
class GlobalVariable;
} // namespace tint::sem
namespace tint::sem { namespace tint::sem {
/// The variant of an ArrayCount when the array is a constant expression.
/// Example:
/// ```
/// const N = 123;
/// type arr = array<i32, N>
/// ```
struct ConstantArrayCount {
/// The array count constant-expression value.
uint32_t value;
};
/// The variant of an ArrayCount when the count is a named override variable.
/// Example:
/// ```
/// override N : i32;
/// type arr = array<i32, N>
/// ```
struct OverrideArrayCount {
/// The `override` variable.
const GlobalVariable* variable;
};
/// The variant of an ArrayCount when the array is is runtime-sized.
/// Example:
/// ```
/// type arr = array<i32>
/// ```
struct RuntimeArrayCount {};
/// An array count is either a constant-expression value, an override identifier, or runtime-sized.
using ArrayCount = std::variant<ConstantArrayCount, OverrideArrayCount, RuntimeArrayCount>;
/// Equality operator
/// @param a the LHS ConstantArrayCount
/// @param b the RHS ConstantArrayCount
/// @returns true if @p a is equal to @p b
inline bool operator==(const ConstantArrayCount& a, const ConstantArrayCount& b) {
return a.value == b.value;
}
/// Equality operator
/// @param a the LHS OverrideArrayCount
/// @param b the RHS OverrideArrayCount
/// @returns true if @p a is equal to @p b
inline bool operator==(const OverrideArrayCount& a, const OverrideArrayCount& b) {
return a.variable == b.variable;
}
/// Equality operator
/// @returns true
inline bool operator==(const RuntimeArrayCount&, const RuntimeArrayCount&) {
return true;
}
/// Equality operator
/// @param a the LHS ArrayCount
/// @param b the RHS count
/// @returns true if @p a is equal to @p b
template <typename T,
typename = std::enable_if_t<std::is_same_v<T, ConstantArrayCount> ||
std::is_same_v<T, OverrideArrayCount> ||
std::is_same_v<T, RuntimeArrayCount>>>
inline bool operator==(const ArrayCount& a, const T& b) {
TINT_BEGIN_DISABLE_WARNING(UNREACHABLE_CODE);
return std::visit(
[&](auto count) {
if constexpr (std::is_same_v<std::decay_t<decltype(count)>, T>) {
return count == b;
}
return false;
},
a);
TINT_END_DISABLE_WARNING(UNREACHABLE_CODE);
}
/// Array holds the semantic information for Array nodes. /// Array holds the semantic information for Array nodes.
class Array final : public Castable<Array, Type> { class Array final : public Castable<Array, Type> {
public: public:
/// An error message string stating that the array count was expected to be a constant
/// expression. Used by multiple writers and transforms.
static const char* kErrExpectedConstantCount;
/// Constructor /// Constructor
/// @param element the array element type /// @param element the array element type
/// @param count the number of elements in the array. 0 represents a /// @param count the number of elements in the array.
/// runtime-sized array.
/// @param align the byte alignment of the array /// @param align the byte alignment of the array
/// @param size the byte size of the array /// @param size the byte size of the array. The size will be 0 if the array element count is
/// pipeline overridable.
/// @param stride the number of bytes from the start of one element of the /// @param stride the number of bytes from the start of one element of the
/// array to the start of the next element /// array to the start of the next element
/// @param implicit_stride the number of bytes from the start of one element /// @param implicit_stride the number of bytes from the start of one element
/// of the array to the start of the next element, if there was no `@stride` /// of the array to the start of the next element, if there was no `@stride`
/// attribute applied. /// attribute applied.
Array(Type const* element, Array(Type const* element,
uint32_t count, ArrayCount count,
uint32_t align, uint32_t align,
uint32_t size, uint32_t size,
uint32_t stride, uint32_t stride,
@ -54,9 +141,16 @@ class Array final : public Castable<Array, Type> {
/// @return the array element type /// @return the array element type
Type const* ElemType() const { return element_; } Type const* ElemType() const { return element_; }
/// @returns the number of elements in the array. 0 represents a runtime-sized /// @returns the number of elements in the array.
/// array. const ArrayCount& Count() const { return count_; }
uint32_t Count() const { return count_; }
/// @returns the array count if the count is a constant expression, otherwise returns nullopt.
inline std::optional<uint32_t> ConstantCount() const {
if (auto* count = std::get_if<ConstantArrayCount>(&count_)) {
return count->value;
}
return std::nullopt;
}
/// @returns the byte alignment of the array /// @returns the byte alignment of the array
/// @note this may differ from the alignment of a structure member of this /// @note this may differ from the alignment of a structure member of this
@ -81,8 +175,14 @@ class Array final : public Castable<Array, Type> {
/// natural stride /// natural stride
bool IsStrideImplicit() const { return stride_ == implicit_stride_; } bool IsStrideImplicit() const { return stride_ == implicit_stride_; }
/// @returns true if this array is sized using an constant expression
bool IsConstantSized() const { return std::holds_alternative<ConstantArrayCount>(count_); }
/// @returns true if this array is sized using an override variable
bool IsOverrideSized() const { return std::holds_alternative<OverrideArrayCount>(count_); }
/// @returns true if this array is runtime sized /// @returns true if this array is runtime sized
bool IsRuntimeSized() const { return count_ == 0; } bool IsRuntimeSized() const { return std::holds_alternative<RuntimeArrayCount>(count_); }
/// @returns true if constructible as per /// @returns true if constructible as per
/// https://gpuweb.github.io/gpuweb/wgsl/#constructible-types /// https://gpuweb.github.io/gpuweb/wgsl/#constructible-types
@ -95,7 +195,7 @@ class Array final : public Castable<Array, Type> {
private: private:
Type const* const element_; Type const* const element_;
const uint32_t count_; const ArrayCount count_;
const uint32_t align_; const uint32_t align_;
const uint32_t size_; const uint32_t size_;
const uint32_t stride_; const uint32_t stride_;
@ -105,4 +205,38 @@ class Array final : public Castable<Array, Type> {
} // namespace tint::sem } // namespace tint::sem
namespace std {
/// Custom std::hash specialization for tint::sem::ConstantArrayCount.
template <>
class hash<tint::sem::ConstantArrayCount> {
public:
/// @param count the count to hash
/// @return the hash value
inline std::size_t operator()(const tint::sem::ConstantArrayCount& count) const {
return std::hash<decltype(count.value)>()(count.value);
}
};
/// Custom std::hash specialization for tint::sem::OverrideArrayCount.
template <>
class hash<tint::sem::OverrideArrayCount> {
public:
/// @param count the count to hash
/// @return the hash value
inline std::size_t operator()(const tint::sem::OverrideArrayCount& count) const {
return std::hash<decltype(count.variable)>()(count.variable);
}
};
/// Custom std::hash specialization for tint::sem::RuntimeArrayCount.
template <>
class hash<tint::sem::RuntimeArrayCount> {
public:
/// @return the hash value
inline std::size_t operator()(const tint::sem::RuntimeArrayCount&) const { return 42; }
};
} // namespace std
#endif // SRC_TINT_SEM_ARRAY_H_ #endif // SRC_TINT_SEM_ARRAY_H_

View File

@ -21,16 +21,16 @@ namespace {
using ArrayTest = TestHelper; using ArrayTest = TestHelper;
TEST_F(ArrayTest, CreateSizedArray) { TEST_F(ArrayTest, CreateSizedArray) {
auto* a = create<Array>(create<U32>(), 2u, 4u, 8u, 32u, 16u); auto* a = create<Array>(create<U32>(), ConstantArrayCount{2u}, 4u, 8u, 32u, 16u);
auto* b = create<Array>(create<U32>(), 2u, 4u, 8u, 32u, 16u); auto* b = create<Array>(create<U32>(), ConstantArrayCount{2u}, 4u, 8u, 32u, 16u);
auto* c = create<Array>(create<U32>(), 3u, 4u, 8u, 32u, 16u); auto* c = create<Array>(create<U32>(), ConstantArrayCount{3u}, 4u, 8u, 32u, 16u);
auto* d = create<Array>(create<U32>(), 2u, 5u, 8u, 32u, 16u); auto* d = create<Array>(create<U32>(), ConstantArrayCount{2u}, 5u, 8u, 32u, 16u);
auto* e = create<Array>(create<U32>(), 2u, 4u, 9u, 32u, 16u); auto* e = create<Array>(create<U32>(), ConstantArrayCount{2u}, 4u, 9u, 32u, 16u);
auto* f = create<Array>(create<U32>(), 2u, 4u, 8u, 33u, 16u); auto* f = create<Array>(create<U32>(), ConstantArrayCount{2u}, 4u, 8u, 33u, 16u);
auto* g = create<Array>(create<U32>(), 2u, 4u, 8u, 33u, 17u); auto* g = create<Array>(create<U32>(), ConstantArrayCount{2u}, 4u, 8u, 33u, 17u);
EXPECT_EQ(a->ElemType(), create<U32>()); EXPECT_EQ(a->ElemType(), create<U32>());
EXPECT_EQ(a->Count(), 2u); EXPECT_EQ(a->Count(), ConstantArrayCount{2u});
EXPECT_EQ(a->Align(), 4u); EXPECT_EQ(a->Align(), 4u);
EXPECT_EQ(a->Size(), 8u); EXPECT_EQ(a->Size(), 8u);
EXPECT_EQ(a->Stride(), 32u); EXPECT_EQ(a->Stride(), 32u);
@ -47,15 +47,15 @@ TEST_F(ArrayTest, CreateSizedArray) {
} }
TEST_F(ArrayTest, CreateRuntimeArray) { TEST_F(ArrayTest, CreateRuntimeArray) {
auto* a = create<Array>(create<U32>(), 0u, 4u, 8u, 32u, 32u); auto* a = create<Array>(create<U32>(), RuntimeArrayCount{}, 4u, 8u, 32u, 32u);
auto* b = create<Array>(create<U32>(), 0u, 4u, 8u, 32u, 32u); auto* b = create<Array>(create<U32>(), RuntimeArrayCount{}, 4u, 8u, 32u, 32u);
auto* c = create<Array>(create<U32>(), 0u, 5u, 8u, 32u, 32u); auto* c = create<Array>(create<U32>(), RuntimeArrayCount{}, 5u, 8u, 32u, 32u);
auto* d = create<Array>(create<U32>(), 0u, 4u, 9u, 32u, 32u); auto* d = create<Array>(create<U32>(), RuntimeArrayCount{}, 4u, 9u, 32u, 32u);
auto* e = create<Array>(create<U32>(), 0u, 4u, 8u, 33u, 32u); auto* e = create<Array>(create<U32>(), RuntimeArrayCount{}, 4u, 8u, 33u, 32u);
auto* f = create<Array>(create<U32>(), 0u, 4u, 8u, 33u, 17u); auto* f = create<Array>(create<U32>(), RuntimeArrayCount{}, 4u, 8u, 33u, 17u);
EXPECT_EQ(a->ElemType(), create<U32>()); EXPECT_EQ(a->ElemType(), create<U32>());
EXPECT_EQ(a->Count(), 0u); EXPECT_EQ(a->Count(), sem::RuntimeArrayCount{});
EXPECT_EQ(a->Align(), 4u); EXPECT_EQ(a->Align(), 4u);
EXPECT_EQ(a->Size(), 8u); EXPECT_EQ(a->Size(), 8u);
EXPECT_EQ(a->Stride(), 32u); EXPECT_EQ(a->Stride(), 32u);
@ -71,13 +71,13 @@ TEST_F(ArrayTest, CreateRuntimeArray) {
} }
TEST_F(ArrayTest, Hash) { TEST_F(ArrayTest, Hash) {
auto* a = create<Array>(create<U32>(), 2u, 4u, 8u, 32u, 16u); auto* a = create<Array>(create<U32>(), ConstantArrayCount{2u}, 4u, 8u, 32u, 16u);
auto* b = create<Array>(create<U32>(), 2u, 4u, 8u, 32u, 16u); auto* b = create<Array>(create<U32>(), ConstantArrayCount{2u}, 4u, 8u, 32u, 16u);
auto* c = create<Array>(create<U32>(), 3u, 4u, 8u, 32u, 16u); auto* c = create<Array>(create<U32>(), ConstantArrayCount{3u}, 4u, 8u, 32u, 16u);
auto* d = create<Array>(create<U32>(), 2u, 5u, 8u, 32u, 16u); auto* d = create<Array>(create<U32>(), ConstantArrayCount{2u}, 5u, 8u, 32u, 16u);
auto* e = create<Array>(create<U32>(), 2u, 4u, 9u, 32u, 16u); auto* e = create<Array>(create<U32>(), ConstantArrayCount{2u}, 4u, 9u, 32u, 16u);
auto* f = create<Array>(create<U32>(), 2u, 4u, 8u, 33u, 16u); auto* f = create<Array>(create<U32>(), ConstantArrayCount{2u}, 4u, 8u, 33u, 16u);
auto* g = create<Array>(create<U32>(), 2u, 4u, 8u, 33u, 17u); auto* g = create<Array>(create<U32>(), ConstantArrayCount{2u}, 4u, 8u, 33u, 17u);
EXPECT_EQ(a->Hash(), b->Hash()); EXPECT_EQ(a->Hash(), b->Hash());
EXPECT_NE(a->Hash(), c->Hash()); EXPECT_NE(a->Hash(), c->Hash());
@ -88,13 +88,13 @@ TEST_F(ArrayTest, Hash) {
} }
TEST_F(ArrayTest, Equals) { TEST_F(ArrayTest, Equals) {
auto* a = create<Array>(create<U32>(), 2u, 4u, 8u, 32u, 16u); auto* a = create<Array>(create<U32>(), ConstantArrayCount{2u}, 4u, 8u, 32u, 16u);
auto* b = create<Array>(create<U32>(), 2u, 4u, 8u, 32u, 16u); auto* b = create<Array>(create<U32>(), ConstantArrayCount{2u}, 4u, 8u, 32u, 16u);
auto* c = create<Array>(create<U32>(), 3u, 4u, 8u, 32u, 16u); auto* c = create<Array>(create<U32>(), ConstantArrayCount{3u}, 4u, 8u, 32u, 16u);
auto* d = create<Array>(create<U32>(), 2u, 5u, 8u, 32u, 16u); auto* d = create<Array>(create<U32>(), ConstantArrayCount{2u}, 5u, 8u, 32u, 16u);
auto* e = create<Array>(create<U32>(), 2u, 4u, 9u, 32u, 16u); auto* e = create<Array>(create<U32>(), ConstantArrayCount{2u}, 4u, 9u, 32u, 16u);
auto* f = create<Array>(create<U32>(), 2u, 4u, 8u, 33u, 16u); auto* f = create<Array>(create<U32>(), ConstantArrayCount{2u}, 4u, 8u, 33u, 16u);
auto* g = create<Array>(create<U32>(), 2u, 4u, 8u, 33u, 17u); auto* g = create<Array>(create<U32>(), ConstantArrayCount{2u}, 4u, 8u, 33u, 17u);
EXPECT_TRUE(a->Equals(*b)); EXPECT_TRUE(a->Equals(*b));
EXPECT_FALSE(a->Equals(*c)); EXPECT_FALSE(a->Equals(*c));
@ -106,22 +106,22 @@ TEST_F(ArrayTest, Equals) {
} }
TEST_F(ArrayTest, FriendlyNameRuntimeSized) { TEST_F(ArrayTest, FriendlyNameRuntimeSized) {
auto* arr = create<Array>(create<I32>(), 0u, 0u, 4u, 4u, 4u); auto* arr = create<Array>(create<I32>(), RuntimeArrayCount{}, 0u, 4u, 4u, 4u);
EXPECT_EQ(arr->FriendlyName(Symbols()), "array<i32>"); EXPECT_EQ(arr->FriendlyName(Symbols()), "array<i32>");
} }
TEST_F(ArrayTest, FriendlyNameStaticSized) { TEST_F(ArrayTest, FriendlyNameStaticSized) {
auto* arr = create<Array>(create<I32>(), 5u, 4u, 20u, 4u, 4u); auto* arr = create<Array>(create<I32>(), ConstantArrayCount{5u}, 4u, 20u, 4u, 4u);
EXPECT_EQ(arr->FriendlyName(Symbols()), "array<i32, 5>"); EXPECT_EQ(arr->FriendlyName(Symbols()), "array<i32, 5>");
} }
TEST_F(ArrayTest, FriendlyNameRuntimeSizedNonImplicitStride) { TEST_F(ArrayTest, FriendlyNameRuntimeSizedNonImplicitStride) {
auto* arr = create<Array>(create<I32>(), 0u, 0u, 4u, 8u, 4u); auto* arr = create<Array>(create<I32>(), RuntimeArrayCount{}, 0u, 4u, 8u, 4u);
EXPECT_EQ(arr->FriendlyName(Symbols()), "@stride(8) array<i32>"); EXPECT_EQ(arr->FriendlyName(Symbols()), "@stride(8) array<i32>");
} }
TEST_F(ArrayTest, FriendlyNameStaticSizedNonImplicitStride) { TEST_F(ArrayTest, FriendlyNameStaticSizedNonImplicitStride) {
auto* arr = create<Array>(create<I32>(), 5u, 4u, 20u, 8u, 4u); auto* arr = create<Array>(create<I32>(), ConstantArrayCount{5u}, 4u, 20u, 8u, 4u);
EXPECT_EQ(arr->FriendlyName(Symbols()), "@stride(8) array<i32, 5>"); EXPECT_EQ(arr->FriendlyName(Symbols()), "@stride(8) array<i32, 5>");
} }

View File

@ -246,7 +246,9 @@ const Type* Type::ElementOf(const Type* ty, uint32_t* count /* = nullptr */) {
}, },
[&](const Array* a) { [&](const Array* a) {
if (count) { if (count) {
*count = a->Count(); if (auto* const_count = std::get_if<ConstantArrayCount>(&a->Count())) {
*count = const_count->value;
}
} }
return a->ElemType(); return a->ElemType();
}, },

View File

@ -62,56 +62,56 @@ struct TypeTest : public TestHelper {
/* size_no_padding*/ 4u); /* size_no_padding*/ 4u);
const sem::Array* arr_i32 = create<Array>( const sem::Array* arr_i32 = create<Array>(
/* element */ i32, /* element */ i32,
/* count */ 5u, /* count */ ConstantArrayCount{5u},
/* align */ 4u, /* align */ 4u,
/* size */ 5u * 4u, /* size */ 5u * 4u,
/* stride */ 5u * 4u, /* stride */ 5u * 4u,
/* implicit_stride */ 5u * 4u); /* implicit_stride */ 5u * 4u);
const sem::Array* arr_ai = create<Array>( const sem::Array* arr_ai = create<Array>(
/* element */ ai, /* element */ ai,
/* count */ 5u, /* count */ ConstantArrayCount{5u},
/* align */ 4u, /* align */ 4u,
/* size */ 5u * 4u, /* size */ 5u * 4u,
/* stride */ 5u * 4u, /* stride */ 5u * 4u,
/* implicit_stride */ 5u * 4u); /* implicit_stride */ 5u * 4u);
const sem::Array* arr_vec3_i32 = create<Array>( const sem::Array* arr_vec3_i32 = create<Array>(
/* element */ vec3_i32, /* element */ vec3_i32,
/* count */ 5u, /* count */ ConstantArrayCount{5u},
/* align */ 16u, /* align */ 16u,
/* size */ 5u * 16u, /* size */ 5u * 16u,
/* stride */ 5u * 16u, /* stride */ 5u * 16u,
/* implicit_stride */ 5u * 16u); /* implicit_stride */ 5u * 16u);
const sem::Array* arr_vec3_ai = create<Array>( const sem::Array* arr_vec3_ai = create<Array>(
/* element */ vec3_ai, /* element */ vec3_ai,
/* count */ 5u, /* count */ ConstantArrayCount{5u},
/* align */ 16u, /* align */ 16u,
/* size */ 5u * 16u, /* size */ 5u * 16u,
/* stride */ 5u * 16u, /* stride */ 5u * 16u,
/* implicit_stride */ 5u * 16u); /* implicit_stride */ 5u * 16u);
const sem::Array* arr_mat4x3_f16 = create<Array>( const sem::Array* arr_mat4x3_f16 = create<Array>(
/* element */ mat4x3_f16, /* element */ mat4x3_f16,
/* count */ 5u, /* count */ ConstantArrayCount{5u},
/* align */ 32u, /* align */ 32u,
/* size */ 5u * 32u, /* size */ 5u * 32u,
/* stride */ 5u * 32u, /* stride */ 5u * 32u,
/* implicit_stride */ 5u * 32u); /* implicit_stride */ 5u * 32u);
const sem::Array* arr_mat4x3_f32 = create<Array>( const sem::Array* arr_mat4x3_f32 = create<Array>(
/* element */ mat4x3_f32, /* element */ mat4x3_f32,
/* count */ 5u, /* count */ ConstantArrayCount{5u},
/* align */ 64u, /* align */ 64u,
/* size */ 5u * 64u, /* size */ 5u * 64u,
/* stride */ 5u * 64u, /* stride */ 5u * 64u,
/* implicit_stride */ 5u * 64u); /* implicit_stride */ 5u * 64u);
const sem::Array* arr_mat4x3_af = create<Array>( const sem::Array* arr_mat4x3_af = create<Array>(
/* element */ mat4x3_af, /* element */ mat4x3_af,
/* count */ 5u, /* count */ ConstantArrayCount{5u},
/* align */ 64u, /* align */ 64u,
/* size */ 5u * 64u, /* size */ 5u * 64u,
/* stride */ 5u * 64u, /* stride */ 5u * 64u,
/* implicit_stride */ 5u * 64u); /* implicit_stride */ 5u * 64u);
const sem::Array* arr_str = create<Array>( const sem::Array* arr_str = create<Array>(
/* element */ str, /* element */ str,
/* count */ 5u, /* count */ ConstantArrayCount{5u},
/* align */ 4u, /* align */ 4u,
/* size */ 5u * 4u, /* size */ 5u * 4u,
/* stride */ 5u * 4u, /* stride */ 5u * 4u,

View File

@ -471,8 +471,18 @@ struct DecomposeMemoryAccess::State {
auto* arr = b.Var(b.Symbols().New("arr"), CreateASTTypeFor(ctx, arr_ty)); auto* arr = b.Var(b.Symbols().New("arr"), CreateASTTypeFor(ctx, arr_ty));
auto* i = b.Var(b.Symbols().New("i"), b.Expr(0_u)); auto* i = b.Var(b.Symbols().New("i"), b.Expr(0_u));
auto* for_init = b.Decl(i); auto* for_init = b.Decl(i);
auto arr_cnt = arr_ty->ConstantCount();
if (!arr_cnt) {
// Non-constant counts should not be possible:
// * Override-expression counts can only be applied to workgroup arrays, and
// this method only handles storage and uniform.
// * Runtime-sized arrays are not loadable.
TINT_ICE(Transform, ctx.dst->Diagnostics())
<< "unexpected non-constant array count";
arr_cnt = 1;
}
auto* for_cond = b.create<ast::BinaryExpression>( auto* for_cond = b.create<ast::BinaryExpression>(
ast::BinaryOp::kLessThan, b.Expr(i), b.Expr(u32(arr_ty->Count()))); ast::BinaryOp::kLessThan, b.Expr(i), b.Expr(u32(arr_cnt.value())));
auto* for_cont = b.Assign(i, b.Add(i, 1_u)); auto* for_cont = b.Assign(i, b.Add(i, 1_u));
auto* arr_el = b.IndexAccessor(arr, i); auto* arr_el = b.IndexAccessor(arr, i);
auto* el_offset = b.Add(b.Expr("offset"), b.Mul(i, u32(arr_ty->Stride()))); auto* el_offset = b.Add(b.Expr("offset"), b.Mul(i, u32(arr_ty->Stride())));
@ -562,8 +572,18 @@ struct DecomposeMemoryAccess::State {
StoreFunc(buf_ty, arr_ty->ElemType()->UnwrapRef(), var_user); StoreFunc(buf_ty, arr_ty->ElemType()->UnwrapRef(), var_user);
auto* i = b.Var(b.Symbols().New("i"), b.Expr(0_u)); auto* i = b.Var(b.Symbols().New("i"), b.Expr(0_u));
auto* for_init = b.Decl(i); auto* for_init = b.Decl(i);
auto arr_cnt = arr_ty->ConstantCount();
if (!arr_cnt) {
// Non-constant counts should not be possible:
// * Override-expression counts can only be applied to workgroup
// arrays, and this method only handles storage and uniform.
// * Runtime-sized arrays are not storable.
TINT_ICE(Transform, ctx.dst->Diagnostics())
<< "unexpected non-constant array count";
arr_cnt = 1;
}
auto* for_cond = b.create<ast::BinaryExpression>( auto* for_cond = b.create<ast::BinaryExpression>(
ast::BinaryOp::kLessThan, b.Expr(i), b.Expr(u32(arr_ty->Count()))); ast::BinaryOp::kLessThan, b.Expr(i), b.Expr(u32(arr_cnt.value())));
auto* for_cont = b.Assign(i, b.Add(i, 1_u)); auto* for_cont = b.Assign(i, b.Add(i, 1_u));
auto* arr_el = b.IndexAccessor(array, i); auto* arr_el = b.IndexAccessor(array, i);
auto* el_offset = auto* el_offset =

View File

@ -82,7 +82,7 @@ void PadStructs::Run(CloneContext& ctx, const DataMap&, DataMap&) const {
// std140 structs should be padded out to 16 bytes. // std140 structs should be padded out to 16 bytes.
size = utils::RoundUp(16u, size); size = utils::RoundUp(16u, size);
} else if (auto* array_ty = ty->As<sem::Array>()) { } else if (auto* array_ty = ty->As<sem::Array>()) {
if (array_ty->Count() == 0) { if (array_ty->IsRuntimeSized()) {
has_runtime_sized_array = true; has_runtime_sized_array = true;
} }
} }

View File

@ -99,14 +99,21 @@ struct Robustness::State {
// Must clamp, even if the index is constant. // Must clamp, even if the index is constant.
auto* arr_ptr = b.AddressOf(ctx.Clone(expr->object)); auto* arr_ptr = b.AddressOf(ctx.Clone(expr->object));
max = b.Sub(b.Call("arrayLength", arr_ptr), 1_u); max = b.Sub(b.Call("arrayLength", arr_ptr), 1_u);
} else { } else if (auto count = arr->ConstantCount()) {
if (sem->Index()->ConstantValue()) { if (sem->Index()->ConstantValue()) {
// Index and size is constant. // Index and size is constant.
// Validation will have rejected any OOB accesses. // Validation will have rejected any OOB accesses.
return nullptr; return nullptr;
} }
max = b.Expr(u32(arr->Count() - 1u)); max = b.Expr(u32(count.value() - 1u));
} else {
// Note: Don't be tempted to use the array override variable as an expression
// here, the name might be shadowed!
ctx.dst->Diagnostics().add_error(diag::System::Transform,
sem::Array::kErrExpectedConstantCount);
return nullptr;
} }
return b.Call("min", idx(), max); return b.Call("min", idx(), max);
}, },
[&](Default) { [&](Default) {

View File

@ -1316,5 +1316,23 @@ fn f() {
EXPECT_EQ(expect, str(got)); EXPECT_EQ(expect, str(got));
} }
TEST_F(RobustnessTest, WorkgroupOverrideCount) {
auto* src = R"(
override N = 123;
var<workgroup> w : array<f32, N>;
fn f() {
var b : f32 = w[1i];
}
)";
auto* expect = R"(error: array size is an override-expression, when expected a constant-expression.
Was the SubstituteOverride transform run?)";
auto got = Run<Robustness>(src);
EXPECT_EQ(expect, str(got));
}
} // namespace } // namespace
} // namespace tint::transform } // namespace tint::transform

View File

@ -35,6 +35,8 @@ TINT_INSTANTIATE_TYPEINFO(tint::transform::SpirvAtomic::Stub);
namespace tint::transform { namespace tint::transform {
using namespace tint::number_suffixes; // NOLINT
/// Private implementation of transform /// Private implementation of transform
struct SpirvAtomic::State { struct SpirvAtomic::State {
private: private:
@ -189,10 +191,19 @@ struct SpirvAtomic::State {
[&](const sem::I32*) { return b.ty.atomic(CreateASTTypeFor(ctx, ty)); }, [&](const sem::I32*) { return b.ty.atomic(CreateASTTypeFor(ctx, ty)); },
[&](const sem::U32*) { return b.ty.atomic(CreateASTTypeFor(ctx, ty)); }, [&](const sem::U32*) { return b.ty.atomic(CreateASTTypeFor(ctx, ty)); },
[&](const sem::Struct* str) { return b.ty.type_name(Fork(str->Declaration()).name); }, [&](const sem::Struct* str) { return b.ty.type_name(Fork(str->Declaration()).name); },
[&](const sem::Array* arr) { [&](const sem::Array* arr) -> const ast::Type* {
return arr->IsRuntimeSized() if (arr->IsRuntimeSized()) {
? b.ty.array(AtomicTypeFor(arr->ElemType())) return b.ty.array(AtomicTypeFor(arr->ElemType()));
: b.ty.array(AtomicTypeFor(arr->ElemType()), u32(arr->Count())); }
auto count = arr->ConstantCount();
if (!count) {
ctx.dst->Diagnostics().add_error(
diag::System::Transform,
"the SpirvAtomic transform does not currently support array counts that "
"use override values");
count = 1;
}
return b.ty.array(AtomicTypeFor(arr->ElemType()), u32(count.value()));
}, },
[&](const sem::Pointer* ptr) { [&](const sem::Pointer* ptr) {
return b.ty.pointer(AtomicTypeFor(ptr->StoreType()), ptr->StorageClass(), return b.ty.pointer(AtomicTypeFor(ptr->StoreType()), ptr->StorageClass(),

View File

@ -423,7 +423,17 @@ struct Std140::State {
if (!arr->IsStrideImplicit()) { if (!arr->IsStrideImplicit()) {
attrs.Push(ctx.dst->create<ast::StrideAttribute>(arr->Stride())); attrs.Push(ctx.dst->create<ast::StrideAttribute>(arr->Stride()));
} }
return b.create<ast::Array>(std140, b.Expr(u32(arr->Count())), auto count = arr->ConstantCount();
if (!count) {
// Non-constant counts should not be possible:
// * Override-expression counts can only be applied to workgroup arrays, and
// this method only handles types transitively used as uniform buffers.
// * Runtime-sized arrays cannot be used in uniform buffers.
TINT_ICE(Transform, ctx.dst->Diagnostics())
<< "unexpected non-constant array count";
count = 1;
}
return b.create<ast::Array>(std140, b.Expr(u32(count.value())),
std::move(attrs)); std::move(attrs));
} }
return nullptr; return nullptr;
@ -613,7 +623,17 @@ struct Std140::State {
ty, // ty, //
[&](const sem::Struct* str) { return sym.NameFor(str->Name()); }, [&](const sem::Struct* str) { return sym.NameFor(str->Name()); },
[&](const sem::Array* arr) { [&](const sem::Array* arr) {
return "arr" + std::to_string(arr->Count()) + "_" + ConvertSuffix(arr->ElemType()); auto count = arr->ConstantCount();
if (!count) {
// Non-constant counts should not be possible:
// * Override-expression counts can only be applied to workgroup arrays, and
// this method only handles types transitively used as uniform buffers.
// * Runtime-sized arrays cannot be used in uniform buffers.
TINT_ICE(Transform, ctx.dst->Diagnostics())
<< "unexpected non-constant array count";
count = 1;
}
return "arr" + std::to_string(count.value()) + "_" + ConvertSuffix(arr->ElemType());
}, },
[&](const sem::Matrix* mat) { [&](const sem::Matrix* mat) {
return "mat" + std::to_string(mat->columns()) + "x" + std::to_string(mat->rows()) + return "mat" + std::to_string(mat->columns()) + "x" + std::to_string(mat->rows()) +
@ -710,10 +730,20 @@ struct Std140::State {
auto* i = b.Var("i", b.ty.u32()); auto* i = b.Var("i", b.ty.u32());
auto* dst_el = b.IndexAccessor(var, i); auto* dst_el = b.IndexAccessor(var, i);
auto* src_el = Convert(arr->ElemType(), b.IndexAccessor(param, i)); auto* src_el = Convert(arr->ElemType(), b.IndexAccessor(param, i));
auto count = arr->ConstantCount();
if (!count) {
// Non-constant counts should not be possible:
// * Override-expression counts can only be applied to workgroup arrays, and
// this method only handles types transitively used as uniform buffers.
// * Runtime-sized arrays cannot be used in uniform buffers.
TINT_ICE(Transform, ctx.dst->Diagnostics())
<< "unexpected non-constant array count";
count = 1;
}
stmts.Push(b.Decl(var)); stmts.Push(b.Decl(var));
stmts.Push(b.For(b.Decl(i), // stmts.Push(b.For(b.Decl(i), //
b.LessThan(i, u32(arr->Count())), // b.LessThan(i, u32(count.value())), //
b.Assign(i, b.Add(i, 1_a)), // b.Assign(i, b.Add(i, 1_a)), //
b.Block(b.Assign(dst_el, src_el)))); b.Block(b.Assign(dst_el, src_el))));
stmts.Push(b.Return(var)); stmts.Push(b.Return(var));
}, },

View File

@ -24,6 +24,7 @@
#include "src/tint/sem/for_loop_statement.h" #include "src/tint/sem/for_loop_statement.h"
#include "src/tint/sem/reference.h" #include "src/tint/sem/reference.h"
#include "src/tint/sem/sampler.h" #include "src/tint/sem/sampler.h"
#include "src/tint/sem/variable.h"
TINT_INSTANTIATE_TYPEINFO(tint::transform::Transform); TINT_INSTANTIATE_TYPEINFO(tint::transform::Transform);
TINT_INSTANTIATE_TYPEINFO(tint::transform::Data); TINT_INSTANTIATE_TYPEINFO(tint::transform::Data);
@ -112,9 +113,16 @@ const ast::Type* Transform::CreateASTTypeFor(CloneContext& ctx, const sem::Type*
} }
if (a->IsRuntimeSized()) { if (a->IsRuntimeSized()) {
return ctx.dst->ty.array(el, nullptr, std::move(attrs)); return ctx.dst->ty.array(el, nullptr, std::move(attrs));
} else {
return ctx.dst->ty.array(el, u32(a->Count()), std::move(attrs));
} }
if (auto* override = std::get_if<sem::OverrideArrayCount>(&a->Count())) {
auto* count = ctx.Clone(override->variable->Declaration());
return ctx.dst->ty.array(el, count, std::move(attrs));
}
if (auto count = a->ConstantCount()) {
return ctx.dst->ty.array(el, u32(count.value()), std::move(attrs));
}
TINT_ICE(Transform, ctx.dst->Diagnostics()) << sem::Array::kErrExpectedConstantCount;
return ctx.dst->ty.array(el, u32(1), std::move(attrs));
} }
if (auto* s = ty->As<sem::Struct>()) { if (auto* s = ty->As<sem::Struct>()) {
return ctx.dst->create<ast::TypeName>(ctx.Clone(s->Declaration()->name)); return ctx.dst->create<ast::TypeName>(ctx.Clone(s->Declaration()->name));

View File

@ -65,7 +65,8 @@ TEST_F(CreateASTTypeForTest, Vector) {
TEST_F(CreateASTTypeForTest, ArrayImplicitStride) { TEST_F(CreateASTTypeForTest, ArrayImplicitStride) {
auto* arr = create([](ProgramBuilder& b) { auto* arr = create([](ProgramBuilder& b) {
return b.create<sem::Array>(b.create<sem::F32>(), 2u, 4u, 4u, 32u, 32u); return b.create<sem::Array>(b.create<sem::F32>(), sem::ConstantArrayCount{2u}, 4u, 4u, 32u,
32u);
}); });
ASSERT_TRUE(arr->Is<ast::Array>()); ASSERT_TRUE(arr->Is<ast::Array>());
ASSERT_TRUE(arr->As<ast::Array>()->type->Is<ast::F32>()); ASSERT_TRUE(arr->As<ast::Array>()->type->Is<ast::F32>());
@ -78,7 +79,8 @@ TEST_F(CreateASTTypeForTest, ArrayImplicitStride) {
TEST_F(CreateASTTypeForTest, ArrayNonImplicitStride) { TEST_F(CreateASTTypeForTest, ArrayNonImplicitStride) {
auto* arr = create([](ProgramBuilder& b) { auto* arr = create([](ProgramBuilder& b) {
return b.create<sem::Array>(b.create<sem::F32>(), 2u, 4u, 4u, 64u, 32u); return b.create<sem::Array>(b.create<sem::F32>(), sem::ConstantArrayCount{2u}, 4u, 4u, 64u,
32u);
}); });
ASSERT_TRUE(arr->Is<ast::Array>()); ASSERT_TRUE(arr->Is<ast::Array>());
ASSERT_TRUE(arr->As<ast::Array>()->type->Is<ast::F32>()); ASSERT_TRUE(arr->As<ast::Array>()->type->Is<ast::F32>());

View File

@ -307,7 +307,13 @@ struct ZeroInitWorkgroupMemory::State {
// `num_values * arr->Count()` // `num_values * arr->Count()`
// The index for this array is: // The index for this array is:
// `(idx % modulo) / division` // `(idx % modulo) / division`
auto modulo = num_values * arr->Count(); auto count = arr->ConstantCount();
if (!count) {
ctx.dst->Diagnostics().add_error(diag::System::Transform,
sem::Array::kErrExpectedConstantCount);
return Expression{};
}
auto modulo = num_values * count.value();
auto division = num_values; auto division = num_values;
auto a = get_expr(modulo); auto a = get_expr(modulo);
auto array_indices = a.array_indices; auto array_indices = a.array_indices;

View File

@ -2240,7 +2240,13 @@ bool GeneratorImpl::EmitConstant(std::ostream& out, const sem::Constant* constan
ScopedParen sp(out); ScopedParen sp(out);
for (size_t i = 0; i < a->Count(); i++) { auto count = a->ConstantCount();
if (!count) {
diagnostics_.add_error(diag::System::Writer, sem::Array::kErrExpectedConstantCount);
return false;
}
for (size_t i = 0; i < count; i++) {
if (i > 0) { if (i > 0) {
out << ", "; out << ", ";
} }
@ -2356,16 +2362,23 @@ bool GeneratorImpl::EmitZeroValue(std::ostream& out, const sem::Type* type) {
} }
EmitZeroValue(out, member->Type()); EmitZeroValue(out, member->Type());
} }
} else if (auto* array = type->As<sem::Array>()) { } else if (auto* arr = type->As<sem::Array>()) {
if (!EmitType(out, type, ast::StorageClass::kNone, ast::Access::kUndefined, "")) { if (!EmitType(out, type, ast::StorageClass::kNone, ast::Access::kUndefined, "")) {
return false; return false;
} }
ScopedParen sp(out); ScopedParen sp(out);
for (uint32_t i = 0; i < array->Count(); i++) {
auto count = arr->ConstantCount();
if (!count) {
diagnostics_.add_error(diag::System::Writer, sem::Array::kErrExpectedConstantCount);
return false;
}
for (uint32_t i = 0; i < count; i++) {
if (i != 0) { if (i != 0) {
out << ", "; out << ", ";
} }
EmitZeroValue(out, array->ElemType()); EmitZeroValue(out, arr->ElemType());
} }
} else { } else {
diagnostics_.add_error(diag::System::Writer, "Invalid type for zero emission: " + diagnostics_.add_error(diag::System::Writer, "Invalid type for zero emission: " +
@ -2697,7 +2710,18 @@ bool GeneratorImpl::EmitType(std::ostream& out,
const sem::Type* base_type = ary; const sem::Type* base_type = ary;
std::vector<uint32_t> sizes; std::vector<uint32_t> sizes;
while (auto* arr = base_type->As<sem::Array>()) { while (auto* arr = base_type->As<sem::Array>()) {
sizes.push_back(arr->Count()); if (arr->IsRuntimeSized()) {
sizes.push_back(0);
} else {
auto count = arr->ConstantCount();
if (!count) {
diagnostics_.add_error(diag::System::Writer,
sem::Array::kErrExpectedConstantCount);
return false;
}
sizes.push_back(count.value());
}
base_type = arr->ElemType(); base_type = arr->ElemType();
} }
if (!EmitType(out, base_type, storage_class, access, "")) { if (!EmitType(out, base_type, storage_class, access, "")) {

View File

@ -3193,7 +3193,13 @@ bool GeneratorImpl::EmitConstant(std::ostream& out, const sem::Constant* constan
out << "{"; out << "{";
TINT_DEFER(out << "}"); TINT_DEFER(out << "}");
for (size_t i = 0; i < a->Count(); i++) { auto count = a->ConstantCount();
if (!count) {
diagnostics_.add_error(diag::System::Writer, sem::Array::kErrExpectedConstantCount);
return false;
}
for (size_t i = 0; i < count; i++) {
if (i > 0) { if (i > 0) {
out << ", "; out << ", ";
} }
@ -3732,11 +3738,18 @@ bool GeneratorImpl::EmitType(std::ostream& out,
while (auto* arr = base_type->As<sem::Array>()) { while (auto* arr = base_type->As<sem::Array>()) {
if (arr->IsRuntimeSized()) { if (arr->IsRuntimeSized()) {
TINT_ICE(Writer, diagnostics_) TINT_ICE(Writer, diagnostics_)
<< "Runtime arrays may only exist in storage buffers, which should have " << "runtime arrays may only exist in storage buffers, which should have "
"been transformed into a ByteAddressBuffer"; "been transformed into a ByteAddressBuffer";
return false; return false;
} }
sizes.push_back(arr->Count()); const auto count = arr->ConstantCount();
if (!count) {
diagnostics_.add_error(diag::System::Writer,
sem::Array::kErrExpectedConstantCount);
return false;
}
sizes.push_back(count.value());
base_type = arr->ElemType(); base_type = arr->ElemType();
} }
if (!EmitType(out, base_type, storage_class, access, "")) { if (!EmitType(out, base_type, storage_class, access, "")) {

View File

@ -1690,7 +1690,13 @@ bool GeneratorImpl::EmitConstant(std::ostream& out, const sem::Constant* constan
return true; return true;
} }
for (size_t i = 0; i < a->Count(); i++) { auto count = a->ConstantCount();
if (!count) {
diagnostics_.add_error(diag::System::Writer, sem::Array::kErrExpectedConstantCount);
return false;
}
for (size_t i = 0; i < count; i++) {
if (i > 0) { if (i > 0) {
out << ", "; out << ", ";
} }
@ -2481,7 +2487,20 @@ bool GeneratorImpl::EmitType(std::ostream& out,
if (!EmitType(out, arr->ElemType(), "")) { if (!EmitType(out, arr->ElemType(), "")) {
return false; return false;
} }
out << ", " << (arr->IsRuntimeSized() ? 1u : arr->Count()) << ">"; out << ", ";
if (arr->IsRuntimeSized()) {
out << "1";
} else {
auto count = arr->ConstantCount();
if (!count) {
diagnostics_.add_error(diag::System::Writer,
sem::Array::kErrExpectedConstantCount);
return false;
}
out << count.value();
}
out << ">";
return true; return true;
}, },
[&](const sem::Bool*) { [&](const sem::Bool*) {
@ -3133,8 +3152,14 @@ GeneratorImpl::SizeAndAlign GeneratorImpl::MslPackedTypeSizeAndAlign(const sem::
<< "arrays with explicit strides should not exist past the SPIR-V reader"; << "arrays with explicit strides should not exist past the SPIR-V reader";
return SizeAndAlign{}; return SizeAndAlign{};
} }
auto num_els = std::max<uint32_t>(arr->Count(), 1); if (arr->IsRuntimeSized()) {
return SizeAndAlign{arr->Stride() * num_els, arr->Align()}; return SizeAndAlign{arr->Stride(), arr->Align()};
}
if (auto count = arr->ConstantCount()) {
return SizeAndAlign{arr->Stride() * count.value(), arr->Align()};
}
diagnostics_.add_error(diag::System::Writer, sem::Array::kErrExpectedConstantCount);
return SizeAndAlign{};
}, },
[&](const sem::Struct* str) { [&](const sem::Struct* str) {

View File

@ -1676,11 +1676,18 @@ uint32_t Builder::GenerateConstantIfNeeded(const sem::Constant* constant) {
}, },
[&](const sem::Vector* v) { return composite(v->Width()); }, [&](const sem::Vector* v) { return composite(v->Width()); },
[&](const sem::Matrix* m) { return composite(m->columns()); }, [&](const sem::Matrix* m) { return composite(m->columns()); },
[&](const sem::Array* a) { return composite(a->Count()); }, [&](const sem::Array* a) {
auto count = a->ConstantCount();
if (!count) {
error_ = sem::Array::kErrExpectedConstantCount;
return static_cast<uint32_t>(0);
}
return composite(count.value());
},
[&](const sem::Struct* s) { return composite(s->Members().size()); }, [&](const sem::Struct* s) { return composite(s->Members().size()); },
[&](Default) { [&](Default) {
error_ = "unhandled constant type: " + builder_.FriendlyName(ty); error_ = "unhandled constant type: " + builder_.FriendlyName(ty);
return false; return 0;
}); });
} }
@ -3852,17 +3859,23 @@ bool Builder::GenerateTextureType(const sem::Texture* texture, const Operand& re
return true; return true;
} }
bool Builder::GenerateArrayType(const sem::Array* ary, const Operand& result) { bool Builder::GenerateArrayType(const sem::Array* arr, const Operand& result) {
auto elem_type = GenerateTypeIfNeeded(ary->ElemType()); auto elem_type = GenerateTypeIfNeeded(arr->ElemType());
if (elem_type == 0) { if (elem_type == 0) {
return false; return false;
} }
auto result_id = std::get<uint32_t>(result); auto result_id = std::get<uint32_t>(result);
if (ary->IsRuntimeSized()) { if (arr->IsRuntimeSized()) {
push_type(spv::Op::OpTypeRuntimeArray, {result, Operand(elem_type)}); push_type(spv::Op::OpTypeRuntimeArray, {result, Operand(elem_type)});
} else { } else {
auto len_id = GenerateConstantIfNeeded(ScalarConstant::U32(ary->Count())); auto count = arr->ConstantCount();
if (!count) {
error_ = sem::Array::kErrExpectedConstantCount;
return static_cast<uint32_t>(0);
}
auto len_id = GenerateConstantIfNeeded(ScalarConstant::U32(count.value()));
if (len_id == 0) { if (len_id == 0) {
return false; return false;
} }
@ -3871,7 +3884,7 @@ bool Builder::GenerateArrayType(const sem::Array* ary, const Operand& result) {
} }
push_annot(spv::Op::OpDecorate, push_annot(spv::Op::OpDecorate,
{Operand(result_id), U32Operand(SpvDecorationArrayStride), Operand(ary->Stride())}); {Operand(result_id), U32Operand(SpvDecorationArrayStride), Operand(arr->Stride())});
return true; return true;
} }

View File

@ -0,0 +1,5 @@
// flags: --transform substitute_override
override size = 2;
var<workgroup> a : array<f32, size>;

View File

@ -0,0 +1,6 @@
[numthreads(1, 1, 1)]
void unused_entry_point() {
return;
}
groupshared float a[2];

View File

@ -0,0 +1,6 @@
[numthreads(1, 1, 1)]
void unused_entry_point() {
return;
}
groupshared float a[2];

View File

@ -0,0 +1,7 @@
#version 310 es
layout(local_size_x = 1, local_size_y = 1, local_size_z = 1) in;
void unused_entry_point() {
return;
}
shared float a[2];

View File

@ -0,0 +1,3 @@
#include <metal_stdlib>
using namespace metal;

View File

@ -0,0 +1,24 @@
; SPIR-V
; Version: 1.3
; Generator: Google Tint Compiler; 0
; Bound: 11
; Schema: 0
OpCapability Shader
OpMemoryModel Logical GLSL450
OpEntryPoint GLCompute %unused_entry_point "unused_entry_point"
OpExecutionMode %unused_entry_point LocalSize 1 1 1
OpName %a "a"
OpName %unused_entry_point "unused_entry_point"
OpDecorate %_arr_float_uint_2 ArrayStride 4
%float = OpTypeFloat 32
%uint = OpTypeInt 32 0
%uint_2 = OpConstant %uint 2
%_arr_float_uint_2 = OpTypeArray %float %uint_2
%_ptr_Workgroup__arr_float_uint_2 = OpTypePointer Workgroup %_arr_float_uint_2
%a = OpVariable %_ptr_Workgroup__arr_float_uint_2 Workgroup
%void = OpTypeVoid
%7 = OpTypeFunction %void
%unused_entry_point = OpFunction %void None %7
%10 = OpLabel
OpReturn
OpFunctionEnd

View File

@ -0,0 +1,3 @@
const size = 2;
var<workgroup> a : array<f32, size>;