Fix overrides in array size.

This CL fixes the usage of overrides in array sizes. Currently
the usage will generate a validation error as we check that the
array size is const.

Bug: tint:1660
Change-Id: Ibf440905c30a73b581d55b0c071b8621b61605e6
Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/101900
Kokoro: Kokoro <noreply+kokoro@google.com>
Reviewed-by: dan sinclair <dsinclair@google.com>
Reviewed-by: Antonio Maiorano <amaiorano@google.com>
Commit-Queue: Ben Clayton <bclayton@chromium.org>
This commit is contained in:
dan sinclair 2022-09-22 22:28:21 +00:00 committed by Dawn LUCI CQ
parent 534a198f88
commit 78f8067fd5
42 changed files with 839 additions and 186 deletions

View File

@ -394,8 +394,10 @@ const ImplConstant* ZeroValue(ProgramBuilder& builder, const sem::Type* type) {
return builder.create<Splat>(type, zero_el, m->columns());
},
[&](const sem::Array* a) -> const ImplConstant* {
if (auto n = a->ConstantCount()) {
if (auto* zero_el = ZeroValue(builder, a->ElemType())) {
return builder.create<Splat>(type, zero_el, a->Count());
return builder.create<Splat>(type, zero_el, n.value());
}
}
return nullptr;
},
@ -451,12 +453,16 @@ bool Equal(const sem::Constant* a, const sem::Constant* b) {
return true;
},
[&](const sem::Array* arr) {
for (size_t i = 0; i < arr->Count(); i++) {
if (auto count = arr->ConstantCount()) {
for (size_t i = 0; i < count; i++) {
if (!Equal(a->Index(i), b->Index(i))) {
return false;
}
}
return true;
}
return false;
},
[&](Default) { return a->Value() == b->Value(); });
}

View File

