diff --git a/src/resolver/resolver.cc b/src/resolver/resolver.cc index 1b26b607f2..0b8c48d775 100644 --- a/src/resolver/resolver.cc +++ b/src/resolver/resolver.cc @@ -17,6 +17,7 @@ #include #include #include +#include #include #include "src/ast/alias.h" @@ -3870,9 +3871,9 @@ sem::Array* Resolver::Array(const ast::Array* arr) { } // Calculate implicit stride - auto implicit_stride = utils::RoundUp(el_align, el_size); + uint64_t implicit_stride = utils::RoundUp(el_align, el_size); - auto stride = explicit_stride ? explicit_stride : implicit_stride; + uint64_t stride = explicit_stride ? explicit_stride : implicit_stride; // Evaluate the constant array size expression. // sem::Array uses a size of 0 for a runtime-sized array. @@ -3933,9 +3934,24 @@ sem::Array* Resolver::Array(const ast::Array* arr) { count = count_val.Elements()[0].u32; } - auto size = std::max(count, 1) * stride; - auto* out = builder_->create(elem_type, count, el_align, size, - stride, implicit_stride); + auto size = std::max(count, 1) * stride; + if (size > std::numeric_limits::max()) { + std::stringstream msg; + msg << "array size in bytes must not exceed 0x" << std::hex + << std::numeric_limits::max() << ", but is 0x" << std::hex + << size; + AddError(msg.str(), arr->source()); + return nullptr; + } + if (stride > std::numeric_limits::max() || + implicit_stride > std::numeric_limits::max()) { + TINT_ICE(Resolver, diagnostics_) + << "calculated array stride exceeds uint32"; + return nullptr; + } + auto* out = builder_->create( + elem_type, count, el_align, static_cast(size), + static_cast(stride), static_cast(implicit_stride)); if (!ValidateArray(out, source)) { return nullptr; @@ -4154,8 +4170,8 @@ sem::Struct* Resolver::Structure(const ast::Struct* str) { // Validation of storage-class rules requires analysing the actual variable // usage of the structure, and so is performed as part of the variable // validation. - uint32_t struct_size = 0; - uint32_t struct_align = 1; + uint64_t struct_size = 0; + uint64_t struct_align = 1; std::unordered_map member_map; for (auto* member : str->members()) { @@ -4183,9 +4199,9 @@ sem::Struct* Resolver::Structure(const ast::Struct* str) { return nullptr; } - uint32_t offset = struct_size; - uint32_t align = type->Align(); - uint32_t size = type->Size(); + uint64_t offset = struct_size; + uint64_t align = type->Align(); + uint64_t size = type->Size(); if (!ValidateNoDuplicateDecorations(member->decorations())) { return nullptr; @@ -4234,6 +4250,14 @@ sem::Struct* Resolver::Structure(const ast::Struct* str) { } offset = utils::RoundUp(align, offset); + if (offset > std::numeric_limits::max()) { + std::stringstream msg; + msg << "struct member has byte offset 0x" << std::hex << offset + << ", but must not exceed 0x" << std::hex + << std::numeric_limits::max(); + AddError(msg.str(), member->source()); + return nullptr; + } auto* sem_member = builder_->create( member, member->symbol(), const_cast(type), @@ -4245,12 +4269,27 @@ sem::Struct* Resolver::Structure(const ast::Struct* str) { struct_align = std::max(struct_align, align); } - auto size_no_padding = struct_size; + uint64_t size_no_padding = struct_size; struct_size = utils::RoundUp(struct_align, struct_size); - auto* out = - builder_->create(str, str->name(), sem_members, struct_align, - struct_size, size_no_padding); + if (struct_size > std::numeric_limits::max()) { + std::stringstream msg; + msg << "struct size in bytes must not exceed 0x" << std::hex + << std::numeric_limits::max() << ", but is 0x" << std::hex + << struct_size; + AddError(msg.str(), str->source()); + return nullptr; + } + if (struct_align > std::numeric_limits::max()) { + TINT_ICE(Resolver, diagnostics_) + << "calculated struct stride exceeds uint32"; + return nullptr; + } + + auto* out = builder_->create( + str, str->name(), sem_members, static_cast(struct_align), + static_cast(struct_size), + static_cast(size_no_padding)); for (size_t i = 0; i < sem_members.size(); i++) { auto* mem_type = sem_members[i]->Type(); diff --git a/src/resolver/type_validation_test.cc b/src/resolver/type_validation_test.cc index 549b814713..354978f772 100644 --- a/src/resolver/type_validation_test.cc +++ b/src/resolver/type_validation_test.cc @@ -325,6 +325,26 @@ TEST_F(ResolverTypeValidationTest, ArraySize_IVecConstant) { EXPECT_EQ(r()->error(), "12:34 error: array size must be integer scalar"); } +TEST_F(ResolverTypeValidationTest, ArraySize_TooBig_ImplicitStride) { + // var a : array; + Global("a", ty.array(Source{{12, 34}}, ty.f32(), 0x40000000), + ast::StorageClass::kPrivate); + EXPECT_FALSE(r()->Resolve()); + EXPECT_EQ(r()->error(), + "12:34 error: array size in bytes must not exceed 0xffffffff, but " + "is 0x100000000"); +} + +TEST_F(ResolverTypeValidationTest, ArraySize_TooBig_ExplicitStride) { + // var a : [[stride(8)]] array; + Global("a", ty.array(Source{{12, 34}}, ty.f32(), 0x20000000, 8), + ast::StorageClass::kPrivate); + EXPECT_FALSE(r()->Resolve()); + EXPECT_EQ(r()->error(), + "12:34 error: array size in bytes must not exceed 0xffffffff, but " + "is 0x100000000"); +} + TEST_F(ResolverTypeValidationTest, ArraySize_OverridableConstant) { // [[override]] let size = 10; // var a : array; @@ -396,8 +416,49 @@ TEST_F(ResolverTypeValidationTest, RuntimeArrayInFunction_Fail) { "a struct"); } +TEST_F(ResolverTypeValidationTest, Struct_TooBig) { + // struct Foo { + // a: array; + // b: array; + // }; + + Structure(Source{{12, 34}}, "Foo", + { + Member("a", ty.array()), + Member("b", ty.array()), + }); + + WrapInFunction(); + + EXPECT_FALSE(r()->Resolve()); + EXPECT_EQ(r()->error(), + "12:34 error: struct size in bytes must not exceed 0xffffffff, but " + "is 0x100000000"); +} + +TEST_F(ResolverTypeValidationTest, Struct_MemberOffset_TooBig) { + // struct Foo { + // a: array; + // b: f32; + // c: f32; + // }; + + Structure("Foo", { + Member("a", ty.array()), + Member("b", ty.f32()), + Member(Source{{12, 34}}, "c", ty.f32()), + }); + + WrapInFunction(); + + EXPECT_FALSE(r()->Resolve()); + EXPECT_EQ(r()->error(), + "12:34 error: struct member has byte offset 0x100000000, but must " + "not exceed 0xffffffff"); +} + TEST_F(ResolverTypeValidationTest, RuntimeArrayIsLast_Pass) { - // [[Block]] + // [[block]] // struct Foo { // vf: f32; // rt: array; @@ -435,7 +496,7 @@ TEST_F(ResolverTypeValidationTest, RuntimeArrayIsLastNoBlock_Fail) { } TEST_F(ResolverTypeValidationTest, RuntimeArrayIsNotLast_Fail) { - // [[Block]] + // [[block]] // struct Foo { // rt: array; // vf: f32; @@ -504,7 +565,7 @@ TEST_F(ResolverTypeValidationTest, RuntimeArrayAsParameter_Fail) { } TEST_F(ResolverTypeValidationTest, AliasRuntimeArrayIsNotLast_Fail) { - // [[Block]] + // [[block]] // type RTArr = array; // struct s { // b: RTArr; @@ -528,7 +589,7 @@ TEST_F(ResolverTypeValidationTest, AliasRuntimeArrayIsNotLast_Fail) { } TEST_F(ResolverTypeValidationTest, AliasRuntimeArrayIsLast_Pass) { - // [[Block]] + // [[block]] // type RTArr = array; // struct s { // a: u32; diff --git a/src/transform/robustness_test.cc b/src/transform/robustness_test.cc index 2e5b96b2e7..3f9078746a 100644 --- a/src/transform/robustness_test.cc +++ b/src/transform/robustness_test.cc @@ -170,7 +170,9 @@ fn f() { EXPECT_EQ(expect, str(got)); } -TEST_F(RobustnessTest, LargeArrays_Idx) { +// TODO(crbug.com/tint/1177) - Validation currently forbids arrays larger than +// 0xffffffff. If WGSL supports 64-bit indexing, re-enable this test. +TEST_F(RobustnessTest, DISABLED_LargeArrays_Idx) { auto* src = R"( [[block]] struct S {