Resolver: Validate that type sizes fit in uint32_t

Bug: chromium:1249708
Bug: tint:1177
Change-Id: I31c52f160e4952475e977453206ab4224fd20df7
Reviewed-on: https://dawn-review.googlesource.com/c/tint/+/64320
Kokoro: Kokoro <noreply+kokoro@google.com>
Reviewed-by: David Neto <dneto@google.com>
Commit-Queue: Ben Clayton <bclayton@google.com>
This commit is contained in:
Ben Clayton 2021-09-15 17:37:00 +00:00 committed by Tint LUCI CQ
parent 6556ba0e94
commit d1d99bc7de
3 changed files with 121 additions and 19 deletions

View File

@ -17,6 +17,7 @@
#include <algorithm> #include <algorithm>
#include <cmath> #include <cmath>
#include <iomanip> #include <iomanip>
#include <limits>
#include <utility> #include <utility>
#include "src/ast/alias.h" #include "src/ast/alias.h"
@ -3870,9 +3871,9 @@ sem::Array* Resolver::Array(const ast::Array* arr) {
} }
// Calculate implicit stride // Calculate implicit stride
auto implicit_stride = utils::RoundUp(el_align, el_size); uint64_t implicit_stride = utils::RoundUp<uint64_t>(el_align, el_size);
auto stride = explicit_stride ? explicit_stride : implicit_stride; uint64_t stride = explicit_stride ? explicit_stride : implicit_stride;
// Evaluate the constant array size expression. // Evaluate the constant array size expression.
// sem::Array uses a size of 0 for a runtime-sized array. // sem::Array uses a size of 0 for a runtime-sized array.
@ -3933,9 +3934,24 @@ sem::Array* Resolver::Array(const ast::Array* arr) {
count = count_val.Elements()[0].u32; count = count_val.Elements()[0].u32;
} }
auto size = std::max<uint32_t>(count, 1) * stride; auto size = std::max<uint64_t>(count, 1) * stride;
auto* out = builder_->create<sem::Array>(elem_type, count, el_align, size, if (size > std::numeric_limits<uint32_t>::max()) {
stride, implicit_stride); std::stringstream msg;
msg << "array size in bytes must not exceed 0x" << std::hex
<< std::numeric_limits<uint32_t>::max() << ", but is 0x" << std::hex
<< size;
AddError(msg.str(), arr->source());
return nullptr;
}
if (stride > std::numeric_limits<uint32_t>::max() ||
implicit_stride > std::numeric_limits<uint32_t>::max()) {
TINT_ICE(Resolver, diagnostics_)
<< "calculated array stride exceeds uint32";
return nullptr;
}
auto* out = builder_->create<sem::Array>(
elem_type, count, el_align, static_cast<uint32_t>(size),
static_cast<uint32_t>(stride), static_cast<uint32_t>(implicit_stride));
if (!ValidateArray(out, source)) { if (!ValidateArray(out, source)) {
return nullptr; return nullptr;
@ -4154,8 +4170,8 @@ sem::Struct* Resolver::Structure(const ast::Struct* str) {
// Validation of storage-class rules requires analysing the actual variable // Validation of storage-class rules requires analysing the actual variable
// usage of the structure, and so is performed as part of the variable // usage of the structure, and so is performed as part of the variable
// validation. // validation.
uint32_t struct_size = 0; uint64_t struct_size = 0;
uint32_t struct_align = 1; uint64_t struct_align = 1;
std::unordered_map<Symbol, ast::StructMember*> member_map; std::unordered_map<Symbol, ast::StructMember*> member_map;
for (auto* member : str->members()) { for (auto* member : str->members()) {
@ -4183,9 +4199,9 @@ sem::Struct* Resolver::Structure(const ast::Struct* str) {
return nullptr; return nullptr;
} }
uint32_t offset = struct_size; uint64_t offset = struct_size;
uint32_t align = type->Align(); uint64_t align = type->Align();
uint32_t size = type->Size(); uint64_t size = type->Size();
if (!ValidateNoDuplicateDecorations(member->decorations())) { if (!ValidateNoDuplicateDecorations(member->decorations())) {
return nullptr; return nullptr;
@ -4234,6 +4250,14 @@ sem::Struct* Resolver::Structure(const ast::Struct* str) {
} }
offset = utils::RoundUp(align, offset); offset = utils::RoundUp(align, offset);
if (offset > std::numeric_limits<uint32_t>::max()) {
std::stringstream msg;
msg << "struct member has byte offset 0x" << std::hex << offset
<< ", but must not exceed 0x" << std::hex
<< std::numeric_limits<uint32_t>::max();
AddError(msg.str(), member->source());
return nullptr;
}
auto* sem_member = builder_->create<sem::StructMember>( auto* sem_member = builder_->create<sem::StructMember>(
member, member->symbol(), const_cast<sem::Type*>(type), member, member->symbol(), const_cast<sem::Type*>(type),
@ -4245,12 +4269,27 @@ sem::Struct* Resolver::Structure(const ast::Struct* str) {
struct_align = std::max(struct_align, align); struct_align = std::max(struct_align, align);
} }
auto size_no_padding = struct_size; uint64_t size_no_padding = struct_size;
struct_size = utils::RoundUp(struct_align, struct_size); struct_size = utils::RoundUp(struct_align, struct_size);
auto* out = if (struct_size > std::numeric_limits<uint32_t>::max()) {
builder_->create<sem::Struct>(str, str->name(), sem_members, struct_align, std::stringstream msg;
struct_size, size_no_padding); msg << "struct size in bytes must not exceed 0x" << std::hex
<< std::numeric_limits<uint32_t>::max() << ", but is 0x" << std::hex
<< struct_size;
AddError(msg.str(), str->source());
return nullptr;
}
if (struct_align > std::numeric_limits<uint32_t>::max()) {
TINT_ICE(Resolver, diagnostics_)
<< "calculated struct stride exceeds uint32";
return nullptr;
}
auto* out = builder_->create<sem::Struct>(
str, str->name(), sem_members, static_cast<uint32_t>(struct_align),
static_cast<uint32_t>(struct_size),
static_cast<uint32_t>(size_no_padding));
for (size_t i = 0; i < sem_members.size(); i++) { for (size_t i = 0; i < sem_members.size(); i++) {
auto* mem_type = sem_members[i]->Type(); auto* mem_type = sem_members[i]->Type();

View File

@ -325,6 +325,26 @@ TEST_F(ResolverTypeValidationTest, ArraySize_IVecConstant) {
EXPECT_EQ(r()->error(), "12:34 error: array size must be integer scalar"); EXPECT_EQ(r()->error(), "12:34 error: array size must be integer scalar");
} }
TEST_F(ResolverTypeValidationTest, ArraySize_TooBig_ImplicitStride) {
// var<private> a : array<f32, 0x40000000>;
Global("a", ty.array(Source{{12, 34}}, ty.f32(), 0x40000000),
ast::StorageClass::kPrivate);
EXPECT_FALSE(r()->Resolve());
EXPECT_EQ(r()->error(),
"12:34 error: array size in bytes must not exceed 0xffffffff, but "
"is 0x100000000");
}
TEST_F(ResolverTypeValidationTest, ArraySize_TooBig_ExplicitStride) {
// var<private> a : [[stride(8)]] array<f32, 0x20000000>;
Global("a", ty.array(Source{{12, 34}}, ty.f32(), 0x20000000, 8),
ast::StorageClass::kPrivate);
EXPECT_FALSE(r()->Resolve());
EXPECT_EQ(r()->error(),
"12:34 error: array size in bytes must not exceed 0xffffffff, but "
"is 0x100000000");
}
TEST_F(ResolverTypeValidationTest, ArraySize_OverridableConstant) { TEST_F(ResolverTypeValidationTest, ArraySize_OverridableConstant) {
// [[override]] let size = 10; // [[override]] let size = 10;
// var<private> a : array<f32, size>; // var<private> a : array<f32, size>;
@ -396,8 +416,49 @@ TEST_F(ResolverTypeValidationTest, RuntimeArrayInFunction_Fail) {
"a struct"); "a struct");
} }
TEST_F(ResolverTypeValidationTest, Struct_TooBig) {
// struct Foo {
// a: array<f32, 0x20000000>;
// b: array<f32, 0x20000000>;
// };
Structure(Source{{12, 34}}, "Foo",
{
Member("a", ty.array<f32, 0x20000000>()),
Member("b", ty.array<f32, 0x20000000>()),
});
WrapInFunction();
EXPECT_FALSE(r()->Resolve());
EXPECT_EQ(r()->error(),
"12:34 error: struct size in bytes must not exceed 0xffffffff, but "
"is 0x100000000");
}
TEST_F(ResolverTypeValidationTest, Struct_MemberOffset_TooBig) {
// struct Foo {
// a: array<f32, 0x3fffffff>;
// b: f32;
// c: f32;
// };
Structure("Foo", {
Member("a", ty.array<f32, 0x3fffffff>()),
Member("b", ty.f32()),
Member(Source{{12, 34}}, "c", ty.f32()),
});
WrapInFunction();
EXPECT_FALSE(r()->Resolve());
EXPECT_EQ(r()->error(),
"12:34 error: struct member has byte offset 0x100000000, but must "
"not exceed 0xffffffff");
}
TEST_F(ResolverTypeValidationTest, RuntimeArrayIsLast_Pass) { TEST_F(ResolverTypeValidationTest, RuntimeArrayIsLast_Pass) {
// [[Block]] // [[block]]
// struct Foo { // struct Foo {
// vf: f32; // vf: f32;
// rt: array<f32>; // rt: array<f32>;
@ -435,7 +496,7 @@ TEST_F(ResolverTypeValidationTest, RuntimeArrayIsLastNoBlock_Fail) {
} }
TEST_F(ResolverTypeValidationTest, RuntimeArrayIsNotLast_Fail) { TEST_F(ResolverTypeValidationTest, RuntimeArrayIsNotLast_Fail) {
// [[Block]] // [[block]]
// struct Foo { // struct Foo {
// rt: array<f32>; // rt: array<f32>;
// vf: f32; // vf: f32;
@ -504,7 +565,7 @@ TEST_F(ResolverTypeValidationTest, RuntimeArrayAsParameter_Fail) {
} }
TEST_F(ResolverTypeValidationTest, AliasRuntimeArrayIsNotLast_Fail) { TEST_F(ResolverTypeValidationTest, AliasRuntimeArrayIsNotLast_Fail) {
// [[Block]] // [[block]]
// type RTArr = array<u32>; // type RTArr = array<u32>;
// struct s { // struct s {
// b: RTArr; // b: RTArr;
@ -528,7 +589,7 @@ TEST_F(ResolverTypeValidationTest, AliasRuntimeArrayIsNotLast_Fail) {
} }
TEST_F(ResolverTypeValidationTest, AliasRuntimeArrayIsLast_Pass) { TEST_F(ResolverTypeValidationTest, AliasRuntimeArrayIsLast_Pass) {
// [[Block]] // [[block]]
// type RTArr = array<u32>; // type RTArr = array<u32>;
// struct s { // struct s {
// a: u32; // a: u32;

View File

@ -170,7 +170,9 @@ fn f() {
EXPECT_EQ(expect, str(got)); EXPECT_EQ(expect, str(got));
} }
TEST_F(RobustnessTest, LargeArrays_Idx) { // TODO(crbug.com/tint/1177) - Validation currently forbids arrays larger than
// 0xffffffff. If WGSL supports 64-bit indexing, re-enable this test.
TEST_F(RobustnessTest, DISABLED_LargeArrays_Idx) {
auto* src = R"( auto* src = R"(
[[block]] [[block]]
struct S { struct S {