@ -1700,7 +1700,7 @@ TEST_F(ResolverConstEvalTest, Array_i32_Zero) {
auto* arr = sem->Type()->As<sem::Array>();
ASSERT_NE(arr, nullptr);
EXPECT_TRUE(arr->ElemType()->Is<sem::I32>());
EXPECT_EQ(arr->Count(), 4u);
EXPECT_EQ(arr->Count(), sem::ConstantArrayCount{4u});
EXPECT_TYPE(sem->ConstantValue()->Type(), sem->Type());
EXPECT_TRUE(sem->ConstantValue()->AllEqual());
EXPECT_TRUE(sem->ConstantValue()->AnyZero());
@ -1738,7 +1738,7 @@ TEST_F(ResolverConstEvalTest, Array_f32_Zero) {
auto* arr = sem->Type()->As<sem::Array>();
ASSERT_NE(arr, nullptr);
EXPECT_TRUE(arr->ElemType()->Is<sem::F32>());
EXPECT_EQ(arr->Count(), 4u);
EXPECT_EQ(arr->Count(), sem::ConstantArrayCount{4u});
EXPECT_TYPE(sem->ConstantValue()->Type(), sem->Type());
EXPECT_TRUE(sem->ConstantValue()->AllEqual());
EXPECT_TRUE(sem->ConstantValue()->AnyZero());
@ -1776,7 +1776,7 @@ TEST_F(ResolverConstEvalTest, Array_vec3_f32_Zero) {
auto* arr = sem->Type()->As<sem::Array>();
ASSERT_NE(arr, nullptr);
EXPECT_TRUE(arr->ElemType()->Is<sem::Vector>());
EXPECT_EQ(arr->Count(), 2u);
EXPECT_EQ(arr->Count(), sem::ConstantArrayCount{2u});
EXPECT_TYPE(sem->ConstantValue()->Type(), sem->Type());
EXPECT_TRUE(sem->ConstantValue()->AllEqual());
EXPECT_TRUE(sem->ConstantValue()->AnyZero());
@ -1828,7 +1828,7 @@ TEST_F(ResolverConstEvalTest, Array_Struct_f32_Zero) {
auto* arr = sem->Type()->As<sem::Array>();
ASSERT_NE(arr, nullptr);
EXPECT_TRUE(arr->ElemType()->Is<sem::Struct>());
EXPECT_EQ(arr->Count(), 2u);
EXPECT_EQ(arr->Count(), sem::ConstantArrayCount{2u});
EXPECT_TYPE(sem->ConstantValue()->Type(), sem->Type());
EXPECT_TRUE(sem->ConstantValue()->AllEqual());
EXPECT_TRUE(sem->ConstantValue()->AnyZero());
@ -1866,7 +1866,7 @@ TEST_F(ResolverConstEvalTest, Array_i32_Elements) {
auto* arr = sem->Type()->As<sem::Array>();
ASSERT_NE(arr, nullptr);
EXPECT_TRUE(arr->ElemType()->Is<sem::I32>());
EXPECT_EQ(arr->Count(), 4u);
EXPECT_EQ(arr->Count(), sem::ConstantArrayCount{4u});
EXPECT_TYPE(sem->ConstantValue()->Type(), sem->Type());
EXPECT_FALSE(sem->ConstantValue()->AllEqual());
EXPECT_FALSE(sem->ConstantValue()->AnyZero());
@ -1904,7 +1904,7 @@ TEST_F(ResolverConstEvalTest, Array_f32_Elements) {
auto* arr = sem->Type()->As<sem::Array>();
ASSERT_NE(arr, nullptr);
EXPECT_TRUE(arr->ElemType()->Is<sem::F32>());
EXPECT_EQ(arr->Count(), 4u);
EXPECT_EQ(arr->Count(), sem::ConstantArrayCount{4u});
EXPECT_TYPE(sem->ConstantValue()->Type(), sem->Type());
EXPECT_FALSE(sem->ConstantValue()->AllEqual());
EXPECT_FALSE(sem->ConstantValue()->AnyZero());
@ -1943,7 +1943,7 @@ TEST_F(ResolverConstEvalTest, Array_vec3_f32_Elements) {
auto* arr = sem->Type()->As<sem::Array>();
ASSERT_NE(arr, nullptr);
EXPECT_TRUE(arr->ElemType()->Is<sem::Vector>());
EXPECT_EQ(arr->Count(), 2u);
EXPECT_EQ(arr->Count(), sem::ConstantArrayCount{2u});
EXPECT_TYPE(sem->ConstantValue()->Type(), sem->Type());
EXPECT_FALSE(sem->ConstantValue()->AllEqual());
EXPECT_FALSE(sem->ConstantValue()->AnyZero());
@ -1973,7 +1973,7 @@ TEST_F(ResolverConstEvalTest, Array_Struct_f32_Elements) {
auto* arr = sem->Type()->As<sem::Array>();
ASSERT_NE(arr, nullptr);
EXPECT_TRUE(arr->ElemType()->Is<sem::Struct>());
EXPECT_EQ(arr->Count(), 2u);
EXPECT_EQ(arr->Count(), sem::ConstantArrayCount{2u});
EXPECT_TYPE(sem->ConstantValue()->Type(), sem->Type());
EXPECT_FALSE(sem->ConstantValue()->AllEqual());
EXPECT_FALSE(sem->ConstantValue()->AnyZero());

View File

@ -877,7 +877,7 @@ TEST_F(ResolverFunctionValidationTest, ParameterStoreType_NonAtomicFree) {
Member("m", ty.atomic(ty.i32())),
});
auto* ret_type = ty.type_name(Source{{12, 34}}, "S");
auto* bar = Param(Source{{12, 34}}, "bar", ret_type);
auto* bar = Param("bar", ret_type);
Func("f", utils::Vector{bar}, ty.void_(), utils::Empty);
EXPECT_FALSE(r()->Resolve());

View File

@ -135,7 +135,8 @@ INSTANTIATE_TEST_SUITE_P(ResolverTest, ResolverInferredTypeParamTest, testing::V
TEST_F(ResolverInferredTypeTest, InferArray_Pass) {
auto* type = ty.array(ty.u32(), 10_u);
auto* expected_type = create<sem::Array>(create<sem::U32>(), 10u, 4u, 4u * 10u, 4u, 4u);
auto* expected_type =
create<sem::Array>(create<sem::U32>(), sem::ConstantArrayCount{10u}, 4u, 4u * 10u, 4u, 4u);
auto* ctor_expr = Construct(type);
auto* var = Var("a", ast::StorageClass::kFunction, ctor_expr);

View File

@ -513,7 +513,7 @@ bool match_array(const sem::Type* ty, const sem::Type*& T) {
}
if (auto* a = ty->As<sem::Array>()) {
if (a->Count() == 0) {
if (a->IsRuntimeSized()) {
T = a->ElemType();
return true;
}
@ -523,7 +523,7 @@ bool match_array(const sem::Type* ty, const sem::Type*& T) {
const sem::Array* build_array(MatchState& state, const sem::Type* el) {
return state.builder.create<sem::Array>(el,
/* count */ 0u,
/* count */ sem::RuntimeArrayCount{},
/* align */ 0u,
/* size */ 0u,
/* stride */ 0u,

View File

@ -235,7 +235,7 @@ TEST_F(IntrinsicTableTest, MismatchPointer) {
}
TEST_F(IntrinsicTableTest, MatchArray) {
auto* arr = create<sem::Array>(create<sem::U32>(), 0u, 4u, 4u, 4u, 4u);
auto* arr = create<sem::Array>(create<sem::U32>(), sem::RuntimeArrayCount{}, 4u, 4u, 4u, 4u);
auto* arr_ptr = create<sem::Pointer>(arr, ast::StorageClass::kStorage, ast::Access::kReadWrite);
auto result = table->Lookup(BuiltinType::kArrayLength, utils::Vector{arr_ptr}, Source{});
ASSERT_NE(result.sem, nullptr) << Diagnostics().str();
@ -798,7 +798,7 @@ TEST_F(IntrinsicTableTest, MatchTypeConversion) {
}
TEST_F(IntrinsicTableTest, MismatchTypeConversion) {
auto* arr = create<sem::Array>(create<sem::U32>(), 0u, 4u, 4u, 4u, 4u);
auto* arr = create<sem::Array>(create<sem::U32>(), sem::RuntimeArrayCount{}, 4u, 4u, 4u, 4u);
auto* f32 = create<sem::F32>();
auto result =
table->Lookup(CtorConvIntrinsic::kVec3, f32, utils::Vector{arr}, Source{{12, 34}});

View File

@ -106,12 +106,13 @@ TEST_F(ResolverIsHostShareable, Atomic) {
}
TEST_F(ResolverIsHostShareable, ArraySizedOfHostShareable) {
auto* arr = create<sem::Array>(create<sem::I32>(), 5u, 4u, 20u, 4u, 4u);
auto* arr =
create<sem::Array>(create<sem::I32>(), sem::ConstantArrayCount{5u}, 4u, 20u, 4u, 4u);
EXPECT_TRUE(r()->IsHostShareable(arr));
}
TEST_F(ResolverIsHostShareable, ArrayUnsizedOfHostShareable) {
auto* arr = create<sem::Array>(create<sem::I32>(), 0u, 4u, 4u, 4u, 4u);
auto* arr = create<sem::Array>(create<sem::I32>(), sem::RuntimeArrayCount{}, 4u, 4u, 4u, 4u);
EXPECT_TRUE(r()->IsHostShareable(arr));
}

View File

@ -89,12 +89,13 @@ TEST_F(ResolverIsStorableTest, Atomic) {
}
TEST_F(ResolverIsStorableTest, ArraySizedOfStorable) {
auto* arr = create<sem::Array>(create<sem::I32>(), 5u, 4u, 20u, 4u, 4u);
auto* arr =
create<sem::Array>(create<sem::I32>(), sem::ConstantArrayCount{5u}, 4u, 20u, 4u, 4u);
EXPECT_TRUE(r()->IsStorable(arr));
}
TEST_F(ResolverIsStorableTest, ArrayUnsizedOfStorable) {
auto* arr = create<sem::Array>(create<sem::I32>(), 0u, 4u, 4u, 4u, 4u);
auto* arr = create<sem::Array>(create<sem::I32>(), sem::RuntimeArrayCount{}, 4u, 4u, 4u, 4u);
EXPECT_TRUE(r()->IsStorable(arr));
}

View File

@ -118,7 +118,9 @@ class MaterializeTest : public resolver::ResolverTestWithParam<CASE> {
}
},
[&](const sem::Array* a) {
for (uint32_t i = 0; i < a->Count(); i++) {
auto count = a->ConstantCount();
ASSERT_NE(count, 0u);
for (uint32_t i = 0; i < count; i++) {
auto* el = value->Index(i);
ASSERT_NE(el, nullptr);
EXPECT_TYPE(el->Type(), a->ElemType());

View File

@ -1489,7 +1489,7 @@ const sem::Type* Resolver::ConcreteType(const sem::Type* ty,
target_el_ty = target_arr_ty->ElemType();
}
if (auto* el_ty = ConcreteType(a->ElemType(), target_el_ty, source)) {
return Array(source, el_ty, a->Count(), /* explicit_stride */ 0);
return Array(source, source, el_ty, a->Count(), /* explicit_stride */ 0);
}
return nullptr;
});
@ -1879,7 +1879,8 @@ sem::Call* Resolver::Call(const ast::CallExpression* expr) {
[&](const ast::Array* a) -> sem::Call* {
Mark(a);
// array element type must be inferred if it was not specified.
auto el_count = static_cast<uint32_t>(args.Length());
sem::ArrayCount el_count =
sem::ConstantArrayCount{static_cast<uint32_t>(args.Length())};
const sem::Type* el_ty = nullptr;
if (a->type) {
el_ty = Type(a->type);
@ -1921,7 +1922,9 @@ sem::Call* Resolver::Call(const ast::CallExpression* expr) {
return nullptr;
}
auto* arr = Array(a->source, el_ty, el_count, explicit_stride);
auto* arr = Array(a->type ? a->type->source : a->source,
a->count ? a->count->source : a->source, //
el_ty, el_count, explicit_stride);
if (!arr) {
return nullptr;
}
@ -2591,7 +2594,7 @@ sem::Array* Resolver::Array(const ast::Array* arr) {
return nullptr;
}
uint32_t el_count = 0; // sem::Array uses a size of 0 for a runtime-sized array.
sem::ArrayCount el_count = sem::RuntimeArrayCount{};
// Evaluate the constant array size expression.
if (auto* count_expr = arr->count) {
@ -2602,7 +2605,9 @@ sem::Array* Resolver::Array(const ast::Array* arr) {
}
}
auto* out = Array(arr->source, el_ty, el_count, explicit_stride);
auto* out = Array(arr->type->source, //
arr->count ? arr->count->source : arr->source, //
el_ty, el_count, explicit_stride);
if (out == nullptr) {
return nullptr;
}
@ -2619,16 +2624,27 @@ sem::Array* Resolver::Array(const ast::Array* arr) {
return out;
}
utils::Result<uint32_t> Resolver::ArrayCount(const ast::Expression* count_expr) {
utils::Result<sem::ArrayCount> Resolver::ArrayCount(const ast::Expression* count_expr) {
// Evaluate the constant array size expression.
const auto* count_sem = Materialize(Expression(count_expr));
if (!count_sem) {
return utils::Failure;
}
// Note: If the array count is an 'override', but not a identifier expression, we do not return
// here, but instead continue to the ConstantValue() check below.
if (auto* user = count_sem->UnwrapMaterialize()->As<sem::VariableUser>()) {
if (auto* global = user->Variable()->As<sem::GlobalVariable>()) {
if (global->Declaration()->Is<ast::Override>()) {
return sem::ArrayCount{sem::OverrideArrayCount{global}};
}
}
}
auto* count_val = count_sem->ConstantValue();
if (!count_val) {
AddError("array size must evaluate to a constant integer expression", count_expr->source);
AddError("array size must evaluate to a constant integer expression or override variable",
count_expr->source);
return utils::Failure;
}
@ -2646,7 +2662,7 @@ utils::Result<uint32_t> Resolver::ArrayCount(const ast::Expression* count_expr)
return utils::Failure;
}
return static_cast<uint32_t>(count);
return sem::ArrayCount{sem::ConstantArrayCount{static_cast<uint32_t>(count)}};
}
bool Resolver::ArrayAttributes(utils::VectorRef<const ast::Attribute*> attributes,
@ -2673,27 +2689,33 @@ bool Resolver::ArrayAttributes(utils::VectorRef<const ast::Attribute*> attribute
return true;
}
sem::Array* Resolver::Array(const Source& source,
sem::Array* Resolver::Array(const Source& el_source,
const Source& count_source,
const sem::Type* el_ty,
uint32_t el_count,
sem::ArrayCount el_count,
uint32_t explicit_stride) {
uint32_t el_align = el_ty->Align();
uint32_t el_size = el_ty->Size();
uint64_t implicit_stride = el_size ? utils::RoundUp<uint64_t>(el_align, el_size) : 0;
uint64_t stride = explicit_stride ? explicit_stride : implicit_stride;
uint64_t size = 0;
auto size = std::max<uint64_t>(el_count, 1u) * stride;
if (auto const_count = std::get_if<sem::ConstantArrayCount>(&el_count)) {
size = const_count->value * stride;
if (size > std::numeric_limits<uint32_t>::max()) {
std::stringstream msg;
msg << "array size (0x" << std::hex << size << ") must not exceed 0xffffffff bytes";
AddError(msg.str(), source);
AddError(msg.str(), count_source);
return nullptr;
}
} else if (std::holds_alternative<sem::RuntimeArrayCount>(el_count)) {
size = stride;
}
auto* out = builder_->create<sem::Array>(el_ty, el_count, el_align, static_cast<uint32_t>(size),
static_cast<uint32_t>(stride),
static_cast<uint32_t>(implicit_stride));
if (!validator_.Array(out, source)) {
if (!validator_.Array(out, el_source)) {
return nullptr;
}

View File

@ -302,7 +302,7 @@ class Resolver {
/// Resolves and validates the expression used as the count parameter of an array.
/// @param count_expr the expression used as the second template parameter to an array<>.
/// @returns the number of elements in the array.
utils::Result<uint32_t> ArrayCount(const ast::Expression* count_expr);
utils::Result<sem::ArrayCount> ArrayCount(const ast::Expression* count_expr);
/// Resolves and validates the attributes on an array.
/// @param attributes the attributes on the array type.
@ -315,13 +315,17 @@ class Resolver {
/// Builds and returns the semantic information for an array.
/// @returns the semantic Array information, or nullptr if an error is raised.
/// @param source the source of the array declaration
/// @param el_source the source of the array element, or the array if the array does not have a
/// locally-declared element AST node.
/// @param count_source the source of the array count, or the array if the array does not have a
/// locally-declared element AST node.
/// @param el_ty the Array element type
/// @param el_count the number of elements in the array. Zero means runtime-sized.
/// @param el_count the number of elements in the array.
/// @param explicit_stride the explicit byte stride of the array. Zero means implicit stride.
sem::Array* Array(const Source& source,
sem::Array* Array(const Source& el_source,
const Source& count_source,
const sem::Type* el_ty,
uint32_t el_count,
sem::ArrayCount el_count,
uint32_t explicit_stride);
/// Builds and returns the semantic information for the alias `alias`.

View File

@ -432,7 +432,7 @@ TEST_F(ResolverTest, ArraySize_UnsignedLiteral) {
auto* ref = TypeOf(a)->As<sem::Reference>();
ASSERT_NE(ref, nullptr);
auto* ary = ref->StoreType()->As<sem::Array>();
EXPECT_EQ(ary->Count(), 10u);
EXPECT_EQ(ary->Count(), sem::ConstantArrayCount{10u});
}
TEST_F(ResolverTest, ArraySize_SignedLiteral) {
@ -445,7 +445,7 @@ TEST_F(ResolverTest, ArraySize_SignedLiteral) {
auto* ref = TypeOf(a)->As<sem::Reference>();
ASSERT_NE(ref, nullptr);
auto* ary = ref->StoreType()->As<sem::Array>();
EXPECT_EQ(ary->Count(), 10u);
EXPECT_EQ(ary->Count(), sem::ConstantArrayCount{10u});
}
TEST_F(ResolverTest, ArraySize_UnsignedConst) {
@ -460,7 +460,7 @@ TEST_F(ResolverTest, ArraySize_UnsignedConst) {
auto* ref = TypeOf(a)->As<sem::Reference>();
ASSERT_NE(ref, nullptr);
auto* ary = ref->StoreType()->As<sem::Array>();
EXPECT_EQ(ary->Count(), 10u);
EXPECT_EQ(ary->Count(), sem::ConstantArrayCount{10u});
}
TEST_F(ResolverTest, ArraySize_SignedConst) {
@ -475,7 +475,51 @@ TEST_F(ResolverTest, ArraySize_SignedConst) {
auto* ref = TypeOf(a)->As<sem::Reference>();
ASSERT_NE(ref, nullptr);
auto* ary = ref->StoreType()->As<sem::Array>();
EXPECT_EQ(ary->Count(), 10u);
EXPECT_EQ(ary->Count(), sem::ConstantArrayCount{10u});
}
TEST_F(ResolverTest, ArraySize_Override) {
// override size = 0;
// var<workgroup> a : array<f32, size>;
auto* override = Override("size", Expr(10_i));
auto* a = GlobalVar("a", ty.array(ty.f32(), Expr("size")), ast::StorageClass::kWorkgroup);
EXPECT_TRUE(r()->Resolve()) << r()->error();
ASSERT_NE(TypeOf(a), nullptr);
auto* ref = TypeOf(a)->As<sem::Reference>();
ASSERT_NE(ref, nullptr);
auto* ary = ref->StoreType()->As<sem::Array>();
auto* sem_override = Sem().Get<sem::GlobalVariable>(override);
ASSERT_NE(sem_override, nullptr);
EXPECT_EQ(ary->Count(), sem::OverrideArrayCount{sem_override});
}
TEST_F(ResolverTest, ArraySize_Override_Equivalence) {
// override size = 0;
// var<workgroup> a : array<f32, size>;
// var<workgroup> b : array<f32, size>;
auto* override = Override("size", Expr(10_i));
auto* a = GlobalVar("a", ty.array(ty.f32(), Expr("size")), ast::StorageClass::kWorkgroup);
auto* b = GlobalVar("b", ty.array(ty.f32(), Expr("size")), ast::StorageClass::kWorkgroup);
EXPECT_TRUE(r()->Resolve()) << r()->error();
ASSERT_NE(TypeOf(a), nullptr);
auto* ref_a = TypeOf(a)->As<sem::Reference>();
ASSERT_NE(ref_a, nullptr);
auto* ary_a = ref_a->StoreType()->As<sem::Array>();
ASSERT_NE(TypeOf(b), nullptr);
auto* ref_b = TypeOf(b)->As<sem::Reference>();
ASSERT_NE(ref_b, nullptr);
auto* ary_b = ref_b->StoreType()->As<sem::Array>();
auto* sem_override = Sem().Get<sem::GlobalVariable>(override);
ASSERT_NE(sem_override, nullptr);
EXPECT_EQ(ary_a->Count(), sem::OverrideArrayCount{sem_override});
EXPECT_EQ(ary_b->Count(), sem::OverrideArrayCount{sem_override});
EXPECT_EQ(ary_a, ary_b);
}
TEST_F(ResolverTest, Expr_Bitcast) {

View File

@ -654,9 +654,13 @@ struct DataType<array<N, T>> {
/// @return the semantic array type
static inline const sem::Type* Sem(ProgramBuilder& b) {
auto* el = DataType<T>::Sem(b);
sem::ArrayCount count = sem::ConstantArrayCount{N};
if (N == 0) {
count = sem::RuntimeArrayCount{};
}
return b.create<sem::Array>(
/* element */ el,
/* count */ N,
/* count */ count,
/* align */ el->Align(),
/* size */ N * el->Size(),
/* stride */ el->Align(),

View File

@ -310,7 +310,8 @@ TEST_F(ResolverTypeValidationTest, ArraySize_IVecConst) {
TEST_F(ResolverTypeValidationTest, ArraySize_TooBig_ImplicitStride) {
// var<private> a : array<f32, 0x40000000u>;
GlobalVar("a", ty.array(Source{{12, 34}}, ty.f32(), 0x40000000_u), ast::StorageClass::kPrivate);
GlobalVar("a", ty.array(ty.f32(), Expr(Source{{12, 34}}, 0x40000000_u)),
ast::StorageClass::kPrivate);
EXPECT_FALSE(r()->Resolve());
EXPECT_EQ(r()->error(),
"12:34 error: array size (0x100000000) must not exceed 0xffffffff bytes");
@ -318,21 +319,157 @@ TEST_F(ResolverTypeValidationTest, ArraySize_TooBig_ImplicitStride) {
TEST_F(ResolverTypeValidationTest, ArraySize_TooBig_ExplicitStride) {
// var<private> a : @stride(8) array<f32, 0x20000000u>;
GlobalVar("a", ty.array(Source{{12, 34}}, ty.f32(), 0x20000000_u, 8),
GlobalVar("a", ty.array(ty.f32(), Expr(Source{{12, 34}}, 0x20000000_u), 8),
ast::StorageClass::kPrivate);
EXPECT_FALSE(r()->Resolve());
EXPECT_EQ(r()->error(),
"12:34 error: array size (0x100000000) must not exceed 0xffffffff bytes");
}
TEST_F(ResolverTypeValidationTest, ArraySize_Overridable) {
TEST_F(ResolverTypeValidationTest, ArraySize_Override_PrivateVar) {
// override size = 10i;
// var<private> a : array<f32, size>;
Override("size", Expr(10_i));
GlobalVar("a", ty.array(ty.f32(), Expr(Source{{12, 34}}, "size")), ast::StorageClass::kPrivate);
GlobalVar("a", ty.array(Source{{12, 34}}, ty.f32(), "size"), ast::StorageClass::kPrivate);
EXPECT_FALSE(r()->Resolve());
EXPECT_EQ(r()->error(),
"12:34 error: array size must evaluate to a constant integer expression");
"12:34 error: array with an 'override' element count can only be used as the store "
"type of a 'var<workgroup>'");
}
TEST_F(ResolverTypeValidationTest, ArraySize_Override_ComplexExpr) {
// override size = 10i;
// var<workgroup> a : array<f32, size + 1>;
Override("size", Expr(10_i));
GlobalVar("a", ty.array(ty.f32(), Add(Source{{12, 34}}, "size", 1_i)),
ast::StorageClass::kWorkgroup);
EXPECT_FALSE(r()->Resolve());
EXPECT_EQ(r()->error(),
"12:34 error: array size must evaluate to a constant integer expression or override "
"variable");
}
TEST_F(ResolverTypeValidationTest, ArraySize_Override_InArray) {
// override size = 10i;
// var<workgroup> a : array<array<f32, size>, 4>;
Override("size", Expr(10_i));
GlobalVar("a", ty.array(ty.array(Source{{12, 34}}, ty.f32(), "size"), 4_a),
ast::StorageClass::kWorkgroup);
EXPECT_FALSE(r()->Resolve());
EXPECT_EQ(r()->error(),
"12:34 error: array with an 'override' element count can only be used as the store "
"type of a 'var<workgroup>'");
}
TEST_F(ResolverTypeValidationTest, ArraySize_Override_InStruct) {
// override size = 10i;
// struct S {
// a : array<f32, size>
// };
Override("size", Expr(10_i));
Structure("S", utils::Vector{Member("a", ty.array(Source{{12, 34}}, ty.f32(), "size"))});
EXPECT_FALSE(r()->Resolve());
EXPECT_EQ(r()->error(),
"12:34 error: array with an 'override' element count can only be used as the store "
"type of a 'var<workgroup>'");
}
TEST_F(ResolverTypeValidationTest, ArraySize_Override_FunctionVar_Explicit) {
// override size = 10i;
// fn f() {
// var a : array<f32, size>;
// }
Override("size", Expr(10_i));
Func("f", utils::Empty, ty.void_(),
utils::Vector{
Decl(Var("a", ty.array(Source{{12, 34}}, ty.f32(), "size"))),
});
EXPECT_FALSE(r()->Resolve());
EXPECT_EQ(r()->error(),
"12:34 error: array with an 'override' element count can only be used as the store "
"type of a 'var<workgroup>'");
}
TEST_F(ResolverTypeValidationTest, ArraySize_Override_FunctionLet_Explicit) {
// override size = 10i;
// fn f() {
// var a : array<f32, size>;
// }
Override("size", Expr(10_i));
Func("f", utils::Empty, ty.void_(),
utils::Vector{
Decl(Var("a", ty.array(Source{{12, 34}}, ty.f32(), "size"))),
});
EXPECT_FALSE(r()->Resolve());
EXPECT_EQ(r()->error(),
"12:34 error: array with an 'override' element count can only be used as the store "
"type of a 'var<workgroup>'");
}
TEST_F(ResolverTypeValidationTest, ArraySize_Override_FunctionVar_Implicit) {
// override size = 10i;
// var<workgroup> w : array<f32, size>;
// fn f() {
// var a = w;
// }
Override("size", Expr(10_i));
GlobalVar("w", ty.array(ty.f32(), "size"), ast::StorageClass::kWorkgroup);
Func("f", utils::Empty, ty.void_(),
utils::Vector{
Decl(Var("a", Expr(Source{{12, 34}}, "w"))),
});
EXPECT_FALSE(r()->Resolve());
EXPECT_EQ(r()->error(),
"12:34 error: array with an 'override' element count can only be used as the store "
"type of a 'var<workgroup>'");
}
TEST_F(ResolverTypeValidationTest, ArraySize_Override_FunctionLet_Implicit) {
// override size = 10i;
// var<workgroup> w : array<f32, size>;
// fn f() {
// let a = w;
// }
Override("size", Expr(10_i));
GlobalVar("w", ty.array(ty.f32(), "size"), ast::StorageClass::kWorkgroup);
Func("f", utils::Empty, ty.void_(),
utils::Vector{
Decl(Let("a", Expr(Source{{12, 34}}, "w"))),
});
EXPECT_FALSE(r()->Resolve());
EXPECT_EQ(r()->error(),
"12:34 error: array with an 'override' element count can only be used as the store "
"type of a 'var<workgroup>'");
}
TEST_F(ResolverTypeValidationTest, ArraySize_Override_Param) {
// override size = 10i;
// fn f(a : array<f32, size>) {
// }
Override("size", Expr(10_i));
Func("f", utils::Vector{Param("a", ty.array(Source{{12, 34}}, ty.f32(), "size"))}, ty.void_(),
utils::Empty);
EXPECT_FALSE(r()->Resolve());
EXPECT_EQ(r()->error(), "12:34 error: type of function parameter must be constructible");
}
TEST_F(ResolverTypeValidationTest, ArraySize_Override_ReturnType) {
// override size = 10i;
// fn f() -> array<f32, size> {
// }
Override("size", Expr(10_i));
Func("f", utils::Empty, ty.array(Source{{12, 34}}, ty.f32(), "size"), utils::Empty);
EXPECT_FALSE(r()->Resolve());
EXPECT_EQ(r()->error(), "12:34 error: function return type must be a constructible type");
}
TEST_F(ResolverTypeValidationTest, ArraySize_Workgroup_Overridable) {
// override size = 10i;
// var<workgroup> a : array<f32, size>;
Override("size", Expr(10_i));
GlobalVar("a", ty.array(ty.f32(), Expr(Source{{12, 34}}, "size")),
ast::StorageClass::kWorkgroup);
EXPECT_TRUE(r()->Resolve()) << r()->error();
}
TEST_F(ResolverTypeValidationTest, ArraySize_ModuleVar) {
@ -367,7 +504,8 @@ TEST_F(ResolverTypeValidationTest, ArraySize_FunctionLet) {
WrapInFunction(size, a);
EXPECT_FALSE(r()->Resolve());
EXPECT_EQ(r()->error(),
"12:34 error: array size must evaluate to a constant integer expression");
"12:34 error: array size must evaluate to a constant integer expression or override "
"variable");
}
TEST_F(ResolverTypeValidationTest, ArraySize_ComplexExpr) {
@ -477,7 +615,7 @@ TEST_F(ResolverTypeValidationTest, RuntimeArrayInArray) {
// };
Structure("Foo", utils::Vector{
Member("rt", ty.array(Source{{12, 34}}, ty.array<f32>(), 4_u)),
Member("rt", ty.array(ty.array(Source{{12, 34}}, ty.f32()), 4_u)),
});
EXPECT_FALSE(r()->Resolve()) << r()->error();
@ -491,10 +629,9 @@ TEST_F(ResolverTypeValidationTest, RuntimeArrayInStructInArray) {
// };
// var<private> a : array<Foo, 4>;
auto* foo = Structure("Foo", utils::Vector{
Member("rt", ty.array<f32>()),
});
GlobalVar("v", ty.array(Source{{12, 34}}, ty.Of(foo), 4_u), ast::StorageClass::kPrivate);
Structure("Foo", utils::Vector{Member("rt", ty.array<f32>())});
GlobalVar("v", ty.array(ty.type_name(Source{{12, 34}}, "Foo"), 4_u),
ast::StorageClass::kPrivate);
EXPECT_FALSE(r()->Resolve()) << r()->error();
EXPECT_EQ(r()->error(),
@ -636,8 +773,8 @@ TEST_F(ResolverTypeValidationTest, AliasRuntimeArrayIsLast_Pass) {
}
TEST_F(ResolverTypeValidationTest, ArrayOfNonStorableType) {
auto* tex_ty = ty.sampled_texture(ast::TextureDimension::k2d, ty.f32());
GlobalVar("arr", ty.array(Source{{12, 34}}, tex_ty, 4_i), ast::StorageClass::kPrivate);
auto* tex_ty = ty.sampled_texture(Source{{12, 34}}, ast::TextureDimension::k2d, ty.f32());
GlobalVar("arr", ty.array(tex_ty, 4_i), ast::StorageClass::kPrivate);
EXPECT_FALSE(r()->Resolve());
EXPECT_EQ(r()->error(),

View File

@ -548,22 +548,27 @@ bool Validator::StorageClassLayout(const sem::Variable* var,
return true;
}
bool Validator::LocalVariable(const sem::Variable* v) const {
auto* decl = v->Declaration();
bool Validator::LocalVariable(const sem::Variable* local) const {
auto* decl = local->Declaration();
if (IsArrayWithOverrideCount(local->Type())) {
RaiseArrayWithOverrideCountError(decl->type ? decl->type->source
: decl->constructor->source);
return false;
}
return Switch(
decl, //
[&](const ast::Var* var) {
if (IsValidationEnabled(var->attributes,
ast::DisabledValidation::kIgnoreStorageClass)) {
if (!v->Type()->UnwrapRef()->IsConstructible()) {
if (!local->Type()->UnwrapRef()->IsConstructible()) {
AddError("function-scope 'var' must have a constructible type",
var->type ? var->type->source : var->source);
return false;
}
}
return Var(v);
return Var(local);
}, //
[&](const ast::Let*) { return Let(v); }, //
[&](const ast::Let*) { return Let(local); }, //
[&](const ast::Const*) { return true; }, //
[&](Default) {
TINT_ICE(Resolver, diagnostics_)
@ -578,6 +583,12 @@ bool Validator::GlobalVariable(
const std::unordered_map<OverrideId, const sem::Variable*>& override_ids,
const std::unordered_map<const sem::Type*, const Source&>& atomic_composite_info) const {
auto* decl = global->Declaration();
if (global->StorageClass() != ast::StorageClass::kWorkgroup &&
IsArrayWithOverrideCount(global->Type())) {
RaiseArrayWithOverrideCountError(decl->type ? decl->type->source
: decl->constructor->source);
return false;
}
bool ok = Switch(
decl, //
[&](const ast::Var* var) {
@ -875,7 +886,7 @@ bool Validator::Parameter(const ast::Function* func, const sem::Variable* var) c
if (IsPlain(var->Type())) {
if (!var->Type()->IsConstructible()) {
AddError("type of function parameter must be constructible", decl->source);
AddError("type of function parameter must be constructible", decl->type->source);
return false;
}
} else if (!var->Type()->IsAnyOf<sem::Texture, sem::Sampler, sem::Pointer>()) {
@ -1898,20 +1909,23 @@ bool Validator::ArrayConstructor(const ast::CallExpression* ctor,
if (array_type->IsRuntimeSized()) {
AddError("cannot construct a runtime-sized array", ctor->source);
return false;
} else if (!elem_ty->IsConstructible()) {
}
if (array_type->IsOverrideSized()) {
AddError("cannot construct an array that has an override expression count", ctor->source);
return false;
}
if (!elem_ty->IsConstructible()) {
AddError("array constructor has non-constructible element type", ctor->source);
return false;
} else if (!values.IsEmpty() && (values.Length() != array_type->Count())) {
std::string fm = values.Length() < array_type->Count() ? "few" : "many";
}
const auto count = std::get<sem::ConstantArrayCount>(array_type->Count()).value;
if (!values.IsEmpty() && (values.Length() != count)) {
std::string fm = values.Length() < count ? "few" : "many";
AddError("array constructor has too " + fm + " elements: expected " +
std::to_string(array_type->Count()) + ", found " +
std::to_string(values.Length()),
ctor->source);
return false;
} else if (values.Length() > array_type->Count()) {
AddError("array constructor has too many elements: expected " +
std::to_string(array_type->Count()) + ", found " +
std::to_string(values.Length()),
std::to_string(count) + ", found " + std::to_string(values.Length()),
ctor->source);
return false;
}
@ -2086,18 +2100,25 @@ bool Validator::PushConstants(const std::vector<sem::Function*>& entry_points) c
return true;
}
bool Validator::Array(const sem::Array* arr, const Source& source) const {
bool Validator::Array(const sem::Array* arr, const Source& el_source) const {
auto* el_ty = arr->ElemType();
if (!IsPlain(el_ty)) {
AddError(sem_.TypeNameOf(el_ty) + " cannot be used as an element type of an array", source);
AddError(sem_.TypeNameOf(el_ty) + " cannot be used as an element type of an array",
el_source);
return false;
}
if (!IsFixedFootprint(el_ty)) {
AddError("an array element type cannot contain a runtime-sized array", source);
AddError("an array element type cannot contain a runtime-sized array", el_source);
return false;
}
if (IsArrayWithOverrideCount(el_ty)) {
RaiseArrayWithOverrideCountError(el_source);
return false;
}
return true;
}
@ -2154,6 +2175,11 @@ bool Validator::Structure(const sem::Struct* str, ast::PipelineStage stage) cons
return false;
}
}
if (IsArrayWithOverrideCount(member->Type())) {
RaiseArrayWithOverrideCountError(member->Declaration()->type->source);
return false;
}
} else if (!IsFixedFootprint(member->Type())) {
AddError(
"a struct that contains a runtime array cannot be nested inside "
@ -2488,6 +2514,22 @@ bool Validator::IsValidationEnabled(utils::VectorRef<const ast::Attribute*> attr
return !IsValidationDisabled(attributes, validation);
}
bool Validator::IsArrayWithOverrideCount(const sem::Type* ty) const {
if (auto* arr = ty->UnwrapRef()->As<sem::Array>()) {
if (arr->IsOverrideSized()) {
return true;
}
}
return false;
}
void Validator::RaiseArrayWithOverrideCountError(const Source& source) const {
AddError(
"array with an 'override' element count can only be used as the store type of a "
"'var<workgroup>'",
source);
}
std::string Validator::VectorPretty(uint32_t size, const sem::Type* element_type) const {
sem::Vector vec_type(element_type, size);
return vec_type.FriendlyName(symbols_);

View File

@ -128,9 +128,10 @@ class Validator {
/// Validates the array
/// @param arr the array to validate
/// @param source the source of the array
/// @param el_source the source of the array element, or the array if the array does not have a
/// locally-declared element AST node.
/// @returns true on success, false otherwise.
bool Array(const sem::Array* arr, const Source& source) const;
bool Array(const sem::Array* arr, const Source& el_source) const;
/// Validates an array stride attribute
/// @param attr the stride attribute to validate
@ -463,6 +464,16 @@ class Validator {
ast::DisabledValidation validation) const;
private:
/// @param ty the type to check
/// @returns true if @p ty is an array with an `override` expression element count, otherwise
/// false.
bool IsArrayWithOverrideCount(const sem::Type* ty) const;
/// Raises an error about an array type using an `override` expression element count, outside
/// the single allowed use of a `var<workgroup>`.
/// @param source the source for the error
void RaiseArrayWithOverrideCountError(const Source& source) const;
/// Searches the current statement and up through parents of the current
/// statement looking for a loop or for-loop continuing statement.
/// @returns the closest continuing statement to the current statement that

View File

@ -89,12 +89,13 @@ TEST_F(ValidatorIsStorableTest, Atomic) {
}
TEST_F(ValidatorIsStorableTest, ArraySizedOfStorable) {
auto* arr = create<sem::Array>(create<sem::I32>(), 5u, 4u, 20u, 4u, 4u);
auto* arr =
create<sem::Array>(create<sem::I32>(), sem::ConstantArrayCount{5u}, 4u, 20u, 4u, 4u);
EXPECT_TRUE(v()->IsStorable(arr));
}
TEST_F(ValidatorIsStorableTest, ArrayUnsizedOfStorable) {
auto* arr = create<sem::Array>(create<sem::I32>(), 0u, 4u, 4u, 4u, 4u);
auto* arr = create<sem::Array>(create<sem::I32>(), sem::RuntimeArrayCount{}, 4u, 4u, 4u, 4u);
EXPECT_TRUE(v()->IsStorable(arr));
}

View File

@ -16,15 +16,22 @@
#include <string>
#include "src/tint/ast/variable.h"
#include "src/tint/debug.h"
#include "src/tint/sem/variable.h"
#include "src/tint/symbol_table.h"
#include "src/tint/utils/hash.h"
TINT_INSTANTIATE_TYPEINFO(tint::sem::Array);
namespace tint::sem {
const char* Array::kErrExpectedConstantCount =
"array size is an override-expression, when expected a constant-expression.\n"
"Was the SubstituteOverride transform run?";
Array::Array(const Type* element,
uint32_t count,
ArrayCount count,
uint32_t align,
uint32_t size,
uint32_t stride,
@ -35,8 +42,9 @@ Array::Array(const Type* element,
size_(size),
stride_(stride),
implicit_stride_(implicit_stride),
constructible_(count > 0 // Runtime-sized arrays are not constructible
&& element->IsConstructible()) {
// Only constant-expression sized arrays are constructible
constructible_(std::holds_alternative<ConstantArrayCount>(count) &&
element->IsConstructible()) {
TINT_ASSERT(Semantic, element_);
}
@ -64,8 +72,10 @@ std::string Array::FriendlyName(const SymbolTable& symbols) const {
out << "@stride(" << stride_ << ") ";
}
out << "array<" << element_->FriendlyName(symbols);
if (!IsRuntimeSized()) {
out << ", " << count_;
if (auto* const_count = std::get_if<ConstantArrayCount>(&count_)) {
out << ", " << const_count->value;
} else if (auto* override_count = std::get_if<OverrideArrayCount>(&count_)) {
out << ", " << symbols.NameFor(override_count->variable->Declaration()->symbol);
}
out << ">";
return out.str();

View File

@ -16,29 +16,116 @@
#define SRC_TINT_SEM_ARRAY_H_
#include <stdint.h>
#include <optional>
#include <string>
#include <variant>
#include "src/tint/sem/node.h"
#include "src/tint/sem/type.h"
#include "src/tint/utils/compiler_macros.h"
// Forward declarations
namespace tint::sem {
class GlobalVariable;
} // namespace tint::sem
namespace tint::sem {
/// The variant of an ArrayCount when the array is a constant expression.
/// Example:
/// ```
/// const N = 123;
/// type arr = array<i32, N>
/// ```
struct ConstantArrayCount {
/// The array count constant-expression value.
uint32_t value;
};
/// The variant of an ArrayCount when the count is a named override variable.
/// Example:
/// ```
/// override N : i32;
/// type arr = array<i32, N>
/// ```
struct OverrideArrayCount {
/// The `override` variable.
const GlobalVariable* variable;
};
/// The variant of an ArrayCount when the array is is runtime-sized.
/// Example:
/// ```
/// type arr = array<i32>
/// ```
struct RuntimeArrayCount {};
/// An array count is either a constant-expression value, an override identifier, or runtime-sized.
using ArrayCount = std::variant<ConstantArrayCount, OverrideArrayCount, RuntimeArrayCount>;
/// Equality operator
/// @param a the LHS ConstantArrayCount
/// @param b the RHS ConstantArrayCount
/// @returns true if @p a is equal to @p b
inline bool operator==(const ConstantArrayCount& a, const ConstantArrayCount& b) {
return a.value == b.value;
}
/// Equality operator
/// @param a the LHS OverrideArrayCount
/// @param b the RHS OverrideArrayCount
/// @returns true if @p a is equal to @p b
inline bool operator==(const OverrideArrayCount& a, const OverrideArrayCount& b) {
return a.variable == b.variable;
}
/// Equality operator
/// @returns true
inline bool operator==(const RuntimeArrayCount&, const RuntimeArrayCount&) {
return true;
}
/// Equality operator
/// @param a the LHS ArrayCount
/// @param b the RHS count
/// @returns true if @p a is equal to @p b
template <typename T,
typename = std::enable_if_t<std::is_same_v<T, ConstantArrayCount> ||
std::is_same_v<T, OverrideArrayCount> ||
std::is_same_v<T, RuntimeArrayCount>>>
inline bool operator==(const ArrayCount& a, const T& b) {
TINT_BEGIN_DISABLE_WARNING(UNREACHABLE_CODE);
return std::visit(
[&](auto count) {
if constexpr (std::is_same_v<std::decay_t<decltype(count)>, T>) {
return count == b;
}
return false;
},
a);
TINT_END_DISABLE_WARNING(UNREACHABLE_CODE);
}
/// Array holds the semantic information for Array nodes.
class Array final : public Castable<Array, Type> {
public:
/// An error message string stating that the array count was expected to be a constant
/// expression. Used by multiple writers and transforms.
static const char* kErrExpectedConstantCount;
/// Constructor
/// @param element the array element type
/// @param count the number of elements in the array. 0 represents a
/// runtime-sized array.
/// @param count the number of elements in the array.
/// @param align the byte alignment of the array
/// @param size the byte size of the array
/// @param size the byte size of the array. The size will be 0 if the array element count is
/// pipeline overridable.
/// @param stride the number of bytes from the start of one element of the
/// array to the start of the next element
/// @param implicit_stride the number of bytes from the start of one element
/// of the array to the start of the next element, if there was no `@stride`
/// attribute applied.
Array(Type const* element,
uint32_t count,
ArrayCount count,
uint32_t align,
uint32_t size,
uint32_t stride,
@ -54,9 +141,16 @@ class Array final : public Castable<Array, Type> {
/// @return the array element type
Type const* ElemType() const { return element_; }
/// @returns the number of elements in the array. 0 represents a runtime-sized
/// array.
uint32_t Count() const { return count_; }
/// @returns the number of elements in the array.
const ArrayCount& Count() const { return count_; }
/// @returns the array count if the count is a constant expression, otherwise returns nullopt.
inline std::optional<uint32_t> ConstantCount() const {
if (auto* count = std::get_if<ConstantArrayCount>(&count_)) {
return count->value;
}
return std::nullopt;
}
/// @returns the byte alignment of the array
/// @note this may differ from the alignment of a structure member of this
@ -81,8 +175,14 @@ class Array final : public Castable<Array, Type> {
/// natural stride
bool IsStrideImplicit() const { return stride_ == implicit_stride_; }
/// @returns true if this array is sized using an constant expression
bool IsConstantSized() const { return std::holds_alternative<ConstantArrayCount>(count_); }
/// @returns true if this array is sized using an override variable
bool IsOverrideSized() const { return std::holds_alternative<OverrideArrayCount>(count_); }
/// @returns true if this array is runtime sized
bool IsRuntimeSized() const { return count_ == 0; }
bool IsRuntimeSized() const { return std::holds_alternative<RuntimeArrayCount>(count_); }
/// @returns true if constructible as per
/// https://gpuweb.github.io/gpuweb/wgsl/#constructible-types
@ -95,7 +195,7 @@ class Array final : public Castable<Array, Type> {
private:
Type const* const element_;
const uint32_t count_;
const ArrayCount count_;
const uint32_t align_;
const uint32_t size_;
const uint32_t stride_;
@ -105,4 +205,38 @@ class Array final : public Castable<Array, Type> {
} // namespace tint::sem
namespace std {
/// Custom std::hash specialization for tint::sem::ConstantArrayCount.
template <>
class hash<tint::sem::ConstantArrayCount> {
public:
/// @param count the count to hash
/// @return the hash value
inline std::size_t operator()(const tint::sem::ConstantArrayCount& count) const {
return std::hash<decltype(count.value)>()(count.value);
}
};
/// Custom std::hash specialization for tint::sem::OverrideArrayCount.
template <>
class hash<tint::sem::OverrideArrayCount> {
public:
/// @param count the count to hash
/// @return the hash value
inline std::size_t operator()(const tint::sem::OverrideArrayCount& count) const {
return std::hash<decltype(count.variable)>()(count.variable);
}
};
/// Custom std::hash specialization for tint::sem::RuntimeArrayCount.
template <>
class hash<tint::sem::RuntimeArrayCount> {
public:
/// @return the hash value
inline std::size_t operator()(const tint::sem::RuntimeArrayCount&) const { return 42; }
};
} // namespace std
#endif // SRC_TINT_SEM_ARRAY_H_

View File

@ -21,16 +21,16 @@ namespace {
using ArrayTest = TestHelper;
TEST_F(ArrayTest, CreateSizedArray) {
auto* a = create<Array>(create<U32>(), 2u, 4u, 8u, 32u, 16u);
auto* b = create<Array>(create<U32>(), 2u, 4u, 8u, 32u, 16u);
auto* c = create<Array>(create<U32>(), 3u, 4u, 8u, 32u, 16u);
auto* d = create<Array>(create<U32>(), 2u, 5u, 8u, 32u, 16u);
auto* e = create<Array>(create<U32>(), 2u, 4u, 9u, 32u, 16u);
auto* f = create<Array>(create<U32>(), 2u, 4u, 8u, 33u, 16u);
auto* g = create<Array>(create<U32>(), 2u, 4u, 8u, 33u, 17u);
auto* a = create<Array>(create<U32>(), ConstantArrayCount{2u}, 4u, 8u, 32u, 16u);
auto* b = create<Array>(create<U32>(), ConstantArrayCount{2u}, 4u, 8u, 32u, 16u);
auto* c = create<Array>(create<U32>(), ConstantArrayCount{3u}, 4u, 8u, 32u, 16u);
auto* d = create<Array>(create<U32>(), ConstantArrayCount{2u}, 5u, 8u, 32u, 16u);
auto* e = create<Array>(create<U32>(), ConstantArrayCount{2u}, 4u, 9u, 32u, 16u);
auto* f = create<Array>(create<U32>(), ConstantArrayCount{2u}, 4u, 8u, 33u, 16u);
auto* g = create<Array>(create<U32>(), ConstantArrayCount{2u}, 4u, 8u, 33u, 17u);
EXPECT_EQ(a->ElemType(), create<U32>());
EXPECT_EQ(a->Count(), 2u);
EXPECT_EQ(a->Count(), ConstantArrayCount{2u});
EXPECT_EQ(a->Align(), 4u);
EXPECT_EQ(a->Size(), 8u);
EXPECT_EQ(a->Stride(), 32u);
@ -47,15 +47,15 @@ TEST_F(ArrayTest, CreateSizedArray) {
}
TEST_F(ArrayTest, CreateRuntimeArray) {
auto* a = create<Array>(create<U32>(), 0u, 4u, 8u, 32u, 32u);
auto* b = create<Array>(create<U32>(), 0u, 4u, 8u, 32u, 32u);
auto* c = create<Array>(create<U32>(), 0u, 5u, 8u, 32u, 32u);
auto* d = create<Array>(create<U32>(), 0u, 4u, 9u, 32u, 32u);
auto* e = create<Array>(create<U32>(), 0u, 4u, 8u, 33u, 32u);
auto* f = create<Array>(create<U32>(), 0u, 4u, 8u, 33u, 17u);
auto* a = create<Array>(create<U32>(), RuntimeArrayCount{}, 4u, 8u, 32u, 32u);
auto* b = create<Array>(create<U32>(), RuntimeArrayCount{}, 4u, 8u, 32u, 32u);
auto* c = create<Array>(create<U32>(), RuntimeArrayCount{}, 5u, 8u, 32u, 32u);
auto* d = create<Array>(create<U32>(), RuntimeArrayCount{}, 4u, 9u, 32u, 32u);
auto* e = create<Array>(create<U32>(), RuntimeArrayCount{}, 4u, 8u, 33u, 32u);
auto* f = create<Array>(create<U32>(), RuntimeArrayCount{}, 4u, 8u, 33u, 17u);
EXPECT_EQ(a->ElemType(), create<U32>());
EXPECT_EQ(a->Count(), 0u);
EXPECT_EQ(a->Count(), sem::RuntimeArrayCount{});
EXPECT_EQ(a->Align(), 4u);
EXPECT_EQ(a->Size(), 8u);
EXPECT_EQ(a->Stride(), 32u);
@ -71,13 +71,13 @@ TEST_F(ArrayTest, CreateRuntimeArray) {
}
TEST_F(ArrayTest, Hash) {
auto* a = create<Array>(create<U32>(), 2u, 4u, 8u, 32u, 16u);
auto* b = create<Array>(create<U32>(), 2u, 4u, 8u, 32u, 16u);
auto* c = create<Array>(create<U32>(), 3u, 4u, 8u, 32u, 16u);
auto* d = create<Array>(create<U32>(), 2u, 5u, 8u, 32u, 16u);
auto* e = create<Array>(create<U32>(), 2u, 4u, 9u, 32u, 16u);
auto* f = create<Array>(create<U32>(), 2u, 4u, 8u, 33u, 16u);
auto* g = create<Array>(create<U32>(), 2u, 4u, 8u, 33u, 17u);
auto* a = create<Array>(create<U32>(), ConstantArrayCount{2u}, 4u, 8u, 32u, 16u);
auto* b = create<Array>(create<U32>(), ConstantArrayCount{2u}, 4u, 8u, 32u, 16u);
auto* c = create<Array>(create<U32>(), ConstantArrayCount{3u}, 4u, 8u, 32u, 16u);
auto* d = create<Array>(create<U32>(), ConstantArrayCount{2u}, 5u, 8u, 32u, 16u);
auto* e = create<Array>(create<U32>(), ConstantArrayCount{2u}, 4u, 9u, 32u, 16u);
auto* f = create<Array>(create<U32>(), ConstantArrayCount{2u}, 4u, 8u, 33u, 16u);
auto* g = create<Array>(create<U32>(), ConstantArrayCount{2u}, 4u, 8u, 33u, 17u);
EXPECT_EQ(a->Hash(), b->Hash());
EXPECT_NE(a->Hash(), c->Hash());
@ -88,13 +88,13 @@ TEST_F(ArrayTest, Hash) {
}
TEST_F(ArrayTest, Equals) {
auto* a = create<Array>(create<U32>(), 2u, 4u, 8u, 32u, 16u);
auto* b = create<Array>(create<U32>(), 2u, 4u, 8u, 32u, 16u);
auto* c = create<Array>(create<U32>(), 3u, 4u, 8u, 32u, 16u);
auto* d = create<Array>(create<U32>(), 2u, 5u, 8u, 32u, 16u);
auto* e = create<Array>(create<U32>(), 2u, 4u, 9u, 32u, 16u);
auto* f = create<Array>(create<U32>(), 2u, 4u, 8u, 33u, 16u);
auto* g = create<Array>(create<U32>(), 2u, 4u, 8u, 33u, 17u);
auto* a = create<Array>(create<U32>(), ConstantArrayCount{2u}, 4u, 8u, 32u, 16u);
auto* b = create<Array>(create<U32>(), ConstantArrayCount{2u}, 4u, 8u, 32u, 16u);
auto* c = create<Array>(create<U32>(), ConstantArrayCount{3u}, 4u, 8u, 32u, 16u);
auto* d = create<Array>(create<U32>(), ConstantArrayCount{2u}, 5u, 8u, 32u, 16u);
auto* e = create<Array>(create<U32>(), ConstantArrayCount{2u}, 4u, 9u, 32u, 16u);
auto* f = create<Array>(create<U32>(), ConstantArrayCount{2u}, 4u, 8u, 33u, 16u);
auto* g = create<Array>(create<U32>(), ConstantArrayCount{2u}, 4u, 8u, 33u, 17u);
EXPECT_TRUE(a->Equals(*b));
EXPECT_FALSE(a->Equals(*c));
@ -106,22 +106,22 @@ TEST_F(ArrayTest, Equals) {
}
TEST_F(ArrayTest, FriendlyNameRuntimeSized) {
auto* arr = create<Array>(create<I32>(), 0u, 0u, 4u, 4u, 4u);
auto* arr = create<Array>(create<I32>(), RuntimeArrayCount{}, 0u, 4u, 4u, 4u);
EXPECT_EQ(arr->FriendlyName(Symbols()), "array<i32>");
}
TEST_F(ArrayTest, FriendlyNameStaticSized) {
auto* arr = create<Array>(create<I32>(), 5u, 4u, 20u, 4u, 4u);
auto* arr = create<Array>(create<I32>(), ConstantArrayCount{5u}, 4u, 20u, 4u, 4u);
EXPECT_EQ(arr->FriendlyName(Symbols()), "array<i32, 5>");
}
TEST_F(ArrayTest, FriendlyNameRuntimeSizedNonImplicitStride) {
auto* arr = create<Array>(create<I32>(), 0u, 0u, 4u, 8u, 4u);
auto* arr = create<Array>(create<I32>(), RuntimeArrayCount{}, 0u, 4u, 8u, 4u);
EXPECT_EQ(arr->FriendlyName(Symbols()), "@stride(8) array<i32>");
}
TEST_F(ArrayTest, FriendlyNameStaticSizedNonImplicitStride) {
auto* arr = create<Array>(create<I32>(), 5u, 4u, 20u, 8u, 4u);
auto* arr = create<Array>(create<I32>(), ConstantArrayCount{5u}, 4u, 20u, 8u, 4u);
EXPECT_EQ(arr->FriendlyName(Symbols()), "@stride(8) array<i32, 5>");
}

View File

@ -246,7 +246,9 @@ const Type* Type::ElementOf(const Type* ty, uint32_t* count /* = nullptr */) {
},
[&](const Array* a) {
if (count) {
*count = a->Count();
if (auto* const_count = std::get_if<ConstantArrayCount>(&a->Count())) {
*count = const_count->value;
}
}
return a->ElemType();
},

View File

@ -62,56 +62,56 @@ struct TypeTest : public TestHelper {
/* size_no_padding*/ 4u);
const sem::Array* arr_i32 = create<Array>(
/* element */ i32,
/* count */ 5u,
/* count */ ConstantArrayCount{5u},
/* align */ 4u,
/* size */ 5u * 4u,
/* stride */ 5u * 4u,
/* implicit_stride */ 5u * 4u);
const sem::Array* arr_ai = create<Array>(
/* element */ ai,
/* count */ 5u,
/* count */ ConstantArrayCount{5u},
/* align */ 4u,
/* size */ 5u * 4u,
/* stride */ 5u * 4u,
/* implicit_stride */ 5u * 4u);
const sem::Array* arr_vec3_i32 = create<Array>(
/* element */ vec3_i32,
/* count */ 5u,
/* count */ ConstantArrayCount{5u},
/* align */ 16u,
/* size */ 5u * 16u,
/* stride */ 5u * 16u,
/* implicit_stride */ 5u * 16u);
const sem::Array* arr_vec3_ai = create<Array>(
/* element */ vec3_ai,
/* count */ 5u,
/* count */ ConstantArrayCount{5u},
/* align */ 16u,
/* size */ 5u * 16u,
/* stride */ 5u * 16u,
/* implicit_stride */ 5u * 16u);
const sem::Array* arr_mat4x3_f16 = create<Array>(
/* element */ mat4x3_f16,
/* count */ 5u,
/* count */ ConstantArrayCount{5u},
/* align */ 32u,
/* size */ 5u * 32u,
/* stride */ 5u * 32u,
/* implicit_stride */ 5u * 32u);
const sem::Array* arr_mat4x3_f32 = create<Array>(
/* element */ mat4x3_f32,
/* count */ 5u,
/* count */ ConstantArrayCount{5u},
/* align */ 64u,
/* size */ 5u * 64u,
/* stride */ 5u * 64u,
/* implicit_stride */ 5u * 64u);
const sem::Array* arr_mat4x3_af = create<Array>(
/* element */ mat4x3_af,
/* count */ 5u,
/* count */ ConstantArrayCount{5u},
/* align */ 64u,
/* size */ 5u * 64u,
/* stride */ 5u * 64u,
/* implicit_stride */ 5u * 64u);
const sem::Array* arr_str = create<Array>(
/* element */ str,
/* count */ 5u,
/* count */ ConstantArrayCount{5u},
/* align */ 4u,
/* size */ 5u * 4u,
/* stride */ 5u * 4u,

View File

@ -471,8 +471,18 @@ struct DecomposeMemoryAccess::State {
auto* arr = b.Var(b.Symbols().New("arr"), CreateASTTypeFor(ctx, arr_ty));
auto* i = b.Var(b.Symbols().New("i"), b.Expr(0_u));
auto* for_init = b.Decl(i);
auto arr_cnt = arr_ty->ConstantCount();
if (!arr_cnt) {
// Non-constant counts should not be possible:
// * Override-expression counts can only be applied to workgroup arrays, and
// this method only handles storage and uniform.
// * Runtime-sized arrays are not loadable.
TINT_ICE(Transform, ctx.dst->Diagnostics())
<< "unexpected non-constant array count";
arr_cnt = 1;
}
auto* for_cond = b.create<ast::BinaryExpression>(
ast::BinaryOp::kLessThan, b.Expr(i), b.Expr(u32(arr_ty->Count())));
ast::BinaryOp::kLessThan, b.Expr(i), b.Expr(u32(arr_cnt.value())));
auto* for_cont = b.Assign(i, b.Add(i, 1_u));
auto* arr_el = b.IndexAccessor(arr, i);
auto* el_offset = b.Add(b.Expr("offset"), b.Mul(i, u32(arr_ty->Stride())));
@ -562,8 +572,18 @@ struct DecomposeMemoryAccess::State {
StoreFunc(buf_ty, arr_ty->ElemType()->UnwrapRef(), var_user);
auto* i = b.Var(b.Symbols().New("i"), b.Expr(0_u));
auto* for_init = b.Decl(i);
auto arr_cnt = arr_ty->ConstantCount();
if (!arr_cnt) {
// Non-constant counts should not be possible:
// * Override-expression counts can only be applied to workgroup
// arrays, and this method only handles storage and uniform.
// * Runtime-sized arrays are not storable.
TINT_ICE(Transform, ctx.dst->Diagnostics())
<< "unexpected non-constant array count";
arr_cnt = 1;
}
auto* for_cond = b.create<ast::BinaryExpression>(
ast::BinaryOp::kLessThan, b.Expr(i), b.Expr(u32(arr_ty->Count())));
ast::BinaryOp::kLessThan, b.Expr(i), b.Expr(u32(arr_cnt.value())));
auto* for_cont = b.Assign(i, b.Add(i, 1_u));
auto* arr_el = b.IndexAccessor(array, i);
auto* el_offset =

View File

@ -82,7 +82,7 @@ void PadStructs::Run(CloneContext& ctx, const DataMap&, DataMap&) const {
// std140 structs should be padded out to 16 bytes.
size = utils::RoundUp(16u, size);
} else if (auto* array_ty = ty->As<sem::Array>()) {
if (array_ty->Count() == 0) {
if (array_ty->IsRuntimeSized()) {
has_runtime_sized_array = true;
}
}

View File

@ -99,14 +99,21 @@ struct Robustness::State {
// Must clamp, even if the index is constant.
auto* arr_ptr = b.AddressOf(ctx.Clone(expr->object));
max = b.Sub(b.Call("arrayLength", arr_ptr), 1_u);
} else {
} else if (auto count = arr->ConstantCount()) {
if (sem->Index()->ConstantValue()) {
// Index and size is constant.
// Validation will have rejected any OOB accesses.
return nullptr;
}
max = b.Expr(u32(arr->Count() - 1u));
max = b.Expr(u32(count.value() - 1u));
} else {
// Note: Don't be tempted to use the array override variable as an expression
// here, the name might be shadowed!
ctx.dst->Diagnostics().add_error(diag::System::Transform,
sem::Array::kErrExpectedConstantCount);
return nullptr;
}
return b.Call("min", idx(), max);
},
[&](Default) {

View File

@ -1316,5 +1316,23 @@ fn f() {
EXPECT_EQ(expect, str(got));
}
TEST_F(RobustnessTest, WorkgroupOverrideCount) {
auto* src = R"(
override N = 123;
var<workgroup> w : array<f32, N>;
fn f() {
var b : f32 = w[1i];
}
)";
auto* expect = R"(error: array size is an override-expression, when expected a constant-expression.
Was the SubstituteOverride transform run?)";
auto got = Run<Robustness>(src);
EXPECT_EQ(expect, str(got));
}
} // namespace
} // namespace tint::transform

View File

@ -35,6 +35,8 @@ TINT_INSTANTIATE_TYPEINFO(tint::transform::SpirvAtomic::Stub);
namespace tint::transform {
using namespace tint::number_suffixes; // NOLINT
/// Private implementation of transform
struct SpirvAtomic::State {
private:
@ -189,10 +191,19 @@ struct SpirvAtomic::State {
[&](const sem::I32*) { return b.ty.atomic(CreateASTTypeFor(ctx, ty)); },
[&](const sem::U32*) { return b.ty.atomic(CreateASTTypeFor(ctx, ty)); },
[&](const sem::Struct* str) { return b.ty.type_name(Fork(str->Declaration()).name); },
[&](const sem::Array* arr) {
return arr->IsRuntimeSized()
? b.ty.array(AtomicTypeFor(arr->ElemType()))
: b.ty.array(AtomicTypeFor(arr->ElemType()), u32(arr->Count()));
[&](const sem::Array* arr) -> const ast::Type* {
if (arr->IsRuntimeSized()) {
return b.ty.array(AtomicTypeFor(arr->ElemType()));
}
auto count = arr->ConstantCount();
if (!count) {
ctx.dst->Diagnostics().add_error(
diag::System::Transform,
"the SpirvAtomic transform does not currently support array counts that "
"use override values");
count = 1;
}
return b.ty.array(AtomicTypeFor(arr->ElemType()), u32(count.value()));
},
[&](const sem::Pointer* ptr) {
return b.ty.pointer(AtomicTypeFor(ptr->StoreType()), ptr->StorageClass(),

View File

@ -423,7 +423,17 @@ struct Std140::State {
if (!arr->IsStrideImplicit()) {
attrs.Push(ctx.dst->create<ast::StrideAttribute>(arr->Stride()));
}
return b.create<ast::Array>(std140, b.Expr(u32(arr->Count())),
auto count = arr->ConstantCount();
if (!count) {
// Non-constant counts should not be possible:
// * Override-expression counts can only be applied to workgroup arrays, and
// this method only handles types transitively used as uniform buffers.
// * Runtime-sized arrays cannot be used in uniform buffers.
TINT_ICE(Transform, ctx.dst->Diagnostics())
<< "unexpected non-constant array count";
count = 1;
}
return b.create<ast::Array>(std140, b.Expr(u32(count.value())),
std::move(attrs));
}
return nullptr;
@ -613,7 +623,17 @@ struct Std140::State {
ty, //
[&](const sem::Struct* str) { return sym.NameFor(str->Name()); },
[&](const sem::Array* arr) {
return "arr" + std::to_string(arr->Count()) + "_" + ConvertSuffix(arr->ElemType());
auto count = arr->ConstantCount();
if (!count) {
// Non-constant counts should not be possible:
// * Override-expression counts can only be applied to workgroup arrays, and
// this method only handles types transitively used as uniform buffers.
// * Runtime-sized arrays cannot be used in uniform buffers.
TINT_ICE(Transform, ctx.dst->Diagnostics())
<< "unexpected non-constant array count";
count = 1;
}
return "arr" + std::to_string(count.value()) + "_" + ConvertSuffix(arr->ElemType());
},
[&](const sem::Matrix* mat) {
return "mat" + std::to_string(mat->columns()) + "x" + std::to_string(mat->rows()) +
@ -710,9 +730,19 @@ struct Std140::State {
auto* i = b.Var("i", b.ty.u32());
auto* dst_el = b.IndexAccessor(var, i);
auto* src_el = Convert(arr->ElemType(), b.IndexAccessor(param, i));
auto count = arr->ConstantCount();
if (!count) {
// Non-constant counts should not be possible:
// * Override-expression counts can only be applied to workgroup arrays, and
// this method only handles types transitively used as uniform buffers.
// * Runtime-sized arrays cannot be used in uniform buffers.
TINT_ICE(Transform, ctx.dst->Diagnostics())
<< "unexpected non-constant array count";
count = 1;
}
stmts.Push(b.Decl(var));
stmts.Push(b.For(b.Decl(i), //
b.LessThan(i, u32(arr->Count())), //
b.LessThan(i, u32(count.value())), //
b.Assign(i, b.Add(i, 1_a)), //
b.Block(b.Assign(dst_el, src_el))));
stmts.Push(b.Return(var));

View File

@ -24,6 +24,7 @@
#include "src/tint/sem/for_loop_statement.h"
#include "src/tint/sem/reference.h"
#include "src/tint/sem/sampler.h"
#include "src/tint/sem/variable.h"
TINT_INSTANTIATE_TYPEINFO(tint::transform::Transform);
TINT_INSTANTIATE_TYPEINFO(tint::transform::Data);
@ -112,9 +113,16 @@ const ast::Type* Transform::CreateASTTypeFor(CloneContext& ctx, const sem::Type*
}
if (a->IsRuntimeSized()) {
return ctx.dst->ty.array(el, nullptr, std::move(attrs));
} else {
return ctx.dst->ty.array(el, u32(a->Count()), std::move(attrs));
}
if (auto* override = std::get_if<sem::OverrideArrayCount>(&a->Count())) {
auto* count = ctx.Clone(override->variable->Declaration());
return ctx.dst->ty.array(el, count, std::move(attrs));
}
if (auto count = a->ConstantCount()) {
return ctx.dst->ty.array(el, u32(count.value()), std::move(attrs));
}
TINT_ICE(Transform, ctx.dst->Diagnostics()) << sem::Array::kErrExpectedConstantCount;
return ctx.dst->ty.array(el, u32(1), std::move(attrs));
}
if (auto* s = ty->As<sem::Struct>()) {
return ctx.dst->create<ast::TypeName>(ctx.Clone(s->Declaration()->name));

View File

@ -65,7 +65,8 @@ TEST_F(CreateASTTypeForTest, Vector) {
TEST_F(CreateASTTypeForTest, ArrayImplicitStride) {
auto* arr = create([](ProgramBuilder& b) {
return b.create<sem::Array>(b.create<sem::F32>(), 2u, 4u, 4u, 32u, 32u);
return b.create<sem::Array>(b.create<sem::F32>(), sem::ConstantArrayCount{2u}, 4u, 4u, 32u,
32u);
});
ASSERT_TRUE(arr->Is<ast::Array>());
ASSERT_TRUE(arr->As<ast::Array>()->type->Is<ast::F32>());
@ -78,7 +79,8 @@ TEST_F(CreateASTTypeForTest, ArrayImplicitStride) {
TEST_F(CreateASTTypeForTest, ArrayNonImplicitStride) {
auto* arr = create([](ProgramBuilder& b) {
return b.create<sem::Array>(b.create<sem::F32>(), 2u, 4u, 4u, 64u, 32u);
return b.create<sem::Array>(b.create<sem::F32>(), sem::ConstantArrayCount{2u}, 4u, 4u, 64u,
32u);
});
ASSERT_TRUE(arr->Is<ast::Array>());
ASSERT_TRUE(arr->As<ast::Array>()->type->Is<ast::F32>());

View File

@ -307,7 +307,13 @@ struct ZeroInitWorkgroupMemory::State {
// `num_values * arr->Count()`
// The index for this array is:
// `(idx % modulo) / division`
auto modulo = num_values * arr->Count();
auto count = arr->ConstantCount();
if (!count) {
ctx.dst->Diagnostics().add_error(diag::System::Transform,
sem::Array::kErrExpectedConstantCount);
return Expression{};
}
auto modulo = num_values * count.value();
auto division = num_values;
auto a = get_expr(modulo);
auto array_indices = a.array_indices;

View File

@ -2240,7 +2240,13 @@ bool GeneratorImpl::EmitConstant(std::ostream& out, const sem::Constant* constan
ScopedParen sp(out);
for (size_t i = 0; i < a->Count(); i++) {
auto count = a->ConstantCount();
if (!count) {
diagnostics_.add_error(diag::System::Writer, sem::Array::kErrExpectedConstantCount);
return false;
}
for (size_t i = 0; i < count; i++) {
if (i > 0) {
out << ", ";
}
@ -2356,16 +2362,23 @@ bool GeneratorImpl::EmitZeroValue(std::ostream& out, const sem::Type* type) {
}
EmitZeroValue(out, member->Type());
}
} else if (auto* array = type->As<sem::Array>()) {
} else if (auto* arr = type->As<sem::Array>()) {
if (!EmitType(out, type, ast::StorageClass::kNone, ast::Access::kUndefined, "")) {
return false;
}
ScopedParen sp(out);
for (uint32_t i = 0; i < array->Count(); i++) {
auto count = arr->ConstantCount();
if (!count) {
diagnostics_.add_error(diag::System::Writer, sem::Array::kErrExpectedConstantCount);
return false;
}
for (uint32_t i = 0; i < count; i++) {
if (i != 0) {
out << ", ";
}
EmitZeroValue(out, array->ElemType());
EmitZeroValue(out, arr->ElemType());
}
} else {
diagnostics_.add_error(diag::System::Writer, "Invalid type for zero emission: " +
@ -2697,7 +2710,18 @@ bool GeneratorImpl::EmitType(std::ostream& out,
const sem::Type* base_type = ary;
std::vector<uint32_t> sizes;
while (auto* arr = base_type->As<sem::Array>()) {
sizes.push_back(arr->Count());
if (arr->IsRuntimeSized()) {
sizes.push_back(0);
} else {
auto count = arr->ConstantCount();
if (!count) {
diagnostics_.add_error(diag::System::Writer,
sem::Array::kErrExpectedConstantCount);
return false;
}
sizes.push_back(count.value());
}
base_type = arr->ElemType();
}
if (!EmitType(out, base_type, storage_class, access, "")) {

View File

@ -3193,7 +3193,13 @@ bool GeneratorImpl::EmitConstant(std::ostream& out, const sem::Constant* constan
out << "{";
TINT_DEFER(out << "}");
for (size_t i = 0; i < a->Count(); i++) {
auto count = a->ConstantCount();
if (!count) {
diagnostics_.add_error(diag::System::Writer, sem::Array::kErrExpectedConstantCount);
return false;
}
for (size_t i = 0; i < count; i++) {
if (i > 0) {
out << ", ";
}
@ -3732,11 +3738,18 @@ bool GeneratorImpl::EmitType(std::ostream& out,
while (auto* arr = base_type->As<sem::Array>()) {
if (arr->IsRuntimeSized()) {
TINT_ICE(Writer, diagnostics_)
<< "Runtime arrays may only exist in storage buffers, which should have "
<< "runtime arrays may only exist in storage buffers, which should have "
"been transformed into a ByteAddressBuffer";
return false;
}
sizes.push_back(arr->Count());
const auto count = arr->ConstantCount();
if (!count) {
diagnostics_.add_error(diag::System::Writer,
sem::Array::kErrExpectedConstantCount);
return false;
}
sizes.push_back(count.value());
base_type = arr->ElemType();
}
if (!EmitType(out, base_type, storage_class, access, "")) {

View File

@ -1690,7 +1690,13 @@ bool GeneratorImpl::EmitConstant(std::ostream& out, const sem::Constant* constan
return true;
}
for (size_t i = 0; i < a->Count(); i++) {
auto count = a->ConstantCount();
if (!count) {
diagnostics_.add_error(diag::System::Writer, sem::Array::kErrExpectedConstantCount);
return false;
}
for (size_t i = 0; i < count; i++) {
if (i > 0) {
out << ", ";
}
@ -2481,7 +2487,20 @@ bool GeneratorImpl::EmitType(std::ostream& out,
if (!EmitType(out, arr->ElemType(), "")) {
return false;
}
out << ", " << (arr->IsRuntimeSized() ? 1u : arr->Count()) << ">";
out << ", ";
if (arr->IsRuntimeSized()) {
out << "1";
} else {
auto count = arr->ConstantCount();
if (!count) {
diagnostics_.add_error(diag::System::Writer,
sem::Array::kErrExpectedConstantCount);
return false;
}
out << count.value();
}
out << ">";
return true;
},
[&](const sem::Bool*) {
@ -3133,8 +3152,14 @@ GeneratorImpl::SizeAndAlign GeneratorImpl::MslPackedTypeSizeAndAlign(const sem::
<< "arrays with explicit strides should not exist past the SPIR-V reader";
return SizeAndAlign{};
}
auto num_els = std::max<uint32_t>(arr->Count(), 1);
return SizeAndAlign{arr->Stride() * num_els, arr->Align()};
if (arr->IsRuntimeSized()) {
return SizeAndAlign{arr->Stride(), arr->Align()};
}
if (auto count = arr->ConstantCount()) {
return SizeAndAlign{arr->Stride() * count.value(), arr->Align()};
}
diagnostics_.add_error(diag::System::Writer, sem::Array::kErrExpectedConstantCount);
return SizeAndAlign{};
},
[&](const sem::Struct* str) {

View File

@ -1676,11 +1676,18 @@ uint32_t Builder::GenerateConstantIfNeeded(const sem::Constant* constant) {
},
[&](const sem::Vector* v) { return composite(v->Width()); },
[&](const sem::Matrix* m) { return composite(m->columns()); },
[&](const sem::Array* a) { return composite(a->Count()); },
[&](const sem::Array* a) {
auto count = a->ConstantCount();
if (!count) {
error_ = sem::Array::kErrExpectedConstantCount;
return static_cast<uint32_t>(0);
}
return composite(count.value());
},
[&](const sem::Struct* s) { return composite(s->Members().size()); },
[&](Default) {
error_ = "unhandled constant type: " + builder_.FriendlyName(ty);
return false;
return 0;
});
}
@ -3852,17 +3859,23 @@ bool Builder::GenerateTextureType(const sem::Texture* texture, const Operand& re
return true;
}
bool Builder::GenerateArrayType(const sem::Array* ary, const Operand& result) {
auto elem_type = GenerateTypeIfNeeded(ary->ElemType());
bool Builder::GenerateArrayType(const sem::Array* arr, const Operand& result) {
auto elem_type = GenerateTypeIfNeeded(arr->ElemType());
if (elem_type == 0) {
return false;
}
auto result_id = std::get<uint32_t>(result);
if (ary->IsRuntimeSized()) {
if (arr->IsRuntimeSized()) {
push_type(spv::Op::OpTypeRuntimeArray, {result, Operand(elem_type)});
} else {
auto len_id = GenerateConstantIfNeeded(ScalarConstant::U32(ary->Count()));
auto count = arr->ConstantCount();
if (!count) {
error_ = sem::Array::kErrExpectedConstantCount;
return static_cast<uint32_t>(0);
}
auto len_id = GenerateConstantIfNeeded(ScalarConstant::U32(count.value()));
if (len_id == 0) {
return false;
}
@ -3871,7 +3884,7 @@ bool Builder::GenerateArrayType(const sem::Array* ary, const Operand& result) {
}
push_annot(spv::Op::OpDecorate,
{Operand(result_id), U32Operand(SpvDecorationArrayStride), Operand(ary->Stride())});
{Operand(result_id), U32Operand(SpvDecorationArrayStride), Operand(arr->Stride())});
return true;
}

View File

@ -0,0 +1,5 @@
// flags: --transform substitute_override
override size = 2;
var<workgroup> a : array<f32, size>;

View File

@ -0,0 +1,6 @@
[numthreads(1, 1, 1)]
void unused_entry_point() {
return;
}
groupshared float a[2];

View File

@ -0,0 +1,6 @@
[numthreads(1, 1, 1)]
void unused_entry_point() {
return;
}
groupshared float a[2];

View File

@ -0,0 +1,7 @@
#version 310 es
layout(local_size_x = 1, local_size_y = 1, local_size_z = 1) in;
void unused_entry_point() {
return;
}
shared float a[2];

View File

@ -0,0 +1,3 @@
#include <metal_stdlib>
using namespace metal;

View File

@ -0,0 +1,24 @@
; SPIR-V
; Version: 1.3
; Generator: Google Tint Compiler; 0
; Bound: 11
; Schema: 0
OpCapability Shader
OpMemoryModel Logical GLSL450
OpEntryPoint GLCompute %unused_entry_point "unused_entry_point"
OpExecutionMode %unused_entry_point LocalSize 1 1 1
OpName %a "a"
OpName %unused_entry_point "unused_entry_point"
OpDecorate %_arr_float_uint_2 ArrayStride 4
%float = OpTypeFloat 32
%uint = OpTypeInt 32 0
%uint_2 = OpConstant %uint 2
%_arr_float_uint_2 = OpTypeArray %float %uint_2
%_ptr_Workgroup__arr_float_uint_2 = OpTypePointer Workgroup %_arr_float_uint_2
%a = OpVariable %_ptr_Workgroup__arr_float_uint_2 Workgroup
%void = OpTypeVoid
%7 = OpTypeFunction %void
%unused_entry_point = OpFunction %void None %7
%10 = OpLabel
OpReturn
OpFunctionEnd

View File

@ -0,0 +1,3 @@
const size = 2;
var<workgroup> a : array<f32, size>;