diff --git a/src/resolver/decoration_validation_test.cc b/src/resolver/decoration_validation_test.cc index 70d1b4c132..0c021a2353 100644 --- a/src/resolver/decoration_validation_test.cc +++ b/src/resolver/decoration_validation_test.cc @@ -24,6 +24,9 @@ #include "gmock/gmock.h" namespace tint { +namespace resolver { + +namespace DecorationTests { namespace { enum class DecorationKind { @@ -45,12 +48,11 @@ struct TestParams { DecorationKind kind; bool should_pass; }; -class TestWithParams : public resolver::TestHelper, - public testing::TestWithParam {}; +struct TestWithParams : ResolverTestWithParam {}; -ast::Decoration* createDecoration(const Source& source, - ProgramBuilder& builder, - DecorationKind kind) { +static ast::Decoration* createDecoration(const Source& source, + ProgramBuilder& builder, + DecorationKind kind) { switch (kind) { case DecorationKind::kAccess: return builder.create( @@ -87,7 +89,7 @@ ast::Decoration* createDecoration(const Source& source, using FunctionReturnTypeDecorationTest = TestWithParams; TEST_P(FunctionReturnTypeDecorationTest, IsValid) { - auto params = GetParam(); + auto& params = GetParam(); Func("main", ast::VariableList{}, ty.f32(), ast::StatementList{create(Expr(1.f))}, @@ -121,9 +123,8 @@ INSTANTIATE_TEST_SUITE_P( TestParams{DecorationKind::kWorkgroup, false})); using ArrayDecorationTest = TestWithParams; - TEST_P(ArrayDecorationTest, IsValid) { - auto params = GetParam(); + auto& params = GetParam(); ast::StructMemberList members{Member( "a", create(ty.f32(), 0, @@ -163,7 +164,7 @@ INSTANTIATE_TEST_SUITE_P( using StructDecorationTest = TestWithParams; TEST_P(StructDecorationTest, IsValid) { - auto params = GetParam(); + auto& params = GetParam(); auto* s = create(ast::StructMemberList{}, ast::DecorationList{createDecoration( @@ -200,7 +201,7 @@ INSTANTIATE_TEST_SUITE_P( using StructMemberDecorationTest = TestWithParams; TEST_P(StructMemberDecorationTest, IsValid) { - auto params = GetParam(); + auto& params = GetParam(); ast::StructMemberList members{ Member("a", ty.i32(), @@ -239,7 +240,7 @@ INSTANTIATE_TEST_SUITE_P( using VariableDecorationTest = TestWithParams; TEST_P(VariableDecorationTest, IsValid) { - auto params = GetParam(); + auto& params = GetParam(); Global("a", ty.f32(), ast::StorageClass::kInput, nullptr, ast::DecorationList{ @@ -274,7 +275,7 @@ INSTANTIATE_TEST_SUITE_P( using FunctionDecorationTest = TestWithParams; TEST_P(FunctionDecorationTest, IsValid) { - auto params = GetParam(); + auto& params = GetParam(); Func("foo", ast::VariableList{}, ty.void_(), ast::StatementList{}, ast::DecorationList{ @@ -307,4 +308,121 @@ INSTANTIATE_TEST_SUITE_P( TestParams{DecorationKind::kWorkgroup, true})); } // namespace +} // namespace DecorationTests + +namespace ArrayStrideTests { +namespace { + +struct Params { + create_type_func_ptr create_el_type; + uint32_t stride; + bool should_pass; +}; + +struct TestWithParams : ResolverTestWithParam {}; + +using ArrayStrideTest = TestWithParams; +TEST_P(ArrayStrideTest, All) { + auto& params = GetParam(); + auto* el_ty = params.create_el_type(ty); + + std::stringstream ss; + ss << "el_ty: " << el_ty->FriendlyName(Symbols()) + << ", stride: " << params.stride + << ", should_pass: " << params.should_pass; + SCOPED_TRACE(ss.str()); + + auto* arr = + create(el_ty, 4, + ast::DecorationList{ + create(params.stride), + }); + + Global(Source{{12, 34}}, "myarray", arr, ast::StorageClass::kInput); + + if (params.should_pass) { + EXPECT_TRUE(r()->Resolve()) << r()->error(); + } else { + EXPECT_FALSE(r()->Resolve()); + EXPECT_EQ(r()->error(), + "12:34 error: arrays decorated with the stride attribute must " + "have a stride that is at least the size of the element type, " + "and be a multiple of the element type's alignment value."); + } +} + +// Helpers and typedefs +using i32 = ProgramBuilder::i32; +using u32 = ProgramBuilder::u32; +using f32 = ProgramBuilder::f32; + +struct SizeAndAlignment { + uint32_t size; + uint32_t align; +}; +constexpr SizeAndAlignment default_u32 = {4, 4}; +constexpr SizeAndAlignment default_i32 = {4, 4}; +constexpr SizeAndAlignment default_f32 = {4, 4}; +constexpr SizeAndAlignment default_vec2 = {8, 8}; +constexpr SizeAndAlignment default_vec3 = {12, 16}; +constexpr SizeAndAlignment default_vec4 = {16, 16}; +constexpr SizeAndAlignment default_mat2x2 = {16, 8}; +constexpr SizeAndAlignment default_mat3x3 = {48, 16}; +constexpr SizeAndAlignment default_mat4x4 = {64, 16}; + +INSTANTIATE_TEST_SUITE_P( + ResolverDecorationValidationTest, + ArrayStrideTest, + testing::Values( + // Succeed because stride >= element size (while being multiple of + // element alignment) + Params{ty_u32, default_u32.size, true}, + Params{ty_i32, default_i32.size, true}, + Params{ty_f32, default_f32.size, true}, + Params{ty_vec2, default_vec2.size, true}, + // vec3's default size is not a multiple of its alignment + // Params{ty_vec3, default_vec3.size, true}, + Params{ty_vec4, default_vec4.size, true}, + Params{ty_mat2x2, default_mat2x2.size, true}, + Params{ty_mat3x3, default_mat3x3.size, true}, + Params{ty_mat4x4, default_mat4x4.size, true}, + + // Fail because stride is < element size + Params{ty_u32, default_u32.size - 1, false}, + Params{ty_i32, default_i32.size - 1, false}, + Params{ty_f32, default_f32.size - 1, false}, + Params{ty_vec2, default_vec2.size - 1, false}, + Params{ty_vec3, default_vec3.size - 1, false}, + Params{ty_vec4, default_vec4.size - 1, false}, + Params{ty_mat2x2, default_mat2x2.size - 1, false}, + Params{ty_mat3x3, default_mat3x3.size - 1, false}, + Params{ty_mat4x4, default_mat4x4.size - 1, false}, + + // Succeed because stride equals multiple of element alignment + Params{ty_u32, default_u32.align * 7, true}, + Params{ty_i32, default_i32.align * 7, true}, + Params{ty_f32, default_f32.align * 7, true}, + Params{ty_vec2, default_vec2.align * 7, true}, + Params{ty_vec3, default_vec3.align * 7, true}, + Params{ty_vec4, default_vec4.align * 7, true}, + Params{ty_mat2x2, default_mat2x2.align * 7, true}, + Params{ty_mat3x3, default_mat3x3.align * 7, true}, + Params{ty_mat4x4, default_mat4x4.align * 7, true}, + + // Fail because stride is not multiple of element alignment + Params{ty_u32, (default_u32.align - 1) * 7, false}, + Params{ty_i32, (default_i32.align - 1) * 7, false}, + Params{ty_f32, (default_f32.align - 1) * 7, false}, + Params{ty_vec2, (default_vec2.align - 1) * 7, false}, + Params{ty_vec3, (default_vec3.align - 1) * 7, false}, + Params{ty_vec4, (default_vec4.align - 1) * 7, false}, + Params{ty_mat2x2, (default_mat2x2.align - 1) * 7, false}, + Params{ty_mat3x3, (default_mat3x3.align - 1) * 7, false}, + Params{ty_mat4x4, (default_mat4x4.align - 1) * 7, false} + + )); + +} // namespace +} // namespace ArrayStrideTests +} // namespace resolver } // namespace tint diff --git a/src/resolver/resolver.cc b/src/resolver/resolver.cc index dfcabea03b..5bbc82093d 100644 --- a/src/resolver/resolver.cc +++ b/src/resolver/resolver.cc @@ -1824,13 +1824,13 @@ const semantic::Array* Resolver::Array(type::Array* arr, const Source& source) { return nullptr; } - auto create_semantic = [&](uint32_t stride) -> semantic::Array* { - uint32_t el_align = 0; - uint32_t el_size = 0; - if (!DefaultAlignAndSize(arr->type(), el_align, el_size, source)) { - return nullptr; - } + uint32_t el_align = 0; + uint32_t el_size = 0; + if (!DefaultAlignAndSize(el_ty, el_align, el_size, source)) { + return nullptr; + } + auto create_semantic = [&](uint32_t stride) -> semantic::Array* { auto align = el_align; // WebGPU requires runtime arrays have at least one element, but the AST // records an element count of 0 for it. @@ -1843,18 +1843,30 @@ const semantic::Array* Resolver::Array(type::Array* arr, const Source& source) { // Look for explicit stride via [[stride(n)]] decoration for (auto* deco : arr->decorations()) { if (auto* stride = deco->As()) { - return create_semantic(stride->stride()); + auto explicit_stride = stride->stride(); + bool is_valid_stride = (explicit_stride >= el_size) && + (explicit_stride >= el_align) && + (explicit_stride % el_align == 0); + if (!is_valid_stride) { + // https://gpuweb.github.io/gpuweb/wgsl/#array-layout-rules + // Arrays decorated with the stride attribute must have a stride that is + // at least the size of the element type, and be a multiple of the + // element type's alignment value. + diagnostics_.add_error( + "arrays decorated with the stride attribute must have a stride " + "that is at least the size of the element type, and be a multiple " + "of the element type's alignment value.", + source); + return nullptr; + } + + return create_semantic(explicit_stride); } } // Calculate implicit stride - uint32_t el_align = 0; - uint32_t el_size = 0; - if (!DefaultAlignAndSize(el_ty, el_align, el_size, source)) { - return nullptr; - } - - return create_semantic(utils::RoundUp(el_align, el_size)); + auto implicit_stride = utils::RoundUp(el_align, el_size); + return create_semantic(implicit_stride); } bool Resolver::ValidateStructure(const type::Struct* st) { diff --git a/src/resolver/resolver_test_helper.h b/src/resolver/resolver_test_helper.h index cd9dd7ad3a..a3dfce1ad1 100644 --- a/src/resolver/resolver_test_helper.h +++ b/src/resolver/resolver_test_helper.h @@ -122,6 +122,17 @@ inline type::Type* ty_f32(const ProgramBuilder::TypesBuilder& ty) { using create_type_func_ptr = type::Type* (*)(const ProgramBuilder::TypesBuilder& ty); +template +type::Type* ty_vec2(const ProgramBuilder::TypesBuilder& ty) { + return ty.vec2(); +} + +template +type::Type* ty_vec2(const ProgramBuilder::TypesBuilder& ty) { + auto* type = create_type(ty); + return ty.vec2(type); +} + template type::Type* ty_vec3(const ProgramBuilder::TypesBuilder& ty) { return ty.vec3(); @@ -133,6 +144,28 @@ type::Type* ty_vec3(const ProgramBuilder::TypesBuilder& ty) { return ty.vec3(type); } +template +type::Type* ty_vec4(const ProgramBuilder::TypesBuilder& ty) { + return ty.vec4(); +} + +template +type::Type* ty_vec4(const ProgramBuilder::TypesBuilder& ty) { + auto* type = create_type(ty); + return ty.vec4(type); +} + +template +type::Type* ty_mat2x2(const ProgramBuilder::TypesBuilder& ty) { + return ty.mat2x2(); +} + +template +type::Type* ty_mat2x2(const ProgramBuilder::TypesBuilder& ty) { + auto* type = create_type(ty); + return ty.mat2x2(type); +} + template type::Type* ty_mat3x3(const ProgramBuilder::TypesBuilder& ty) { return ty.mat3x3(); @@ -144,6 +177,17 @@ type::Type* ty_mat3x3(const ProgramBuilder::TypesBuilder& ty) { return ty.mat3x3(type); } +template +type::Type* ty_mat4x4(const ProgramBuilder::TypesBuilder& ty) { + return ty.mat4x4(); +} + +template +type::Type* ty_mat4x4(const ProgramBuilder::TypesBuilder& ty) { + auto* type = create_type(ty); + return ty.mat4x4(type); +} + template type::Type* ty_alias(const ProgramBuilder::TypesBuilder& ty) { auto* type = create_type(ty);